From df9a2294aaf0eff2665632f346e251b0cfd090ee Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Mon, 19 Jan 2026 15:28:50 +0100 Subject: [PATCH 01/21] user settings update: init moved to providers --- backend/app/core/providers.py | 18 +++++++++-- backend/app/services/user_settings_service.py | 32 ++++++------------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 6ce30a01..ce1cc071 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -50,7 +50,7 @@ from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService from app.services.coordinator.coordinator import ExecutionCoordinator -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBusEvent, EventBusManager from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService from app.services.execution_service import ExecutionService @@ -473,8 +473,20 @@ async def get_user_settings_service( event_bus_manager: EventBusManager, logger: logging.Logger, ) -> UserSettingsService: - service = UserSettingsService(repository, kafka_event_service, logger) - await service.initialize(event_bus_manager) + service = UserSettingsService(repository, kafka_event_service, logger, event_bus_manager) + + # Subscribe to settings update events for cross-instance cache invalidation. + # EventBus filters out self-published messages, so this handler only + # runs for events from OTHER instances. + bus = await event_bus_manager.get_event_bus() + + async def _handle_settings_update(evt: EventBusEvent) -> None: + uid = evt.payload.get("user_id") + if uid: + await service.invalidate_cache(str(uid)) + + await bus.subscribe("user.settings.updated*", _handle_settings_update) + return service diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index 75817055..21f4e000 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -16,7 +16,7 @@ DomainUserSettingsChangedEvent, DomainUserSettingsUpdate, ) -from app.services.event_bus import EventBusEvent, EventBusManager +from app.services.event_bus import EventBusManager from app.services.kafka_event_service import KafkaEventService _settings_adapter = TypeAdapter(DomainUserSettings) @@ -25,19 +25,22 @@ class UserSettingsService: def __init__( - self, repository: UserSettingsRepository, event_service: KafkaEventService, logger: logging.Logger + self, + repository: UserSettingsRepository, + event_service: KafkaEventService, + logger: logging.Logger, + event_bus_manager: EventBusManager, ) -> None: self.repository = repository self.event_service = event_service self.logger = logger + self._event_bus_manager = event_bus_manager self._cache_ttl = timedelta(minutes=5) self._max_cache_size = 1000 self._cache: TTLCache[str, DomainUserSettings] = TTLCache( maxsize=self._max_cache_size, ttl=self._cache_ttl.total_seconds(), ) - self._event_bus_manager: EventBusManager | None = None - self._subscription_id: str | None = None self.logger.info( "UserSettingsService initialized", @@ -53,22 +56,6 @@ async def get_user_settings(self, user_id: str) -> DomainUserSettings: return await self.get_user_settings_fresh(user_id) - async def initialize(self, event_bus_manager: EventBusManager) -> None: - """Subscribe to settings update events for cross-instance cache invalidation. - - Note: EventBus filters out self-published messages, so this handler only - runs for events from OTHER instances. - """ - self._event_bus_manager = event_bus_manager - bus = await event_bus_manager.get_event_bus() - - async def _handle(evt: EventBusEvent) -> None: - uid = evt.payload.get("user_id") - if uid: - await self.invalidate_cache(str(uid)) - - self._subscription_id = await bus.subscribe("user.settings.updated*", _handle) - async def get_user_settings_fresh(self, user_id: str) -> DomainUserSettings: """Bypass cache and rebuild settings from snapshot + events.""" snapshot = await self.repository.get_snapshot(user_id) @@ -108,9 +95,8 @@ async def update_user_settings( changes_json = _update_adapter.dump_python(updates, exclude_none=True, mode="json") await self._publish_settings_event(user_id, changes_json, reason) - if self._event_bus_manager is not None: - bus = await self._event_bus_manager.get_event_bus() - await bus.publish("user.settings.updated", {"user_id": user_id}) + bus = await self._event_bus_manager.get_event_bus() + await bus.publish("user.settings.updated", {"user_id": user_id}) self._add_to_cache(user_id, new_settings) if (await self.repository.count_events_since_snapshot(user_id)) >= 10: From 6d151b92e2a0c4ef61c2e9ee739ef5e6df3523a9 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Mon, 19 Jan 2026 15:35:28 +0100 Subject: [PATCH 02/21] notification service: update and removal of initialize() --- backend/app/core/providers.py | 4 +-- backend/app/services/notification_service.py | 34 +++++++------------- 2 files changed, 12 insertions(+), 26 deletions(-) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index ce1cc071..b3590988 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -523,7 +523,7 @@ def get_notification_service( notification_metrics: NotificationMetrics, event_metrics: EventMetrics, ) -> NotificationService: - service = NotificationService( + return NotificationService( notification_repository=notification_repository, event_service=kafka_event_service, event_bus_manager=event_bus_manager, @@ -534,8 +534,6 @@ def get_notification_service( notification_metrics=notification_metrics, event_metrics=event_metrics, ) - service.initialize() - return service @provide def get_grafana_alert_processor( diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 780f1279..6ed63068 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -59,8 +59,6 @@ class ServiceState(StringEnum): """Service lifecycle states.""" - IDLE = auto() - INITIALIZING = auto() RUNNING = auto() STOPPING = auto() STOPPED = auto() @@ -136,7 +134,7 @@ def __init__( self.logger = logger # State - self._state = ServiceState.IDLE + self._state = ServiceState.RUNNING self._throttle_cache = ThrottleCache() # Tasks @@ -146,6 +144,16 @@ def __init__( self._dispatcher: EventDispatcher | None = None self._consumer_task: asyncio.Task[None] | None = None + # Channel handlers mapping + self._channel_handlers: dict[NotificationChannel, ChannelHandler] = { + NotificationChannel.IN_APP: self._send_in_app, + NotificationChannel.WEBHOOK: self._send_webhook, + NotificationChannel.SLACK: self._send_slack, + } + + # Start background processors + self._start_background_tasks() + self.logger.info( "NotificationService initialized", extra={ @@ -155,30 +163,10 @@ def __init__( }, ) - # Channel handlers mapping - self._channel_handlers: dict[NotificationChannel, ChannelHandler] = { - NotificationChannel.IN_APP: self._send_in_app, - NotificationChannel.WEBHOOK: self._send_webhook, - NotificationChannel.SLACK: self._send_slack, - } - @property def state(self) -> ServiceState: return self._state - def initialize(self) -> None: - if self._state != ServiceState.IDLE: - self.logger.warning(f"Cannot initialize in state: {self._state}") - return - - self._state = ServiceState.INITIALIZING - - # Start processors - self._state = ServiceState.RUNNING - self._start_background_tasks() - - self.logger.info("Notification service initialized (without Kafka consumer)") - async def shutdown(self) -> None: """Shutdown notification service.""" if self._state == ServiceState.STOPPED: From 1aa13776b2deff747d3e867ec2c07b6b5e392304 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Mon, 19 Jan 2026 15:48:12 +0100 Subject: [PATCH 03/21] event bus: removal of manager (bus is passed directly) --- backend/app/core/providers.py | 25 +++++++--------- backend/app/services/event_bus.py | 30 ------------------- .../app/services/grafana_alert_processor.py | 2 -- backend/app/services/notification_service.py | 15 ++++------ backend/app/services/user_settings_service.py | 9 +++--- .../services/events/test_event_bus.py | 5 ++-- 6 files changed, 22 insertions(+), 64 deletions(-) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index b3590988..4ec446e9 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -50,7 +50,7 @@ from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService from app.services.coordinator.coordinator import ExecutionCoordinator -from app.services.event_bus import EventBusEvent, EventBusManager +from app.services.event_bus import EventBus, EventBusEvent from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService from app.services.execution_service import ExecutionService @@ -230,14 +230,11 @@ async def get_event_store_consumer( yield consumer @provide - async def get_event_bus_manager( + async def get_event_bus( self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics - ) -> AsyncIterator[EventBusManager]: - manager = EventBusManager(settings, logger, connection_metrics) - try: - yield manager - finally: - await manager.close() + ) -> AsyncIterator[EventBus]: + async with EventBus(settings, logger, connection_metrics) as bus: + yield bus class KubernetesProvider(Provider): @@ -470,22 +467,20 @@ async def get_user_settings_service( self, repository: UserSettingsRepository, kafka_event_service: KafkaEventService, - event_bus_manager: EventBusManager, + event_bus: EventBus, logger: logging.Logger, ) -> UserSettingsService: - service = UserSettingsService(repository, kafka_event_service, logger, event_bus_manager) + service = UserSettingsService(repository, kafka_event_service, logger, event_bus) # Subscribe to settings update events for cross-instance cache invalidation. # EventBus filters out self-published messages, so this handler only # runs for events from OTHER instances. - bus = await event_bus_manager.get_event_bus() - async def _handle_settings_update(evt: EventBusEvent) -> None: uid = evt.payload.get("user_id") if uid: await service.invalidate_cache(str(uid)) - await bus.subscribe("user.settings.updated*", _handle_settings_update) + await event_bus.subscribe("user.settings.updated*", _handle_settings_update) return service @@ -515,7 +510,7 @@ def get_notification_service( self, notification_repository: NotificationRepository, kafka_event_service: KafkaEventService, - event_bus_manager: EventBusManager, + event_bus: EventBus, schema_registry: SchemaRegistryManager, sse_redis_bus: SSERedisBus, settings: Settings, @@ -526,7 +521,7 @@ def get_notification_service( return NotificationService( notification_repository=notification_repository, event_service=kafka_event_service, - event_bus_manager=event_bus_manager, + event_bus=event_bus, schema_registry_manager=schema_registry, sse_bus=sse_redis_bus, settings=settings, diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py index bd0080ee..25b0824c 100644 --- a/backend/app/services/event_bus.py +++ b/backend/app/services/event_bus.py @@ -9,7 +9,6 @@ from aiokafka import AIOKafkaConsumer, AIOKafkaProducer from aiokafka.errors import KafkaError -from fastapi import Request from pydantic import BaseModel, ConfigDict from app.core.lifecycle import LifecycleEnabled @@ -316,32 +315,3 @@ async def get_statistics(self) -> dict[str, Any]: } -class EventBusManager: - """Manages EventBus lifecycle as a singleton.""" - - def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: - self.settings = settings - self.logger = logger - self._connection_metrics = connection_metrics - self._event_bus: Optional[EventBus] = None - self._lock = asyncio.Lock() - - async def get_event_bus(self) -> EventBus: - """Get or create the event bus instance.""" - async with self._lock: - if self._event_bus is None: - self._event_bus = EventBus(self.settings, self.logger, self._connection_metrics) - await self._event_bus.__aenter__() - return self._event_bus - - async def close(self) -> None: - """Stop and clean up the event bus.""" - async with self._lock: - if self._event_bus: - await self._event_bus.aclose() - self._event_bus = None - - -async def get_event_bus(request: Request) -> EventBus: - manager: EventBusManager = request.app.state.event_bus_manager - return await manager.get_event_bus() diff --git a/backend/app/services/grafana_alert_processor.py b/backend/app/services/grafana_alert_processor.py index a78d6d6c..e0103faa 100644 --- a/backend/app/services/grafana_alert_processor.py +++ b/backend/app/services/grafana_alert_processor.py @@ -1,5 +1,3 @@ -"""Grafana alert processing service.""" - import logging from typing import Any diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 6ed63068..13a48418 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -40,7 +40,7 @@ from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.mappings import get_topic_for_event from app.schemas_pydantic.sse import RedisNotificationMessage -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus from app.services.kafka_event_service import KafkaEventService from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings @@ -115,7 +115,7 @@ def __init__( self, notification_repository: NotificationRepository, event_service: KafkaEventService, - event_bus_manager: EventBusManager, + event_bus: EventBus, schema_registry_manager: SchemaRegistryManager, sse_bus: SSERedisBus, settings: Settings, @@ -125,7 +125,7 @@ def __init__( ) -> None: self.repository = notification_repository self.event_service = event_service - self.event_bus_manager = event_bus_manager + self.event_bus = event_bus self.metrics = notification_metrics self._event_metrics = event_metrics self.settings = settings @@ -312,8 +312,7 @@ async def create_notification( notification = await self.repository.create_notification(create_data) # Publish event - event_bus = await self.event_bus_manager.get_event_bus() - await event_bus.publish( + await self.event_bus.publish( "notifications.created", { "notification_id": str(notification.notification_id), @@ -682,9 +681,8 @@ async def mark_as_read(self, user_id: str, notification_id: str) -> bool: """Mark notification as read.""" success = await self.repository.mark_as_read(notification_id, user_id) - event_bus = await self.event_bus_manager.get_event_bus() if success: - await event_bus.publish( + await self.event_bus.publish( "notifications.read", {"notification_id": str(notification_id), "user_id": user_id, "read_at": datetime.now(UTC).isoformat()}, ) @@ -766,9 +764,8 @@ async def mark_all_as_read(self, user_id: str) -> int: """Mark all notifications as read for a user.""" count = await self.repository.mark_all_as_read(user_id) - event_bus = await self.event_bus_manager.get_event_bus() if count > 0: - await event_bus.publish( + await self.event_bus.publish( "notifications.all_read", {"user_id": user_id, "count": count, "read_at": datetime.now(UTC).isoformat()} ) diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index 21f4e000..44d69f87 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -16,7 +16,7 @@ DomainUserSettingsChangedEvent, DomainUserSettingsUpdate, ) -from app.services.event_bus import EventBusManager +from app.services.event_bus import EventBus from app.services.kafka_event_service import KafkaEventService _settings_adapter = TypeAdapter(DomainUserSettings) @@ -29,12 +29,12 @@ def __init__( repository: UserSettingsRepository, event_service: KafkaEventService, logger: logging.Logger, - event_bus_manager: EventBusManager, + event_bus: EventBus, ) -> None: self.repository = repository self.event_service = event_service self.logger = logger - self._event_bus_manager = event_bus_manager + self._event_bus = event_bus self._cache_ttl = timedelta(minutes=5) self._max_cache_size = 1000 self._cache: TTLCache[str, DomainUserSettings] = TTLCache( @@ -95,8 +95,7 @@ async def update_user_settings( changes_json = _update_adapter.dump_python(updates, exclude_none=True, mode="json") await self._publish_settings_event(user_id, changes_json, reason) - bus = await self._event_bus_manager.get_event_bus() - await bus.publish("user.settings.updated", {"user_id": user_id}) + await self._event_bus.publish("user.settings.updated", {"user_id": user_id}) self._add_to_cache(user_id, new_settings) if (await self.repository.count_events_since_snapshot(user_id)) >= 10: diff --git a/backend/tests/integration/services/events/test_event_bus.py b/backend/tests/integration/services/events/test_event_bus.py index 6f17670b..0a0ef543 100644 --- a/backend/tests/integration/services/events/test_event_bus.py +++ b/backend/tests/integration/services/events/test_event_bus.py @@ -5,7 +5,7 @@ import pytest from aiokafka import AIOKafkaProducer from app.domain.enums.kafka import KafkaTopic -from app.services.event_bus import EventBusEvent, EventBusManager +from app.services.event_bus import EventBus, EventBusEvent from app.settings import Settings from dishka import AsyncContainer @@ -15,8 +15,7 @@ @pytest.mark.asyncio async def test_event_bus_publish_subscribe(scope: AsyncContainer, test_settings: Settings) -> None: """Test EventBus receives events from other instances (cross-instance communication).""" - manager: EventBusManager = await scope.get(EventBusManager) - bus = await manager.get_event_bus() + bus: EventBus = await scope.get(EventBus) # Future resolves when handler receives the event - no polling needed received_future: asyncio.Future[EventBusEvent] = asyncio.get_running_loop().create_future() From caf1422c267c36235edc3d4dd0e86fb667becd18 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Mon, 19 Jan 2026 16:02:58 +0100 Subject: [PATCH 04/21] sse service: removed getattr sse shutdown manager: moved to DI --- backend/app/core/providers.py | 14 +++++-- backend/app/services/sse/sse_service.py | 36 ++++++++-------- .../app/services/sse/sse_shutdown_manager.py | 42 ++----------------- .../services/sse/test_shutdown_manager.py | 28 ++++++++++--- .../services/sse/test_sse_shutdown_manager.py | 20 +++++++-- 5 files changed, 71 insertions(+), 69 deletions(-) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 4ec446e9..fc6dca31 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -73,7 +73,7 @@ from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge, create_sse_kafka_redis_bridge from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_service import SSEService -from app.services.sse.sse_shutdown_manager import SSEShutdownManager, create_sse_shutdown_manager +from app.services.sse.sse_shutdown_manager import SSEShutdownManager from app.services.user_settings_service import UserSettingsService from app.settings import Settings @@ -395,9 +395,16 @@ async def get_sse_kafka_redis_bridge( @provide(scope=Scope.REQUEST) def get_sse_shutdown_manager( - self, logger: logging.Logger, connection_metrics: ConnectionMetrics + self, + router: SSEKafkaRedisBridge, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, ) -> SSEShutdownManager: - return create_sse_shutdown_manager(logger=logger, connection_metrics=connection_metrics) + return SSEShutdownManager( + router=router, + logger=logger, + connection_metrics=connection_metrics, + ) @provide(scope=Scope.REQUEST) def get_sse_service( @@ -410,7 +417,6 @@ def get_sse_service( logger: logging.Logger, connection_metrics: ConnectionMetrics, ) -> SSEService: - shutdown_manager.set_router(router) return SSEService( repository=sse_repository, router=router, diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index e474fc41..3af70fb0 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -32,14 +32,14 @@ class SSEService: } def __init__( - self, - repository: SSERepository, - router: SSEKafkaRedisBridge, - sse_bus: SSERedisBus, - shutdown_manager: SSEShutdownManager, - settings: Settings, - logger: logging.Logger, - connection_metrics: ConnectionMetrics, + self, + repository: SSERepository, + router: SSEKafkaRedisBridge, + sse_bus: SSERedisBus, + shutdown_manager: SSEShutdownManager, + settings: Settings, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, ) -> None: self.repository = repository self.router = router @@ -48,7 +48,7 @@ def __init__( self.settings = settings self.logger = logger self.metrics = connection_metrics - self.heartbeat_interval = getattr(settings, "SSE_HEARTBEAT_INTERVAL", 30) + self.heartbeat_interval = settings.SSE_HEARTBEAT_INTERVAL async def create_execution_stream(self, execution_id: str, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: connection_id = f"sse_{execution_id}_{datetime.now(timezone.utc).timestamp()}" @@ -106,10 +106,10 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn self.metrics.record_sse_message_sent("executions", "status") async for event_data in self._stream_events_redis( - execution_id, - subscription, - shutdown_event, - include_heartbeat=False, + execution_id, + subscription, + shutdown_event, + include_heartbeat=False, ): yield event_data @@ -120,11 +120,11 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn self.logger.info("SSE connection closed", extra={"execution_id": execution_id}) async def _stream_events_redis( - self, - execution_id: str, - subscription: Any, - shutdown_event: asyncio.Event, - include_heartbeat: bool = True, + self, + execution_id: str, + subscription: Any, + shutdown_event: asyncio.Event, + include_heartbeat: bool = True, ) -> AsyncGenerator[Dict[str, Any], None]: last_heartbeat = datetime.now(timezone.utc) while True: diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py index 4551e812..5c303c54 100644 --- a/backend/app/services/sse/sse_shutdown_manager.py +++ b/backend/app/services/sse/sse_shutdown_manager.py @@ -35,12 +35,14 @@ class SSEShutdownManager: def __init__( self, + router: LifecycleEnabled, logger: logging.Logger, connection_metrics: ConnectionMetrics, drain_timeout: float = 30.0, notification_timeout: float = 5.0, force_close_timeout: float = 10.0, ): + self._router = router self.logger = logger self.drain_timeout = drain_timeout self.notification_timeout = notification_timeout @@ -57,9 +59,6 @@ def __init__( self._connection_callbacks: Dict[str, asyncio.Event] = {} # connection_id -> shutdown event self._draining_connections: Set[str] = set() - # Router reference (set during initialization) - self._router: LifecycleEnabled | None = None - # Synchronization self._lock = asyncio.Lock() self._shutdown_event = asyncio.Event() @@ -74,10 +73,6 @@ def __init__( extra={"drain_timeout": drain_timeout, "notification_timeout": notification_timeout}, ) - def set_router(self, router: LifecycleEnabled) -> None: - """Set the router reference for shutdown coordination.""" - self._router = router - async def register_connection(self, execution_id: str, connection_id: str) -> asyncio.Event | None: """ Register a new SSE connection. @@ -260,9 +255,8 @@ async def _force_close_connections(self) -> None: self._connection_callbacks.clear() self._draining_connections.clear() - # If we have a router, tell it to stop accepting new subscriptions - if self._router: - await self._router.aclose() + # Tell router to stop accepting new subscriptions + await self._router.aclose() self.metrics.update_sse_draining_connections(0) self.logger.info("Force close phase complete") @@ -306,31 +300,3 @@ async def _wait_for_complete(self) -> None: """Wait for shutdown to complete""" while not self._shutdown_complete: await asyncio.sleep(0.1) - - -def create_sse_shutdown_manager( - logger: logging.Logger, - connection_metrics: ConnectionMetrics, - drain_timeout: float = 30.0, - notification_timeout: float = 5.0, - force_close_timeout: float = 10.0, -) -> SSEShutdownManager: - """Factory function to create an SSE shutdown manager. - - Args: - logger: Logger instance - connection_metrics: Connection metrics for tracking SSE connections - drain_timeout: Time to wait for connections to close gracefully - notification_timeout: Time to wait for shutdown notifications to be sent - force_close_timeout: Time before force closing connections - - Returns: - A new SSE shutdown manager instance - """ - return SSEShutdownManager( - logger=logger, - connection_metrics=connection_metrics, - drain_timeout=drain_timeout, - notification_timeout=notification_timeout, - force_close_timeout=force_close_timeout, - ) diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py index 05f6e023..7c2a484d 100644 --- a/backend/tests/unit/services/sse/test_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_shutdown_manager.py @@ -23,7 +23,14 @@ async def _on_stop(self) -> None: @pytest.mark.asyncio async def test_shutdown_graceful_notify_and_drain(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=1.0, notification_timeout=0.01, force_close_timeout=0.1, logger=_test_logger, connection_metrics=connection_metrics) + mgr = SSEShutdownManager( + router=_FakeRouter(), + logger=_test_logger, + connection_metrics=connection_metrics, + drain_timeout=1.0, + notification_timeout=0.01, + force_close_timeout=0.1, + ) # Register two connections and arrange that they unregister when notified ev1 = await mgr.register_connection("e1", "c1") @@ -47,11 +54,15 @@ async def on_shutdown(event: asyncio.Event, cid: str) -> None: @pytest.mark.asyncio async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: + router = _FakeRouter() mgr = SSEShutdownManager( - drain_timeout=0.01, notification_timeout=0.01, force_close_timeout=0.01, logger=_test_logger, connection_metrics=connection_metrics + router=router, + logger=_test_logger, + connection_metrics=connection_metrics, + drain_timeout=0.01, + notification_timeout=0.01, + force_close_timeout=0.01, ) - router = _FakeRouter() - mgr.set_router(router) # Register a connection but never unregister -> force close path ev = await mgr.register_connection("e1", "c1") @@ -71,7 +82,14 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection @pytest.mark.asyncio async def test_get_shutdown_status_transitions(connection_metrics: ConnectionMetrics) -> None: - m = SSEShutdownManager(drain_timeout=0.01, notification_timeout=0.0, force_close_timeout=0.0, logger=_test_logger, connection_metrics=connection_metrics) + m = SSEShutdownManager( + router=_FakeRouter(), + logger=_test_logger, + connection_metrics=connection_metrics, + drain_timeout=0.01, + notification_timeout=0.0, + force_close_timeout=0.0, + ) st0 = m.get_shutdown_status() assert st0.phase == "ready" await m.initiate_shutdown() diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py index fc7ffb3b..3f424605 100644 --- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py @@ -25,8 +25,14 @@ async def _on_stop(self) -> None: @pytest.mark.asyncio async def test_register_unregister_and_shutdown_flow(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.1, force_close_timeout=0.1, logger=_test_logger, connection_metrics=connection_metrics) - mgr.set_router(_FakeRouter()) + mgr = SSEShutdownManager( + router=_FakeRouter(), + logger=_test_logger, + connection_metrics=connection_metrics, + drain_timeout=0.5, + notification_timeout=0.1, + force_close_timeout=0.1, + ) # Register two connections e1 = await mgr.register_connection("exec-1", "c1") @@ -52,8 +58,14 @@ async def test_register_unregister_and_shutdown_flow(connection_metrics: Connect @pytest.mark.asyncio async def test_reject_new_connection_during_shutdown(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager(drain_timeout=0.5, notification_timeout=0.01, force_close_timeout=0.01, - logger=_test_logger, connection_metrics=connection_metrics) + mgr = SSEShutdownManager( + router=_FakeRouter(), + logger=_test_logger, + connection_metrics=connection_metrics, + drain_timeout=0.5, + notification_timeout=0.01, + force_close_timeout=0.01, + ) # Pre-register one active connection - shutdown will block waiting for it e = await mgr.register_connection("e", "c0") assert e is not None From 22c0554dd8932780bf22040402e04e476ebda54b Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Mon, 19 Jan 2026 20:42:02 +0100 Subject: [PATCH 05/21] removed LifecycleEnabled - we have DI and handling stuff via _aenter / _aexit --- backend/app/core/dishka_lifespan.py | 6 +- backend/app/core/lifecycle.py | 62 --- backend/app/core/providers.py | 53 ++- backend/app/dlq/manager.py | 110 ++---- backend/app/events/core/consumer.py | 73 ++-- backend/app/events/core/producer.py | 36 +- backend/app/events/core/types.py | 1 - backend/app/events/event_store_consumer.py | 43 +- .../app/services/coordinator/coordinator.py | 46 +-- .../app/services/coordinator/queue_manager.py | 54 --- backend/app/services/event_bus.py | 15 +- backend/app/services/k8s_worker/worker.py | 11 +- backend/app/services/kafka_event_service.py | 4 - backend/app/services/pod_monitor/monitor.py | 261 +++--------- .../services/result_processor/processor.py | 9 +- backend/app/services/saga/__init__.py | 3 +- .../app/services/saga/saga_orchestrator.py | 62 +-- .../app/services/sse/kafka_redis_bridge.py | 28 +- .../app/services/sse/sse_shutdown_manager.py | 8 +- .../tests/integration/dlq/test_dlq_manager.py | 22 +- .../events/test_consumer_lifecycle.py | 2 +- .../sse/test_partitioned_event_router.py | 6 +- .../coordinator/test_queue_manager.py | 4 - .../unit/services/pod_monitor/test_monitor.py | 373 +++++++----------- .../saga/test_saga_orchestrator_unit.py | 5 +- .../services/sse/test_kafka_redis_bridge.py | 2 +- .../services/sse/test_shutdown_manager.py | 20 +- .../services/sse/test_sse_shutdown_manager.py | 12 +- backend/workers/run_coordinator.py | 4 +- backend/workers/run_k8s_worker.py | 4 +- backend/workers/run_pod_monitor.py | 6 +- backend/workers/run_saga_orchestrator.py | 6 +- 32 files changed, 437 insertions(+), 914 deletions(-) delete mode 100644 backend/app/core/lifecycle.py diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 3a91ee1d..956c5ebd 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -100,11 +100,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Phase 3: Start Kafka consumers in parallel async with AsyncExitStack() as stack: - stack.push_async_callback(sse_bridge.aclose) - stack.push_async_callback(event_store_consumer.aclose) await asyncio.gather( - sse_bridge.__aenter__(), - event_store_consumer.__aenter__(), + stack.enter_async_context(sse_bridge), + stack.enter_async_context(event_store_consumer), ) logger.info("SSE bridge and EventStoreConsumer started") yield diff --git a/backend/app/core/lifecycle.py b/backend/app/core/lifecycle.py deleted file mode 100644 index 2e0d8f85..00000000 --- a/backend/app/core/lifecycle.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from types import TracebackType -from typing import Self - - -class LifecycleEnabled: - """Base class for services with async lifecycle management. - - Usage: - async with MyService() as service: - # service is running - # service is stopped - - Subclasses override _on_start() and _on_stop() for their logic. - Base class handles idempotency and context manager protocol. - - For internal component cleanup, use aclose() which follows Python's - standard async cleanup pattern (like aiofiles, aiohttp). - """ - - def __init__(self) -> None: - self._lifecycle_started: bool = False - - async def _on_start(self) -> None: - """Override with startup logic. Called once on enter.""" - pass - - async def _on_stop(self) -> None: - """Override with cleanup logic. Called once on exit.""" - pass - - async def aclose(self) -> None: - """Close the service. For internal component cleanup. - - Mirrors Python's standard aclose() pattern (like aiofiles, aiohttp). - Idempotent - safe to call multiple times. - """ - if not self._lifecycle_started: - return - self._lifecycle_started = False - await self._on_stop() - - @property - def is_running(self) -> bool: - """Check if service is currently running.""" - return self._lifecycle_started - - async def __aenter__(self) -> Self: - if self._lifecycle_started: - return self # Already started, idempotent - await self._on_start() - self._lifecycle_started = True - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc: BaseException | None, - tb: TracebackType | None, - ) -> None: - await self.aclose() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index fc6dca31..ad01da2a 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -2,6 +2,7 @@ from typing import AsyncIterator import redis.asyncio as redis +from aiokafka import AIOKafkaConsumer, AIOKafkaProducer from dishka import Provider, Scope, from_context, provide from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -40,11 +41,13 @@ from app.db.repositories.replay_repository import ReplayRepository from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository from app.db.repositories.user_settings_repository import UserSettingsRepository -from app.dlq.manager import DLQManager, create_dlq_manager +from app.dlq.manager import DLQManager +from app.dlq.models import RetryPolicy, RetryStrategy +from app.domain.enums.kafka import GroupId, KafkaTopic from app.domain.saga.models import SagaConfig from app.events.core import UnifiedProducer from app.events.event_store import EventStore, create_event_store -from app.events.event_store_consumer import EventStoreConsumer, create_event_store_consumer +from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.topics import get_all_topics from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService @@ -67,10 +70,10 @@ from app.services.pod_monitor.monitor import PodMonitor from app.services.rate_limit_service import RateLimitService from app.services.replay_service import ReplayService -from app.services.saga import SagaOrchestrator, create_saga_orchestrator +from app.services.saga import SagaOrchestrator from app.services.saga.saga_service import SagaService from app.services.saved_script_service import SavedScriptService -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge, create_sse_kafka_redis_bridge +from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_service import SSEService from app.services.sse.sse_shutdown_manager import SSEShutdownManager @@ -171,7 +174,39 @@ async def get_dlq_manager( logger: logging.Logger, dlq_metrics: DLQMetrics, ) -> AsyncIterator[DLQManager]: - async with create_dlq_manager(settings, schema_registry, logger, dlq_metrics) as manager: + topic_name = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DEAD_LETTER_QUEUE}" + consumer = AIOKafkaConsumer( + topic_name, + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"{GroupId.DLQ_MANAGER}.{settings.KAFKA_GROUP_SUFFIX}", + enable_auto_commit=False, + auto_offset_reset="earliest", + client_id="dlq-manager-consumer", + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + producer = AIOKafkaProducer( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + client_id="dlq-manager-producer", + acks="all", + compression_type="gzip", + max_batch_size=16384, + linger_ms=10, + enable_idempotence=True, + ) + manager = DLQManager( + settings=settings, + consumer=consumer, + producer=producer, + schema_registry=schema_registry, + logger=logger, + dlq_metrics=dlq_metrics, + dlq_topic=KafkaTopic.DEAD_LETTER_QUEUE, + default_retry_policy=RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF), + ) + async with manager: yield manager @provide @@ -218,7 +253,7 @@ async def get_event_store_consumer( event_metrics: EventMetrics, ) -> AsyncIterator[EventStoreConsumer]: topics = get_all_topics() - async with create_event_store_consumer( + async with EventStoreConsumer( event_store=event_store, topics=list(topics), schema_registry_manager=schema_registry, @@ -384,7 +419,7 @@ async def get_sse_kafka_redis_bridge( sse_redis_bus: SSERedisBus, logger: logging.Logger, ) -> AsyncIterator[SSEKafkaRedisBridge]: - async with create_sse_kafka_redis_bridge( + async with SSEKafkaRedisBridge( schema_registry=schema_registry, settings=settings, event_metrics=event_metrics, @@ -571,7 +606,8 @@ async def _provide_saga_orchestrator( event_metrics: EventMetrics, ) -> AsyncIterator[SagaOrchestrator]: """Shared factory for SagaOrchestrator with lifecycle management.""" - async with create_saga_orchestrator( + async with SagaOrchestrator( + config=_create_default_saga_config(), saga_repository=saga_repository, producer=kafka_producer, schema_registry_manager=schema_registry, @@ -579,7 +615,6 @@ async def _provide_saga_orchestrator( event_store=event_store, idempotency_manager=idempotency_manager, resource_allocation_repository=resource_allocation_repository, - config=_create_default_saga_config(), logger=logger, event_metrics=event_metrics, ) as orchestrator: diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index 1d450a03..da434964 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -7,7 +7,6 @@ from aiokafka import AIOKafkaConsumer, AIOKafkaProducer from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import DLQMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import extract_trace_context, get_tracer, inject_trace_context @@ -21,7 +20,7 @@ RetryPolicy, RetryStrategy, ) -from app.domain.enums.kafka import GroupId, KafkaTopic +from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import ( DLQMessageDiscardedEvent, DLQMessageReceivedEvent, @@ -32,7 +31,7 @@ from app.settings import Settings -class DLQManager(LifecycleEnabled): +class DLQManager: def __init__( self, settings: Settings, @@ -45,7 +44,6 @@ def __init__( retry_topic_suffix: str = "-retry", default_retry_policy: RetryPolicy | None = None, ): - super().__init__() self.settings = settings self.metrics = dlq_metrics self.schema_registry = schema_registry @@ -76,7 +74,7 @@ def _kafka_msg_to_message(self, msg: Any) -> DLQMessage: headers = {k: v.decode() for k, v in (msg.headers or [])} return DLQMessage(**data, dlq_offset=msg.offset, dlq_partition=msg.partition, headers=headers) - async def _on_start(self) -> None: + async def __aenter__(self) -> "DLQManager": """Start DLQ manager.""" # Start producer and consumer in parallel for faster startup await asyncio.gather(self.producer.start(), self.consumer.start()) @@ -86,8 +84,9 @@ async def _on_start(self) -> None: self._monitor_task = asyncio.create_task(self._monitor_dlq()) self.logger.info("DLQ Manager started") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop DLQ manager.""" # Cancel tasks for task in [self._process_task, self._monitor_task]: @@ -273,35 +272,40 @@ async def _discard_message(self, message: DLQMessage, reason: str) -> None: self.logger.warning("Discarded message", extra={"event_id": message.event.event_id, "reason": reason}) async def _monitor_dlq(self) -> None: - while self.is_running: - try: - # Find messages ready for retry using Beanie - now = datetime.now(timezone.utc) - - docs = ( - await DLQMessageDocument.find( - { - "status": DLQMessageStatus.SCHEDULED, - "next_retry_at": {"$lte": now}, - } + try: + while True: + try: + # Find messages ready for retry using Beanie + now = datetime.now(timezone.utc) + + docs = ( + await DLQMessageDocument.find( + { + "status": DLQMessageStatus.SCHEDULED, + "next_retry_at": {"$lte": now}, + } + ) + .limit(100) + .to_list() ) - .limit(100) - .to_list() - ) - for doc in docs: - message = DLQMessage.model_validate(doc, from_attributes=True) - await self._retry_message(message) + for doc in docs: + message = DLQMessage.model_validate(doc, from_attributes=True) + await self._retry_message(message) - # Update queue size metrics - await self._update_queue_metrics() + # Update queue size metrics + await self._update_queue_metrics() - # Sleep before next check - await asyncio.sleep(10) + # Sleep before next check + await asyncio.sleep(10) - except Exception as e: - self.logger.error(f"Error in DLQ monitor: {e}") - await asyncio.sleep(60) + except asyncio.CancelledError: + raise + except Exception as e: + self.logger.error(f"Error in DLQ monitor: {e}") + await asyncio.sleep(60) + except asyncio.CancelledError: + self.logger.info("DLQ monitor cancelled") async def _update_queue_metrics(self) -> None: # Get counts by topic using Beanie aggregation @@ -438,49 +442,3 @@ async def discard_message_manually(self, event_id: str, reason: str) -> bool: message = DLQMessage.model_validate(doc, from_attributes=True) await self._discard_message(message, reason) return True - - -def create_dlq_manager( - settings: Settings, - schema_registry: SchemaRegistryManager, - logger: logging.Logger, - dlq_metrics: DLQMetrics, - dlq_topic: KafkaTopic = KafkaTopic.DEAD_LETTER_QUEUE, - retry_topic_suffix: str = "-retry", - default_retry_policy: RetryPolicy | None = None, -) -> DLQManager: - topic_name = f"{settings.KAFKA_TOPIC_PREFIX}{dlq_topic}" - consumer = AIOKafkaConsumer( - topic_name, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.DLQ_MANAGER}.{settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - auto_offset_reset="earliest", - client_id="dlq-manager-consumer", - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - producer = AIOKafkaProducer( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - client_id="dlq-manager-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - if default_retry_policy is None: - default_retry_policy = RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF) - return DLQManager( - settings=settings, - consumer=consumer, - producer=producer, - schema_registry=schema_registry, - logger=logger, - dlq_metrics=dlq_metrics, - dlq_topic=dlq_topic, - retry_topic_suffix=retry_topic_suffix, - default_retry_policy=default_retry_policy, - ) diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py index d0532f37..bb37a134 100644 --- a/backend/app/events/core/consumer.py +++ b/backend/app/events/core/consumer.py @@ -36,7 +36,6 @@ def __init__( self._dispatcher = event_dispatcher self._consumer: AIOKafkaConsumer | None = None self._state = ConsumerState.STOPPED - self._running = False self._metrics = ConsumerMetrics() self._event_metrics = event_metrics self._error_callback: "Callable[[Exception, DomainEvent], Awaitable[None]] | None" = None @@ -64,7 +63,6 @@ async def start(self, topics: list[KafkaTopic]) -> None: ) await self._consumer.start() - self._running = True self._consume_task = asyncio.create_task(self._consume_loop()) self._state = ConsumerState.RUNNING @@ -78,8 +76,6 @@ async def stop(self) -> None: else self._state ) - self._running = False - if self._consume_task: self._consume_task.cancel() await asyncio.gather(self._consume_task, return_exceptions=True) @@ -98,37 +94,39 @@ async def _consume_loop(self) -> None: poll_count = 0 message_count = 0 - while self._running and self._consumer: - poll_count += 1 - if poll_count % 100 == 0: # Log every 100 polls - self.logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}") - - try: - # Use getone() with timeout for single message consumption - msg = await asyncio.wait_for( - self._consumer.getone(), - timeout=0.1 - ) - - message_count += 1 - self.logger.debug( - f"Message received from topic {msg.topic}, partition {msg.partition}, offset {msg.offset}" - ) - await self._process_message(msg) - if not self._config.enable_auto_commit: - await self._consumer.commit() - - except asyncio.TimeoutError: - # No message available within timeout, continue polling - await asyncio.sleep(0.01) - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - self._metrics.processing_errors += 1 - - self.logger.warning( - f"Consumer loop ended for group {self._config.group_id}: " - f"running={self._running}, consumer={self._consumer is not None}" - ) + try: + while True: + if not self._consumer: + break + + poll_count += 1 + if poll_count % 100 == 0: # Log every 100 polls + self.logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}") + + try: + # Use getone() with timeout for single message consumption + msg = await asyncio.wait_for( + self._consumer.getone(), + timeout=0.1 + ) + + message_count += 1 + self.logger.debug( + f"Message received from topic {msg.topic}, partition {msg.partition}, offset {msg.offset}" + ) + await self._process_message(msg) + if not self._config.enable_auto_commit: + await self._consumer.commit() + + except asyncio.TimeoutError: + # No message available within timeout, continue polling + await asyncio.sleep(0.01) + except KafkaError as e: + self.logger.error(f"Consumer error: {e}") + self._metrics.processing_errors += 1 + + except asyncio.CancelledError: + self.logger.info(f"Consumer loop cancelled for group {self._config.group_id}") async def _process_message(self, message: Any) -> None: """Process a ConsumerRecord from aiokafka.""" @@ -203,10 +201,6 @@ def state(self) -> ConsumerState: def metrics(self) -> ConsumerMetrics: return self._metrics - @property - def is_running(self) -> bool: - return self._state == ConsumerState.RUNNING - @property def consumer(self) -> AIOKafkaConsumer | None: return self._consumer @@ -214,7 +208,6 @@ def consumer(self) -> AIOKafkaConsumer | None: def get_status(self) -> ConsumerStatus: return ConsumerStatus( state=self._state, - is_running=self.is_running, group_id=self._config.group_id, client_id=self._config.client_id, metrics=ConsumerMetricsSnapshot( diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index a41188c7..69e136ff 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -8,7 +8,6 @@ from aiokafka import AIOKafkaProducer from aiokafka.errors import KafkaError -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.dlq.models import DLQMessage, DLQMessageStatus from app.domain.enums.kafka import KafkaTopic @@ -20,7 +19,7 @@ from .types import ProducerMetrics, ProducerState -class UnifiedProducer(LifecycleEnabled): +class UnifiedProducer: """Fully async Kafka producer using aiokafka.""" def __init__( @@ -30,7 +29,6 @@ def __init__( settings: Settings, event_metrics: EventMetrics, ): - super().__init__() self._settings = settings self._schema_registry = schema_registry_manager self.logger = logger @@ -40,10 +38,6 @@ def __init__( self._event_metrics = event_metrics self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - @property - def is_running(self) -> bool: - return self._state == ProducerState.RUNNING - @property def state(self) -> ProducerState: return self._state @@ -56,7 +50,7 @@ def metrics(self) -> ProducerMetrics: def producer(self) -> AIOKafkaProducer | None: return self._producer - async def _on_start(self) -> None: + async def __aenter__(self) -> "UnifiedProducer": """Start the Kafka producer.""" self._state = ProducerState.STARTING self.logger.info("Starting producer...") @@ -74,11 +68,23 @@ async def _on_start(self) -> None: await self._producer.start() self._state = ProducerState.RUNNING self.logger.info(f"Producer started: {self._settings.KAFKA_BOOTSTRAP_SERVERS}") + return self + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: + """Stop the Kafka producer.""" + self._state = ProducerState.STOPPING + self.logger.info("Stopping producer...") + + if self._producer: + await self._producer.stop() + self._producer = None + + self._state = ProducerState.STOPPED + self.logger.info("Producer stopped") def get_status(self) -> dict[str, Any]: return { "state": self._state, - "running": self.is_running, "config": { "bootstrap_servers": self._settings.KAFKA_BOOTSTRAP_SERVERS, "client_id": f"{self._settings.SERVICE_NAME}-producer", @@ -94,18 +100,6 @@ def get_status(self) -> dict[str, Any]: }, } - async def _on_stop(self) -> None: - """Stop the Kafka producer.""" - self._state = ProducerState.STOPPING - self.logger.info("Stopping producer...") - - if self._producer: - await self._producer.stop() - self._producer = None - - self._state = ProducerState.STOPPED - self.logger.info("Producer stopped") - async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: diff --git a/backend/app/events/core/types.py b/backend/app/events/core/types.py index 1912f1be..7d1eaf14 100644 --- a/backend/app/events/core/types.py +++ b/backend/app/events/core/types.py @@ -112,7 +112,6 @@ class ConsumerStatus(BaseModel): model_config = ConfigDict(from_attributes=True) state: str - is_running: bool group_id: str client_id: str metrics: ConsumerMetricsSnapshot diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py index 41135a95..c7b712ac 100644 --- a/backend/app/events/event_store_consumer.py +++ b/backend/app/events/event_store_consumer.py @@ -3,7 +3,6 @@ from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.core.tracing.utils import trace_span from app.domain.enums.events import EventType @@ -15,7 +14,7 @@ from app.settings import Settings -class EventStoreConsumer(LifecycleEnabled): +class EventStoreConsumer: """Consumes events from Kafka and stores them in MongoDB.""" def __init__( @@ -31,7 +30,6 @@ def __init__( batch_size: int = 100, batch_timeout_seconds: float = 5.0, ): - super().__init__() self.event_store = event_store self.topics = topics self.settings = settings @@ -49,7 +47,7 @@ def __init__( self._last_batch_time: float = 0.0 self._batch_task: asyncio.Task[None] | None = None - async def _on_start(self) -> None: + async def __aenter__(self) -> "EventStoreConsumer": """Start consuming and storing events.""" self._last_batch_time = asyncio.get_running_loop().time() config = ConsumerConfig( @@ -95,8 +93,9 @@ async def _on_start(self) -> None: self._batch_task = asyncio.create_task(self._batch_processor()) self.logger.info(f"Event store consumer started for topics: {self.topics}") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop consumer.""" await self._flush_batch() @@ -128,8 +127,8 @@ async def _handle_error_with_event(self, error: Exception, event: DomainEvent) - async def _batch_processor(self) -> None: """Periodically flush batches based on timeout.""" - while self.is_running: - try: + try: + while True: await asyncio.sleep(1) async with self._batch_lock: @@ -138,8 +137,8 @@ async def _batch_processor(self) -> None: if self._batch_buffer and time_since_last_batch >= self.batch_timeout: await self._flush_batch() - except Exception as e: - self.logger.error(f"Error in batch processor: {e}") + except asyncio.CancelledError: + self.logger.info("Batch processor cancelled") async def _flush_batch(self) -> None: if not self._batch_buffer: @@ -162,29 +161,3 @@ async def _flush_batch(self) -> None: f"stored={results['stored']}, duplicates={results['duplicates']}, " f"failed={results['failed']}" ) - - -def create_event_store_consumer( - event_store: EventStore, - topics: list[KafkaTopic], - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, - producer: UnifiedProducer | None = None, - group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, - batch_size: int = 100, - batch_timeout_seconds: float = 5.0, -) -> EventStoreConsumer: - return EventStoreConsumer( - event_store=event_store, - topics=topics, - group_id=group_id, - batch_size=batch_size, - batch_timeout_seconds=batch_timeout_seconds, - schema_registry_manager=schema_registry_manager, - settings=settings, - logger=logger, - event_metrics=event_metrics, - producer=producer, - ) diff --git a/backend/app/services/coordinator/coordinator.py b/backend/app/services/coordinator/coordinator.py index 5f93ceb6..4bc09a69 100644 --- a/backend/app/services/coordinator/coordinator.py +++ b/backend/app/services/coordinator/coordinator.py @@ -5,7 +5,6 @@ from typing import Any, TypeAlias from uuid import uuid4 -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import CoordinatorMetrics, EventMetrics from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.events import EventType @@ -35,7 +34,7 @@ ExecutionMap: TypeAlias = dict[str, ResourceAllocation] -class ExecutionCoordinator(LifecycleEnabled): +class ExecutionCoordinator: """ Coordinates execution scheduling across the system. @@ -62,7 +61,6 @@ def __init__( max_concurrent_scheduling: int = 10, scheduling_interval_seconds: float = 0.5, ): - super().__init__() self.logger = logger self.metrics = coordinator_metrics self._event_metrics = event_metrics @@ -111,11 +109,11 @@ def __init__( self._schema_registry_manager = schema_registry_manager self.dispatcher = EventDispatcher(logger=self.logger) - async def _on_start(self) -> None: + async def __aenter__(self) -> "ExecutionCoordinator": """Start the coordinator service.""" self.logger.info("Starting ExecutionCoordinator service...") - await self.queue_manager.start() + self.logger.info("Queue manager initialized") await self.idempotency_manager.initialize() @@ -176,8 +174,9 @@ async def handle_cancelled(event: ExecutionCancelledEvent) -> None: self._scheduling_task = asyncio.create_task(self._scheduling_loop()) self.logger.info("ExecutionCoordinator service started successfully") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the coordinator service.""" self.logger.info("Stopping ExecutionCoordinator service...") @@ -193,8 +192,6 @@ async def _on_stop(self) -> None: if self.idempotent_consumer: await self.idempotent_consumer.stop() - await self.queue_manager.stop() - # Close idempotency manager if hasattr(self, "idempotency_manager") and self.idempotency_manager: await self.idempotency_manager.close() @@ -308,21 +305,26 @@ async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: async def _scheduling_loop(self) -> None: """Main scheduling loop""" - while self.is_running: - try: - # Get next execution from queue - execution = await self.queue_manager.get_next_execution() + try: + while True: + try: + # Get next execution from queue + execution = await self.queue_manager.get_next_execution() - if execution: - # Schedule execution - asyncio.create_task(self._schedule_execution(execution)) - else: - # No executions in queue, wait - await asyncio.sleep(self.scheduling_interval) + if execution: + # Schedule execution + asyncio.create_task(self._schedule_execution(execution)) + else: + # No executions in queue, wait + await asyncio.sleep(self.scheduling_interval) - except Exception as e: - self.logger.error(f"Error in scheduling loop: {e}", exc_info=True) - await asyncio.sleep(5) # Wait before retrying + except asyncio.CancelledError: + raise + except Exception as e: + self.logger.error(f"Error in scheduling loop: {e}", exc_info=True) + await asyncio.sleep(5) # Wait before retrying + except asyncio.CancelledError: + self.logger.info("Scheduling loop cancelled") async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: """Schedule a single execution""" @@ -492,7 +494,7 @@ async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, err async def get_status(self) -> dict[str, Any]: """Get coordinator status""" return { - "running": self.is_running, + "scheduling_task_active": self._scheduling_task is not None and not self._scheduling_task.done(), "active_executions": len(self._active_executions), "queue_stats": await self.queue_manager.get_queue_stats(), "resource_stats": await self.resource_manager.get_resource_stats(), diff --git a/backend/app/services/coordinator/queue_manager.py b/backend/app/services/coordinator/queue_manager.py index b8ac98eb..76b15c3c 100644 --- a/backend/app/services/coordinator/queue_manager.py +++ b/backend/app/services/coordinator/queue_manager.py @@ -58,31 +58,6 @@ def __init__( self._queue_lock = asyncio.Lock() self._user_execution_count: Dict[str, int] = defaultdict(int) self._execution_users: Dict[str, str] = {} - self._cleanup_task: asyncio.Task[None] | None = None - self._running = False - - async def start(self) -> None: - if self._running: - return - - self._running = True - self._cleanup_task = asyncio.create_task(self._cleanup_stale_executions()) - self.logger.info("Queue manager started") - - async def stop(self) -> None: - if not self._running: - return - - self._running = False - - if self._cleanup_task: - self._cleanup_task.cancel() - try: - await self._cleanup_task - except asyncio.CancelledError: - pass - - self.logger.info(f"Queue manager stopped. Final queue size: {len(self._queue)}") async def add_execution( self, event: ExecutionRequestedEvent, priority: QueuePriority | None = None @@ -240,32 +215,3 @@ def _update_add_metrics(self, priority: QueuePriority) -> None: def _update_queue_size(self) -> None: self.metrics.update_execution_request_queue_size(len(self._queue)) - - async def _cleanup_stale_executions(self) -> None: - while self._running: - try: - await asyncio.sleep(300) - - async with self._queue_lock: - stale_executions = [] - active_executions = [] - - for queued in self._queue: - if self._is_stale(queued): - stale_executions.append(queued) - else: - active_executions.append(queued) - - if stale_executions: - self._queue = active_executions - heapq.heapify(self._queue) - - for queued in stale_executions: - self._untrack_execution(queued.execution_id) - - # Update metric after stale cleanup - self.metrics.update_execution_request_queue_size(len(self._queue)) - self.logger.info(f"Cleaned {len(stale_executions)} stale executions from queue") - - except Exception as e: - self.logger.error(f"Error in queue cleanup: {e}") diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py index 25b0824c..6ae60f87 100644 --- a/backend/app/services/event_bus.py +++ b/backend/app/services/event_bus.py @@ -11,7 +11,6 @@ from aiokafka.errors import KafkaError from pydantic import BaseModel, ConfigDict -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.domain.enums.kafka import KafkaTopic from app.settings import Settings @@ -37,7 +36,7 @@ class Subscription: handler: Callable[[EventBusEvent], Any] = field(default=lambda _: None) -class EventBus(LifecycleEnabled): +class EventBus: """ Distributed event bus for cross-instance communication via Kafka. @@ -53,7 +52,6 @@ class EventBus(LifecycleEnabled): """ def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: - super().__init__() self.logger = logger self.settings = settings self.metrics = connection_metrics @@ -66,11 +64,12 @@ def __init__(self, settings: Settings, logger: logging.Logger, connection_metric self._topic = f"{self.settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" self._instance_id = str(uuid4()) # Unique ID for filtering self-published messages - async def _on_start(self) -> None: + async def __aenter__(self) -> "EventBus": """Start the event bus with Kafka backing.""" await self._initialize_kafka() self._consumer_task = asyncio.create_task(self._kafka_listener()) self.logger.info("Event bus started with Kafka backing") + return self async def _initialize_kafka(self) -> None: """Initialize Kafka producer and consumer.""" @@ -100,7 +99,7 @@ async def _initialize_kafka(self) -> None: # Start both in parallel for faster startup await asyncio.gather(self.producer.start(), self.consumer.start()) - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the event bus and clean up resources.""" # Cancel consumer task if self._consumer_task and not self._consumer_task.done(): @@ -269,7 +268,7 @@ async def _kafka_listener(self) -> None: self.logger.info("Kafka listener started") try: - while self.is_running: + while True: try: msg = await asyncio.wait_for(self.consumer.getone(), timeout=0.1) @@ -294,8 +293,6 @@ async def _kafka_listener(self) -> None: except asyncio.CancelledError: self.logger.info("Kafka listener cancelled") - except Exception as e: - self.logger.error(f"Fatal error in Kafka listener: {e}") def _update_metrics(self, pattern: str) -> None: """Update metrics for a pattern (must be called within lock).""" @@ -311,7 +308,7 @@ async def get_statistics(self) -> dict[str, Any]: "total_patterns": len(self._pattern_index), "total_subscriptions": len(self._subscriptions), "kafka_enabled": self.producer is not None, - "running": self.is_running, + "consumer_task_active": self._consumer_task is not None and not self._consumer_task.done(), } diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker.py index cd9af936..7fd8f5aa 100644 --- a/backend/app/services/k8s_worker/worker.py +++ b/backend/app/services/k8s_worker/worker.py @@ -9,7 +9,6 @@ from kubernetes import config as k8s_config from kubernetes.client.rest import ApiException -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId @@ -35,7 +34,7 @@ from app.settings import Settings -class KubernetesWorker(LifecycleEnabled): +class KubernetesWorker: """ Worker service that creates Kubernetes pods from execution events. @@ -58,7 +57,6 @@ def __init__( logger: logging.Logger, event_metrics: EventMetrics, ): - super().__init__() self._event_metrics = event_metrics self.logger = logger self.metrics = KubernetesMetrics(settings) @@ -87,7 +85,7 @@ def __init__( self._creation_semaphore = asyncio.Semaphore(self.config.max_concurrent_pods) self._schema_registry_manager = schema_registry_manager - async def _on_start(self) -> None: + async def __aenter__(self) -> "KubernetesWorker": """Start the Kubernetes worker.""" self.logger.info("Starting KubernetesWorker service...") self.logger.info("DEBUG: About to initialize Kubernetes client") @@ -150,8 +148,9 @@ async def _on_start(self) -> None: self.logger.info("Image pre-puller daemonset task scheduled") self.logger.info("KubernetesWorker service started successfully") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the Kubernetes worker.""" self.logger.info("Stopping KubernetesWorker service...") @@ -431,7 +430,7 @@ async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, err async def get_status(self) -> dict[str, Any]: """Get worker status""" return { - "running": self.is_running, + "running": self.idempotent_consumer is not None, "active_creations": len(self._active_creations), "config": { "namespace": self.config.namespace, diff --git a/backend/app/services/kafka_event_service.py b/backend/app/services/kafka_event_service.py index b0a3bcb7..da756f2d 100644 --- a/backend/app/services/kafka_event_service.py +++ b/backend/app/services/kafka_event_service.py @@ -201,7 +201,3 @@ async def publish_domain_event(self, event: DomainEvent, key: str | None = None) self.metrics.record_event_processing_duration(time.time() - start_time, event.event_type) self.logger.info("Domain event published", extra={"event_id": event.event_id}) return event.event_id - - async def close(self) -> None: - """Close event service resources""" - await self.kafka_producer.aclose() diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index ecbb4556..ae95f6a7 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -1,8 +1,6 @@ import asyncio import logging import time -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from dataclasses import dataclass from enum import auto from typing import Any @@ -10,8 +8,7 @@ from kubernetes import client as k8s_client from kubernetes.client.rest import ApiException -from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients -from app.core.lifecycle import LifecycleEnabled +from app.core.k8s_clients import K8sClients from app.core.metrics import KubernetesMetrics from app.core.utils import StringEnum from app.domain.events.typed import DomainEvent @@ -28,7 +25,6 @@ # Constants MAX_BACKOFF_SECONDS: int = 300 # 5 minutes -RECONCILIATION_LOG_INTERVAL: int = 60 # 1 minute class WatchEventType(StringEnum): @@ -39,15 +35,6 @@ class WatchEventType(StringEnum): DELETED = "DELETED" -class MonitorState(StringEnum): - """Pod monitor states.""" - - IDLE = auto() - RUNNING = auto() - STOPPING = auto() - STOPPED = auto() - - class ErrorType(StringEnum): """Error types for metrics.""" @@ -77,24 +64,13 @@ class PodEvent: resource_version: ResourceVersion | None -@dataclass(frozen=True, slots=True) -class ReconciliationResult: - """Result of state reconciliation.""" - - missing_pods: set[PodName] - extra_pods: set[PodName] - duration_seconds: float - success: bool - error: str | None = None - - -class PodMonitor(LifecycleEnabled): +class PodMonitor: """ Monitors Kubernetes pods and publishes lifecycle events. - This service watches pods with specific labels using the K8s watch API, + Watches pods with specific labels using the K8s watch API, maps Kubernetes events to application events, and publishes them to Kafka. - Events are stored in the events collection AND published to Kafka via KafkaEventService. + Reconciles state when watch restarts (every watch_timeout_seconds or on error). """ def __init__( @@ -106,148 +82,116 @@ def __init__( event_mapper: PodEventMapper, kubernetes_metrics: KubernetesMetrics, ) -> None: - """Initialize the pod monitor with all required dependencies. - - All dependencies must be provided - use create_pod_monitor() factory - for automatic dependency creation in production. - """ - super().__init__() self.logger = logger self.config = config - # Kubernetes clients (required, no nullability) + # Kubernetes clients self._clients = k8s_clients self._v1 = k8s_clients.v1 self._watch = k8s_clients.watch - # Components (required, no nullability) + # Components self._event_mapper = event_mapper self._kafka_event_service = kafka_event_service # State - self._state = MonitorState.IDLE self._tracked_pods: set[PodName] = set() self._reconnect_attempts: int = 0 self._last_resource_version: ResourceVersion | None = None - # Tasks + # Task self._watch_task: asyncio.Task[None] | None = None - self._reconcile_task: asyncio.Task[None] | None = None # Metrics self._metrics = kubernetes_metrics - @property - def state(self) -> MonitorState: - """Get current monitor state.""" - return self._state - - async def _on_start(self) -> None: + async def __aenter__(self) -> "PodMonitor": """Start the pod monitor.""" self.logger.info("Starting PodMonitor service...") - # Verify K8s connectivity (all clients already injected via __init__) + # Verify K8s connectivity await asyncio.to_thread(self._v1.get_api_resources) self.logger.info("Successfully connected to Kubernetes API") - # Start monitoring - self._state = MonitorState.RUNNING - self._watch_task = asyncio.create_task(self._watch_pods()) - - # Start reconciliation if enabled - if self.config.enable_state_reconciliation: - self._reconcile_task = asyncio.create_task(self._reconciliation_loop()) + # Start watch task + self._watch_task = asyncio.create_task(self._watch_loop()) self.logger.info("PodMonitor service started successfully") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the pod monitor.""" self.logger.info("Stopping PodMonitor service...") - self._state = MonitorState.STOPPING - # Cancel tasks - tasks = [t for t in [self._watch_task, self._reconcile_task] if t] - for task in tasks: - task.cancel() - - # Wait for cancellation - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + if self._watch_task: + self._watch_task.cancel() + try: + await self._watch_task + except asyncio.CancelledError: + pass - # Close watch if self._watch: self._watch.stop() - # Clear state self._tracked_pods.clear() self._event_mapper.clear_cache() - - self._state = MonitorState.STOPPED self.logger.info("PodMonitor service stopped") - async def _watch_pods(self) -> None: - """Main watch loop for pods.""" - while self._state == MonitorState.RUNNING: - try: - self._reconnect_attempts = 0 - await self._watch_pod_events() + async def _watch_loop(self) -> None: + """Main watch loop - reconciles on each restart.""" + try: + while True: + try: + # Reconcile before starting watch (catches missed events) + if self.config.enable_state_reconciliation: + await self._reconcile() - except ApiException as e: - match e.status: - case 410: # Gone - resource version too old + self._reconnect_attempts = 0 + await self._run_watch() + + except ApiException as e: + if e.status == 410: # Resource version expired self.logger.warning("Resource version expired, resetting watch") self._last_resource_version = None self._metrics.record_pod_monitor_watch_error(ErrorType.RESOURCE_VERSION_EXPIRED) - case _: + else: self.logger.error(f"API error in watch: {e}") self._metrics.record_pod_monitor_watch_error(ErrorType.API_ERROR) + await self._backoff() - await self._handle_watch_error() + except asyncio.CancelledError: + raise - except Exception as e: - self.logger.error(f"Unexpected error in watch: {e}", exc_info=True) - self._metrics.record_pod_monitor_watch_error(ErrorType.UNEXPECTED) - await self._handle_watch_error() + except Exception as e: + self.logger.error(f"Unexpected error in watch: {e}", exc_info=True) + self._metrics.record_pod_monitor_watch_error(ErrorType.UNEXPECTED) + await self._backoff() - async def _watch_pod_events(self) -> None: - """Watch for pod events.""" - # self._v1 and self._watch are guaranteed initialized by start() + except asyncio.CancelledError: + self.logger.info("Watch loop cancelled") - context = WatchContext( - namespace=self.config.namespace, - label_selector=self.config.label_selector, - field_selector=self.config.field_selector, - timeout_seconds=self.config.watch_timeout_seconds, - resource_version=self._last_resource_version, + async def _run_watch(self) -> None: + """Run a single watch session.""" + self.logger.info( + f"Starting pod watch: selector={self.config.label_selector}, namespace={self.config.namespace}" ) - self.logger.info(f"Starting pod watch with selector: {context.label_selector}, namespace: {context.namespace}") - - # Create watch stream - kwargs = { - "namespace": context.namespace, - "label_selector": context.label_selector, - "timeout_seconds": context.timeout_seconds, + kwargs: dict[str, Any] = { + "namespace": self.config.namespace, + "label_selector": self.config.label_selector, + "timeout_seconds": self.config.watch_timeout_seconds, } + if self.config.field_selector: + kwargs["field_selector"] = self.config.field_selector + if self._last_resource_version: + kwargs["resource_version"] = self._last_resource_version - if context.field_selector: - kwargs["field_selector"] = context.field_selector - - if context.resource_version: - kwargs["resource_version"] = context.resource_version - - # Watch stream (clients guaranteed by __init__) stream = self._watch.stream(self._v1.list_namespaced_pod, **kwargs) try: for event in stream: - if self._state != MonitorState.RUNNING: - break - await self._process_raw_event(event) - finally: - # Store resource version for next watch self._update_resource_version(stream) def _update_resource_version(self, stream: Any) -> None: @@ -342,16 +286,15 @@ async def _publish_event(self, event: DomainEvent, pod: k8s_client.V1Pod) -> Non except Exception as e: self.logger.error(f"Error publishing event: {e}", exc_info=True) - async def _handle_watch_error(self) -> None: + async def _backoff(self) -> None: """Handle watch errors with exponential backoff.""" self._reconnect_attempts += 1 if self._reconnect_attempts > self.config.max_reconnect_attempts: self.logger.error( - f"Max reconnect attempts ({self.config.max_reconnect_attempts}) exceeded, stopping pod monitor" + f"Max reconnect attempts ({self.config.max_reconnect_attempts}) exceeded" ) - self._state = MonitorState.STOPPING - return + raise RuntimeError("Max reconnect attempts exceeded") # Calculate exponential backoff backoff = min(self.config.watch_reconnect_delay * (2 ** (self._reconnect_attempts - 1)), MAX_BACKOFF_SECONDS) @@ -364,27 +307,14 @@ async def _handle_watch_error(self) -> None: self._metrics.increment_pod_monitor_watch_reconnects() await asyncio.sleep(backoff) - async def _reconciliation_loop(self) -> None: - """Periodically reconcile state with Kubernetes.""" - while self._state == MonitorState.RUNNING: - try: - await asyncio.sleep(self.config.reconcile_interval_seconds) - - if self._state == MonitorState.RUNNING: - result = await self._reconcile_state() - self._log_reconciliation_result(result) - - except Exception as e: - self.logger.error(f"Error in reconciliation loop: {e}", exc_info=True) - - async def _reconcile_state(self) -> ReconciliationResult: + async def _reconcile(self) -> None: """Reconcile tracked pods with actual state.""" start_time = time.time() try: self.logger.info("Starting pod state reconciliation") - # List all pods matching selector (clients guaranteed by __init__) + # List all pods matching selector pods = await asyncio.to_thread( self._v1.list_namespaced_pod, namespace=self.config.namespace, label_selector=self.config.label_selector ) @@ -415,90 +345,25 @@ async def _reconcile_state(self) -> ReconciliationResult: self._metrics.record_pod_monitor_reconciliation_run("success") duration = time.time() - start_time - - return ReconciliationResult( - missing_pods=missing_pods, extra_pods=extra_pods, duration_seconds=duration, success=True + self.logger.info( + f"Reconciliation completed in {duration:.2f}s. " + f"Found {len(missing_pods)} missing, {len(extra_pods)} extra pods" ) except Exception as e: self.logger.error(f"Failed to reconcile state: {e}", exc_info=True) self._metrics.record_pod_monitor_reconciliation_run("failed") - return ReconciliationResult( - missing_pods=set(), - extra_pods=set(), - duration_seconds=time.time() - start_time, - success=False, - error=str(e), - ) - - def _log_reconciliation_result(self, result: ReconciliationResult) -> None: - """Log reconciliation result.""" - if result.success: - self.logger.info( - f"Reconciliation completed in {result.duration_seconds:.2f}s. " - f"Found {len(result.missing_pods)} missing, " - f"{len(result.extra_pods)} extra pods" - ) - else: - self.logger.error(f"Reconciliation failed after {result.duration_seconds:.2f}s: {result.error}") - async def get_status(self) -> StatusDict: """Get monitor status.""" return { - "state": self._state, "tracked_pods": len(self._tracked_pods), "reconnect_attempts": self._reconnect_attempts, "last_resource_version": self._last_resource_version, + "watch_task_active": self._watch_task is not None and not self._watch_task.done(), "config": { "namespace": self.config.namespace, "label_selector": self.config.label_selector, "enable_reconciliation": self.config.enable_state_reconciliation, }, } - - -@asynccontextmanager -async def create_pod_monitor( - config: PodMonitorConfig, - kafka_event_service: KafkaEventService, - logger: logging.Logger, - kubernetes_metrics: KubernetesMetrics, - k8s_clients: K8sClients | None = None, - event_mapper: PodEventMapper | None = None, -) -> AsyncIterator[PodMonitor]: - """Create and manage a pod monitor instance. - - This factory handles production dependency creation: - - Creates K8sClients if not provided (using config settings) - - Creates PodEventMapper if not provided - - Cleans up created K8sClients on exit - """ - # Track whether we created clients (so we know to close them) - owns_clients = k8s_clients is None - - if k8s_clients is None: - k8s_clients = create_k8s_clients( - logger=logger, - kubeconfig_path=config.kubeconfig_path, - in_cluster=config.in_cluster, - ) - - if event_mapper is None: - event_mapper = PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) - - monitor = PodMonitor( - config=config, - kafka_event_service=kafka_event_service, - logger=logger, - k8s_clients=k8s_clients, - event_mapper=event_mapper, - kubernetes_metrics=kubernetes_metrics, - ) - - try: - async with monitor: - yield monitor - finally: - if owns_clients: - close_k8s_clients(k8s_clients) diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor.py index 3f9864db..dddb9eb7 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor.py @@ -4,7 +4,6 @@ from pydantic import BaseModel, ConfigDict, Field -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics, ExecutionMetrics from app.core.utils import StringEnum from app.db.repositories.execution_repository import ExecutionRepository @@ -51,7 +50,7 @@ class ResultProcessorConfig(BaseModel): processing_timeout: int = Field(default=300) -class ResultProcessor(LifecycleEnabled): +class ResultProcessor: """Service for processing execution completion events and storing results.""" def __init__( @@ -66,7 +65,6 @@ def __init__( event_metrics: EventMetrics, ) -> None: """Initialize the result processor.""" - super().__init__() self.config = ResultProcessorConfig() self._execution_repo = execution_repo self._producer = producer @@ -80,7 +78,7 @@ def __init__( self._dispatcher: EventDispatcher | None = None self.logger = logger - async def _on_start(self) -> None: + async def __aenter__(self) -> "ResultProcessor": """Start the result processor.""" self.logger.info("Starting ResultProcessor...") @@ -92,8 +90,9 @@ async def _on_start(self) -> None: self._consumer = await self._create_consumer() self._state = ProcessingState.PROCESSING self.logger.info("ResultProcessor started successfully with idempotency protection") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the result processor.""" self.logger.info("Stopping ResultProcessor...") self._state = ProcessingState.STOPPED diff --git a/backend/app/services/saga/__init__.py b/backend/app/services/saga/__init__.py index e89535ae..ec47a201 100644 --- a/backend/app/services/saga/__init__.py +++ b/backend/app/services/saga/__init__.py @@ -12,7 +12,7 @@ RemoveFromQueueCompensation, ValidateExecutionStep, ) -from app.services.saga.saga_orchestrator import SagaOrchestrator, create_saga_orchestrator +from app.services.saga.saga_orchestrator import SagaOrchestrator from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep __all__ = [ @@ -34,5 +34,4 @@ "ReleaseResourcesCompensation", "RemoveFromQueueCompensation", "DeletePodCompensation", - "create_saga_orchestrator", ] diff --git a/backend/app/services/saga/saga_orchestrator.py b/backend/app/services/saga/saga_orchestrator.py index 194d6ac3..4c293165 100644 --- a/backend/app/services/saga/saga_orchestrator.py +++ b/backend/app/services/saga/saga_orchestrator.py @@ -5,7 +5,6 @@ from opentelemetry.trace import SpanKind -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.core.tracing import EventAttributes from app.core.tracing.utils import get_tracer @@ -27,7 +26,7 @@ from .saga_step import SagaContext -class SagaOrchestrator(LifecycleEnabled): +class SagaOrchestrator: """Orchestrates saga execution and compensation""" def __init__( @@ -43,7 +42,6 @@ def __init__( logger: logging.Logger, event_metrics: EventMetrics, ): - super().__init__() self.config = config self._sagas: dict[str, type[BaseSaga]] = {} self._running_instances: dict[str, Saga] = {} @@ -67,7 +65,7 @@ def _register_default_sagas(self) -> None: self.register_saga(ExecutionSaga) self.logger.info("Registered default sagas") - async def _on_start(self) -> None: + async def __aenter__(self) -> "SagaOrchestrator": """Start the saga orchestrator.""" self.logger.info(f"Starting saga orchestrator: {self.config.name}") @@ -79,8 +77,9 @@ async def _on_start(self) -> None: self._tasks.append(timeout_task) self.logger.info("Saga orchestrator started") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the saga orchestrator.""" self.logger.info("Stopping saga orchestrator...") @@ -250,8 +249,6 @@ async def _execute_saga( # Execute each step for step in steps: - if not self.is_running: - break # Update current step instance.current_step = step.name @@ -363,8 +360,8 @@ async def _fail_saga(self, instance: Saga, error_message: str) -> None: async def _check_timeouts(self) -> None: """Check for saga timeouts""" - while self.is_running: - try: + try: + while True: # Check every 30 seconds await asyncio.sleep(30) @@ -382,8 +379,8 @@ async def _check_timeouts(self) -> None: await self._save_saga(instance) self._running_instances.pop(instance.saga_id, None) - except Exception as e: - self.logger.error(f"Error checking timeouts: {e}") + except asyncio.CancelledError: + self.logger.info("Timeout checker cancelled") async def _save_saga(self, instance: Saga) -> None: """Persist saga through repository""" @@ -534,46 +531,3 @@ async def _publish_saga_cancelled_event(self, saga_instance: Saga) -> None: except Exception as e: self.logger.error(f"Failed to publish saga cancellation event: {e}") - - -def create_saga_orchestrator( - saga_repository: SagaRepository, - producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - resource_allocation_repository: ResourceAllocationRepository, - config: SagaConfig, - logger: logging.Logger, - event_metrics: EventMetrics, -) -> SagaOrchestrator: - """Factory function to create a saga orchestrator. - - Args: - saga_repository: Repository for saga persistence - producer: Kafka producer instance - schema_registry_manager: Schema registry manager for event serialization - settings: Application settings - event_store: Event store instance for event sourcing - idempotency_manager: Manager for idempotent event processing - resource_allocation_repository: Repository for resource allocations - config: Saga configuration - logger: Logger instance - event_metrics: Event metrics for tracking Kafka consumption - - Returns: - A new saga orchestrator instance - """ - return SagaOrchestrator( - config, - saga_repository=saga_repository, - producer=producer, - schema_registry_manager=schema_registry_manager, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - resource_allocation_repository=resource_allocation_repository, - logger=logger, - event_metrics=event_metrics, - ) diff --git a/backend/app/services/sse/kafka_redis_bridge.py b/backend/app/services/sse/kafka_redis_bridge.py index 07e03c44..2a6b8c64 100644 --- a/backend/app/services/sse/kafka_redis_bridge.py +++ b/backend/app/services/sse/kafka_redis_bridge.py @@ -3,7 +3,6 @@ import asyncio import logging -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId @@ -14,7 +13,7 @@ from app.settings import Settings -class SSEKafkaRedisBridge(LifecycleEnabled): +class SSEKafkaRedisBridge: """ Bridges Kafka events to Redis channels for SSE delivery. @@ -31,7 +30,6 @@ def __init__( sse_bus: SSERedisBus, logger: logging.Logger, ) -> None: - super().__init__() self.schema_registry = schema_registry self.settings = settings self.event_metrics = event_metrics @@ -41,7 +39,7 @@ def __init__( self.num_consumers = settings.SSE_CONSUMER_POOL_SIZE self.consumers: list[UnifiedConsumer] = [] - async def _on_start(self) -> None: + async def __aenter__(self) -> "SSEKafkaRedisBridge": """Start the SSE Kafka→Redis bridge.""" self.logger.info(f"Starting SSE Kafka→Redis bridge with {self.num_consumers} consumers") @@ -53,8 +51,9 @@ async def _on_start(self) -> None: await asyncio.gather(*[c.start(topics) for c in self.consumers]) self.logger.info("SSE Kafka→Redis bridge started successfully") + return self - async def _on_stop(self) -> None: + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: """Stop the SSE Kafka→Redis bridge.""" self.logger.info("Stopping SSE Kafka→Redis bridge") await asyncio.gather(*[c.stop() for c in self.consumers], return_exceptions=True) @@ -127,26 +126,9 @@ async def route_event(event: DomainEvent) -> None: for et in relevant_events: dispatcher.register_handler(et, route_event) - def get_stats(self) -> dict[str, int | bool]: + def get_stats(self) -> dict[str, int]: return { "num_consumers": len(self.consumers), "active_executions": 0, "total_buffers": 0, - "is_running": self.is_running, } - - -def create_sse_kafka_redis_bridge( - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - sse_bus: SSERedisBus, - logger: logging.Logger, -) -> SSEKafkaRedisBridge: - return SSEKafkaRedisBridge( - schema_registry=schema_registry, - settings=settings, - event_metrics=event_metrics, - sse_bus=sse_bus, - logger=logger, - ) diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py index 5c303c54..63f22eb4 100644 --- a/backend/app/services/sse/sse_shutdown_manager.py +++ b/backend/app/services/sse/sse_shutdown_manager.py @@ -2,9 +2,8 @@ import logging import time from enum import Enum -from typing import Dict, Set +from typing import Any, Dict, Set -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.domain.sse import ShutdownStatus @@ -35,7 +34,7 @@ class SSEShutdownManager: def __init__( self, - router: LifecycleEnabled, + router: Any, logger: logging.Logger, connection_metrics: ConnectionMetrics, drain_timeout: float = 30.0, @@ -255,8 +254,7 @@ async def _force_close_connections(self) -> None: self._connection_callbacks.clear() self._draining_connections.clear() - # Tell router to stop accepting new subscriptions - await self._router.aclose() + # Router lifecycle is managed by DI container self.metrics.update_sse_draining_connections(0) self.logger.info("Force close phase complete") diff --git a/backend/tests/integration/dlq/test_dlq_manager.py b/backend/tests/integration/dlq/test_dlq_manager.py index 6af47303..8ee6029f 100644 --- a/backend/tests/integration/dlq/test_dlq_manager.py +++ b/backend/tests/integration/dlq/test_dlq_manager.py @@ -6,8 +6,7 @@ import pytest from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from app.core.metrics import DLQMetrics -from app.dlq.manager import create_dlq_manager +from app.dlq.manager import DLQManager from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DLQMessageReceivedEvent @@ -28,9 +27,8 @@ @pytest.mark.asyncio async def test_dlq_manager_persists_and_emits_event(scope: AsyncContainer, test_settings: Settings) -> None: """Test that DLQ manager persists messages and emits DLQMessageReceivedEvent.""" - schema_registry = SchemaRegistryManager(test_settings, _test_logger) - dlq_metrics: DLQMetrics = await scope.get(DLQMetrics) - manager = create_dlq_manager(settings=test_settings, schema_registry=schema_registry, logger=_test_logger, dlq_metrics=dlq_metrics) + schema_registry = await scope.get(SchemaRegistryManager) + await scope.get(DLQManager) # Ensure DI starts the manager prefix = test_settings.KAFKA_TOPIC_PREFIX ev = make_execution_requested_event(execution_id=f"exec-dlq-persist-{uuid.uuid4().hex[:8]}") @@ -89,14 +87,12 @@ async def consume_dlq_events() -> None: consume_task = asyncio.create_task(consume_dlq_events()) try: - # Start manager - it will consume from DLQ, persist, and emit DLQMessageReceivedEvent - async with manager: - # Await the DLQMessageReceivedEvent - true async, no polling - received = await asyncio.wait_for(received_future, timeout=15.0) - assert received.dlq_event_id == ev.event_id - assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED - assert received.original_event_type == str(EventType.EXECUTION_REQUESTED) - assert received.error == "handler failed" + # Manager is already started by DI - just wait for the event + received = await asyncio.wait_for(received_future, timeout=15.0) + assert received.dlq_event_id == ev.event_id + assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED + assert received.original_event_type == str(EventType.EXECUTION_REQUESTED) + assert received.error == "handler failed" finally: consume_task.cancel() try: diff --git a/backend/tests/integration/events/test_consumer_lifecycle.py b/backend/tests/integration/events/test_consumer_lifecycle.py index 5374e152..f2e69c27 100644 --- a/backend/tests/integration/events/test_consumer_lifecycle.py +++ b/backend/tests/integration/events/test_consumer_lifecycle.py @@ -37,7 +37,7 @@ async def test_consumer_start_status_seek_and_stop(scope: AsyncContainer) -> Non await c.start([KafkaTopic.EXECUTION_EVENTS]) try: st = c.get_status() - assert st.state == "running" and st.is_running is True + assert st.state == "running" # Exercise seek functions; don't force specific partition offsets await c.seek_to_beginning() await c.seek_to_end() diff --git a/backend/tests/integration/services/sse/test_partitioned_event_router.py b/backend/tests/integration/services/sse/test_partitioned_event_router.py index 7e1c4ac6..f31391c1 100644 --- a/backend/tests/integration/services/sse/test_partitioned_event_router.py +++ b/backend/tests/integration/services/sse/test_partitioned_event_router.py @@ -72,10 +72,10 @@ async def test_router_start_and_stop(redis_client: redis.Redis, test_settings: S await router.__aenter__() stats = router.get_stats() assert stats["num_consumers"] == 1 - await router.aclose() + await router.__aexit__(None, None, None) assert router.get_stats()["num_consumers"] == 0 # idempotent start/stop await router.__aenter__() await router.__aenter__() - await router.aclose() - await router.aclose() + await router.__aexit__(None, None, None) + await router.__aexit__(None, None, None) diff --git a/backend/tests/unit/services/coordinator/test_queue_manager.py b/backend/tests/unit/services/coordinator/test_queue_manager.py index b4b39b2d..ebec3a6b 100644 --- a/backend/tests/unit/services/coordinator/test_queue_manager.py +++ b/backend/tests/unit/services/coordinator/test_queue_manager.py @@ -19,23 +19,19 @@ def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value) -> Executi @pytest.mark.asyncio async def test_requeue_execution_increments_priority(coordinator_metrics: CoordinatorMetrics) -> None: qm = QueueManager(max_queue_size=10, logger=_test_logger, coordinator_metrics=coordinator_metrics) - await qm.start() # Use NORMAL priority which can be incremented to LOW e = ev("x", priority=QueuePriority.NORMAL.value) await qm.add_execution(e) await qm.requeue_execution(e, increment_retry=True) nxt = await qm.get_next_execution() assert nxt is not None - await qm.stop() @pytest.mark.asyncio async def test_queue_stats_empty_and_after_add(coordinator_metrics: CoordinatorMetrics) -> None: qm = QueueManager(max_queue_size=5, logger=_test_logger, coordinator_metrics=coordinator_metrics) - await qm.start() stats0 = await qm.get_queue_stats() assert stats0["total_size"] == 0 await qm.add_execution(ev("a")) st = await qm.get_queue_stats() assert st["total_size"] == 1 - await qm.stop() diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index dc93a150..691b4e6f 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock import pytest -from app.core import k8s_clients as k8s_clients_module from app.core.k8s_clients import K8sClients from app.core.metrics import EventMetrics, KubernetesMetrics from app.db.repositories.event_repository import EventRepository @@ -16,12 +15,9 @@ from app.services.pod_monitor.config import PodMonitorConfig from app.services.pod_monitor.event_mapper import PodEventMapper from app.services.pod_monitor.monitor import ( - MonitorState, PodEvent, PodMonitor, - ReconciliationResult, WatchEventType, - create_pod_monitor, ) from app.settings import Settings from kubernetes.client.rest import ApiException @@ -155,23 +151,21 @@ async def test_start_and_stop_lifecycle(event_metrics: EventMetrics, kubernetes_ spy = SpyMapper() pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=spy) # type: ignore[arg-type] - # Replace _watch_pods to avoid real watch loop + # Replace _watch_loop to avoid real watch loop async def _quick_watch() -> None: return None - pm._watch_pods = _quick_watch # type: ignore[method-assign] + pm._watch_loop = _quick_watch # type: ignore[method-assign] await pm.__aenter__() - assert pm.state == MonitorState.RUNNING + assert pm._watch_task is not None - await pm.aclose() - final_state: MonitorState = pm.state - assert final_state == MonitorState.STOPPED + await pm.__aexit__(None, None, None) assert spy.cleared is True @pytest.mark.asyncio -async def test_watch_pod_events_flow_and_publish(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_run_watch_flow_and_publish(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() cfg.enable_state_reconciliation = False @@ -179,14 +173,15 @@ async def test_watch_pod_events_flow_and_publish(event_metrics: EventMetrics, ku k8s_clients = make_k8s_clients_di(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._state = MonitorState.RUNNING - await pm._watch_pod_events() + await pm._run_watch() assert pm._last_resource_version == "rv2" @pytest.mark.asyncio -async def test_process_raw_event_invalid_and_handle_watch_error(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_process_raw_event_invalid_and_backoff( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: cfg = PodMonitorConfig() pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) @@ -194,8 +189,8 @@ async def test_process_raw_event_invalid_and_handle_watch_error(event_metrics: E pm.config.watch_reconnect_delay = 0 pm._reconnect_attempts = 0 - await pm._handle_watch_error() - await pm._handle_watch_error() + await pm._backoff() + await pm._backoff() assert pm._reconnect_attempts >= 2 @@ -212,7 +207,6 @@ async def test_get_status(event_metrics: EventMetrics, kubernetes_metrics: Kuber pm._last_resource_version = "v123" status = await pm.get_status() - assert "idle" in status["state"].lower() assert status["tracked_pods"] == 2 assert status["reconnect_attempts"] == 3 assert status["last_resource_version"] == "v123" @@ -222,41 +216,7 @@ async def test_get_status(event_metrics: EventMetrics, kubernetes_metrics: Kuber @pytest.mark.asyncio -async def test_reconciliation_loop_and_state(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - reconcile_called: list[bool] = [] - - async def mock_reconcile() -> ReconciliationResult: - reconcile_called.append(True) - return ReconciliationResult(missing_pods={"p1"}, extra_pods={"p2"}, duration_seconds=0.1, success=True) - - evt = asyncio.Event() - - async def wrapped_reconcile() -> ReconciliationResult: - res = await mock_reconcile() - evt.set() - return res - - pm._reconcile_state = wrapped_reconcile # type: ignore[method-assign] - - task = asyncio.create_task(pm._reconciliation_loop()) - await asyncio.wait_for(evt.wait(), timeout=1.0) - pm._state = MonitorState.STOPPED - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - assert len(reconcile_called) > 0 - - -@pytest.mark.asyncio -async def test_reconcile_state_success(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_reconcile_success(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() cfg.namespace = "test" cfg.label_selector = "app=test" @@ -275,17 +235,16 @@ async def mock_process(event: PodEvent) -> None: pm._process_pod_event = mock_process # type: ignore[method-assign] - result = await pm._reconcile_state() + await pm._reconcile() - assert result.success is True - assert result.missing_pods == {"pod1"} - assert result.extra_pods == {"pod3"} + # pod1 was missing and should have been processed assert "pod1" in processed + # pod3 was extra and should have been removed from tracking assert "pod3" not in pm._tracked_pods @pytest.mark.asyncio -async def test_reconcile_state_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_reconcile_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() class FailV1(FakeV1Api): @@ -303,10 +262,8 @@ def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - result = await pm._reconcile_state() - assert result.success is False - assert result.error is not None - assert "API error" in result.error + # Should not raise - errors are caught and logged + await pm._reconcile() @pytest.mark.asyncio @@ -368,7 +325,9 @@ async def mock_publish(event: Any, pod: Any) -> None: # noqa: ARG001 @pytest.mark.asyncio -async def test_process_pod_event_exception_handling(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_process_pod_event_exception_handling( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: cfg = PodMonitorConfig() class FailMapper: @@ -412,7 +371,9 @@ async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_m @pytest.mark.asyncio -async def test_publish_event_exception_handling(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_publish_event_exception_handling( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: cfg = PodMonitorConfig() class FailingProducer(FakeUnifiedProducer): @@ -449,133 +410,104 @@ async def produce( @pytest.mark.asyncio -async def test_handle_watch_error_max_attempts(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_backoff_max_attempts(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() cfg.max_reconnect_attempts = 2 pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING pm._reconnect_attempts = 2 - await pm._handle_watch_error() - - assert pm._state == MonitorState.STOPPING + with pytest.raises(RuntimeError, match="Max reconnect attempts exceeded"): + await pm._backoff() @pytest.mark.asyncio -async def test_watch_pods_main_loop(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_watch_loop_with_cancellation(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() + cfg.enable_state_reconciliation = False pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING watch_count: list[int] = [] - async def mock_watch() -> None: + async def mock_run_watch() -> None: watch_count.append(1) - if len(watch_count) > 2: - pm._state = MonitorState.STOPPED + if len(watch_count) >= 3: + raise asyncio.CancelledError() - async def mock_handle_error() -> None: - pass + pm._run_watch = mock_run_watch # type: ignore[method-assign] - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle_error # type: ignore[method-assign] + # watch_loop catches CancelledError and exits gracefully (doesn't propagate) + await pm._watch_loop() - await pm._watch_pods() - assert len(watch_count) > 2 + assert len(watch_count) == 3 @pytest.mark.asyncio -async def test_watch_pods_api_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_watch_loop_api_exception_410(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() + cfg.enable_state_reconciliation = False pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - async def mock_watch() -> None: - raise ApiException(status=410) + pm._last_resource_version = "v123" + call_count = 0 - error_handled: list[bool] = [] + async def mock_run_watch() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ApiException(status=410) + raise asyncio.CancelledError() - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED + async def mock_backoff() -> None: + pass - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] + pm._run_watch = mock_run_watch # type: ignore[method-assign] + pm._backoff = mock_backoff # type: ignore[method-assign] - await pm._watch_pods() + # watch_loop catches CancelledError and exits gracefully + await pm._watch_loop() + # Resource version should be reset on 410 assert pm._last_resource_version is None - assert len(error_handled) > 0 @pytest.mark.asyncio -async def test_watch_pods_generic_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_watch_loop_generic_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: cfg = PodMonitorConfig() + cfg.enable_state_reconciliation = False pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - async def mock_watch() -> None: - raise RuntimeError("Unexpected error") - error_handled: list[bool] = [] + call_count = 0 + backoff_count = 0 - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED + async def mock_run_watch() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Unexpected error") + raise asyncio.CancelledError() - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] + async def mock_backoff() -> None: + nonlocal backoff_count + backoff_count += 1 - await pm._watch_pods() - assert len(error_handled) > 0 + pm._run_watch = mock_run_watch # type: ignore[method-assign] + pm._backoff = mock_backoff # type: ignore[method-assign] + # watch_loop catches CancelledError and exits gracefully + await pm._watch_loop() -@pytest.mark.asyncio -async def test_create_pod_monitor_context_manager(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, monkeypatch: pytest.MonkeyPatch) -> None: - """Test create_pod_monitor factory with auto-created dependencies.""" - # Mock create_k8s_clients to avoid real K8s connection - mock_v1 = FakeV1Api() - mock_watch = make_watch([]) - mock_clients = K8sClients( - api_client=MagicMock(), - v1=mock_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=mock_watch, - ) - - def mock_create_clients( - logger: logging.Logger, # noqa: ARG001 - kubeconfig_path: str | None = None, # noqa: ARG001 - in_cluster: bool | None = None, # noqa: ARG001 - ) -> K8sClients: - return mock_clients - - monkeypatch.setattr(k8s_clients_module, "create_k8s_clients", mock_create_clients) - monkeypatch.setattr("app.services.pod_monitor.monitor.create_k8s_clients", mock_create_clients) - - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - - service, _ = create_test_kafka_event_service(event_metrics) - - # Use the actual create_pod_monitor which will use our mocked create_k8s_clients - async with create_pod_monitor(cfg, service, _test_logger, kubernetes_metrics=kubernetes_metrics) as monitor: - assert monitor.state == MonitorState.RUNNING - - final_state: MonitorState = monitor.state - assert final_state == MonitorState.STOPPED + assert backoff_count == 1 @pytest.mark.asyncio -async def test_create_pod_monitor_with_injected_k8s_clients(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test create_pod_monitor with injected K8sClients (DI path).""" +async def test_pod_monitor_context_manager_lifecycle( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: + """Test PodMonitor lifecycle via async context manager.""" cfg = PodMonitorConfig() cfg.enable_state_reconciliation = False - service, _ = create_test_kafka_event_service(event_metrics) - mock_v1 = FakeV1Api() mock_watch = make_watch([]) mock_k8s_clients = K8sClients( @@ -586,60 +518,38 @@ async def test_create_pod_monitor_with_injected_k8s_clients(event_metrics: Event watch=mock_watch, ) - async with create_pod_monitor( - cfg, service, _test_logger, k8s_clients=mock_k8s_clients, kubernetes_metrics=kubernetes_metrics - ) as monitor: - assert monitor.state == MonitorState.RUNNING - assert monitor._clients is mock_k8s_clients - assert monitor._v1 is mock_v1 - - final_state: MonitorState = monitor.state - assert final_state == MonitorState.STOPPED - - -@pytest.mark.asyncio -async def test_start_already_running(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test idempotent start via __aenter__.""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - # Simulate already started state - pm._lifecycle_started = True - pm._state = MonitorState.RUNNING - - # Should be idempotent - just return self - await pm.__aenter__() - + service, _ = create_test_kafka_event_service(event_metrics) + event_mapper = PodEventMapper(logger=_test_logger, k8s_api=mock_v1) -@pytest.mark.asyncio -async def test_stop_already_stopped(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test idempotent stop via aclose().""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.STOPPED - # Not started, so aclose should be a no-op + monitor = PodMonitor( + config=cfg, + kafka_event_service=service, + logger=_test_logger, + k8s_clients=mock_k8s_clients, + event_mapper=event_mapper, + kubernetes_metrics=kubernetes_metrics, + ) - await pm.aclose() + async with monitor: + assert monitor._watch_task is not None + assert monitor._clients is mock_k8s_clients + assert monitor._v1 is mock_v1 @pytest.mark.asyncio async def test_stop_with_tasks(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test cleanup of tasks on aclose().""" + """Test cleanup of tasks on __aexit__.""" cfg = PodMonitorConfig() pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - pm._lifecycle_started = True async def dummy_task() -> None: await asyncio.Event().wait() pm._watch_task = asyncio.create_task(dummy_task()) - pm._reconcile_task = asyncio.create_task(dummy_task()) pm._tracked_pods = {"pod1"} - await pm.aclose() + await pm.__aexit__(None, None, None) - assert pm._state == MonitorState.STOPPED assert len(pm._tracked_pods) == 0 @@ -660,7 +570,9 @@ class BadStream: @pytest.mark.asyncio -async def test_process_raw_event_with_metadata(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_process_raw_event_with_metadata( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: cfg = PodMonitorConfig() pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) @@ -688,29 +600,40 @@ async def mock_process(event: PodEvent) -> None: @pytest.mark.asyncio -async def test_watch_pods_api_exception_other_status(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_watch_loop_api_exception_other_status( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: cfg = PodMonitorConfig() + cfg.enable_state_reconciliation = False pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - async def mock_watch() -> None: - raise ApiException(status=500) + call_count = 0 + backoff_count = 0 - error_handled: list[bool] = [] + async def mock_run_watch() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ApiException(status=500) + raise asyncio.CancelledError() - async def mock_handle() -> None: - error_handled.append(True) - pm._state = MonitorState.STOPPED + async def mock_backoff() -> None: + nonlocal backoff_count + backoff_count += 1 - pm._watch_pod_events = mock_watch # type: ignore[method-assign] - pm._handle_watch_error = mock_handle # type: ignore[method-assign] + pm._run_watch = mock_run_watch # type: ignore[method-assign] + pm._backoff = mock_backoff # type: ignore[method-assign] - await pm._watch_pods() - assert len(error_handled) > 0 + # watch_loop catches CancelledError and exits gracefully + await pm._watch_loop() + + assert backoff_count == 1 @pytest.mark.asyncio -async def test_watch_pod_events_with_field_selector(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_run_watch_with_field_selector( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: cfg = PodMonitorConfig() cfg.field_selector = "status.phase=Running" cfg.enable_state_reconciliation = False @@ -736,56 +659,40 @@ def stream(self, func: Any, **kwargs: Any) -> FakeWatchStream: ) pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) - pm._state = MonitorState.RUNNING - await pm._watch_pod_events() + await pm._run_watch() assert any("field_selector" in kw for kw in watch_kwargs) @pytest.mark.asyncio -async def test_reconciliation_loop_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: +async def test_watch_loop_with_reconciliation( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics +) -> None: + """Test that reconciliation is called before each watch restart.""" cfg = PodMonitorConfig() cfg.enable_state_reconciliation = True - cfg.reconcile_interval_seconds = 0 # sleep(0) yields control immediately - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - pm._state = MonitorState.RUNNING - - hit = asyncio.Event() - async def raising() -> ReconciliationResult: - hit.set() - raise RuntimeError("Reconcile error") - - pm._reconcile_state = raising # type: ignore[method-assign] - - task = asyncio.create_task(pm._reconciliation_loop()) - await asyncio.wait_for(hit.wait(), timeout=1.0) - pm._state = MonitorState.STOPPED - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - -@pytest.mark.asyncio -async def test_start_with_reconciliation(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) - - async def mock_watch() -> None: - return None + reconcile_count = 0 + watch_count = 0 async def mock_reconcile() -> None: - return None + nonlocal reconcile_count + reconcile_count += 1 - pm._watch_pods = mock_watch # type: ignore[method-assign] - pm._reconciliation_loop = mock_reconcile # type: ignore[method-assign] + async def mock_run_watch() -> None: + nonlocal watch_count + watch_count += 1 + if watch_count >= 2: + raise asyncio.CancelledError() - await pm.__aenter__() - assert pm._watch_task is not None - assert pm._reconcile_task is not None + pm._reconcile = mock_reconcile # type: ignore[method-assign] + pm._run_watch = mock_run_watch # type: ignore[method-assign] + + # watch_loop catches CancelledError and exits gracefully + await pm._watch_loop() - await pm.aclose() + # Reconcile should be called before each watch restart + assert reconcile_count == 2 + assert watch_count == 2 diff --git a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py index b414884a..35e71820 100644 --- a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py +++ b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py @@ -119,11 +119,10 @@ def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: async def test_min_success_flow(event_metrics: EventMetrics) -> None: orch = _orch(event_metrics) orch.register_saga(_Saga) - # Set orchestrator running state via lifecycle property - orch._lifecycle_started = True + # Handle the event await orch._handle_event(make_execution_requested_event(execution_id="e")) # basic sanity; deep behavior covered by integration - assert orch.is_running is True + assert len(orch._sagas) > 0 @pytest.mark.asyncio diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py index 6fa5d1ef..15e3ff9f 100644 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py @@ -61,4 +61,4 @@ async def test_register_and_route_events_without_kafka() -> None: assert fake_bus.published and fake_bus.published[-1][0] == "exec-123" s = bridge.get_stats() - assert s["num_consumers"] == 0 and s["is_running"] is False + assert s["num_consumers"] == 0 diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py index 7c2a484d..7f9ab3c2 100644 --- a/backend/tests/unit/services/sse/test_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_shutdown_manager.py @@ -2,22 +2,24 @@ import logging import pytest -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.services.sse.sse_shutdown_manager import SSEShutdownManager +pytestmark = pytest.mark.unit + _test_logger = logging.getLogger("test.services.sse.shutdown_manager") -class _FakeRouter(LifecycleEnabled): - """Fake router that tracks whether aclose was called.""" +class _FakeRouter: + """Fake router for testing.""" def __init__(self) -> None: - super().__init__() self.stopped = False - self._lifecycle_started = True # Simulate already-started router - async def _on_stop(self) -> None: + async def __aenter__(self) -> "_FakeRouter": + return self + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: self.stopped = True @@ -53,10 +55,9 @@ async def on_shutdown(event: asyncio.Event, cid: str) -> None: @pytest.mark.asyncio -async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: - router = _FakeRouter() +async def test_shutdown_force_close_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: mgr = SSEShutdownManager( - router=router, + router=_FakeRouter(), logger=_test_logger, connection_metrics=connection_metrics, drain_timeout=0.01, @@ -70,7 +71,6 @@ async def test_shutdown_force_close_calls_router_stop_and_rejects_new(connection # Initiate shutdown await mgr.initiate_shutdown() - assert router.stopped is True assert mgr.is_shutting_down() is True status = mgr.get_shutdown_status() assert status.draining_connections == 0 diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py index 3f424605..54ab54f3 100644 --- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py +++ b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py @@ -2,7 +2,6 @@ import logging import pytest -from app.core.lifecycle import LifecycleEnabled from app.core.metrics import ConnectionMetrics from app.services.sse.sse_shutdown_manager import SSEShutdownManager @@ -11,15 +10,16 @@ _test_logger = logging.getLogger("test.services.sse.sse_shutdown_manager") -class _FakeRouter(LifecycleEnabled): - """Fake router that tracks whether aclose was called.""" +class _FakeRouter: + """Fake router for testing.""" def __init__(self) -> None: - super().__init__() self.stopped = False - self._lifecycle_started = True # Simulate already-started router - async def _on_stop(self) -> None: + async def __aenter__(self) -> "_FakeRouter": + return self + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: self.stopped = True diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py index 12004bf1..77346b3a 100644 --- a/backend/workers/run_coordinator.py +++ b/backend/workers/run_coordinator.py @@ -39,8 +39,8 @@ async def run_coordinator(settings: Settings) -> None: logger.info("ExecutionCoordinator started and running") try: - # Wait for shutdown signal or service to stop - while coordinator.is_running and not shutdown_event.is_set(): + # Wait for shutdown signal + while not shutdown_event.is_set(): await asyncio.sleep(60) status = await coordinator.get_status() logger.info(f"Coordinator status: {status}") diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index d3b857ad..657785f8 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -39,8 +39,8 @@ async def run_kubernetes_worker(settings: Settings) -> None: logger.info("KubernetesWorker started and running") try: - # Wait for shutdown signal or service to stop - while worker.is_running and not shutdown_event.is_set(): + # Wait for shutdown signal + while not shutdown_event.is_set(): await asyncio.sleep(60) status = await worker.get_status() logger.info(f"Kubernetes worker status: {status}") diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 4b4dd325..36e9f7f7 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -9,7 +9,7 @@ from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.pod_monitor.monitor import MonitorState, PodMonitor +from app.services.pod_monitor.monitor import PodMonitor from app.settings import Settings from beanie import init_beanie @@ -41,8 +41,8 @@ async def run_pod_monitor(settings: Settings) -> None: logger.info("PodMonitor started and running") try: - # Wait for shutdown signal or service to stop - while monitor.state == MonitorState.RUNNING and not shutdown_event.is_set(): + # Wait for shutdown signal + while not shutdown_event.is_set(): await asyncio.sleep(RECONCILIATION_LOG_INTERVAL) status = await monitor.get_status() logger.info(f"Pod monitor status: {status}") diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 7fd0c359..8027e2e4 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -28,7 +28,7 @@ async def run_saga_orchestrator(settings: Settings) -> None: await initialize_event_schemas(schema_registry) # Services are already started by the DI container providers - orchestrator = await container.get(SagaOrchestrator) + await container.get(SagaOrchestrator) # Shutdown event - signal handlers just set this shutdown_event = asyncio.Event() @@ -39,8 +39,8 @@ async def run_saga_orchestrator(settings: Settings) -> None: logger.info("Saga orchestrator started and running") try: - # Wait for shutdown signal or service to stop - while orchestrator.is_running and not shutdown_event.is_set(): + # Wait for shutdown signal + while not shutdown_event.is_set(): await asyncio.sleep(1) finally: # Container cleanup stops everything From f5ca16d7dc02d77d269bc7d0eaa3feb5fa33bce5 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Mon, 19 Jan 2026 21:57:57 +0100 Subject: [PATCH 06/21] tests update --- backend/tests/integration/conftest.py | 87 +++++++ .../events/test_consume_roundtrip.py | 33 +-- .../events/test_consumer_lifecycle.py | 22 +- .../events/test_event_dispatcher.py | 31 +-- .../events/test_producer_roundtrip.py | 2 +- .../events/test_schema_registry_real.py | 14 +- .../events/test_schema_registry_roundtrip.py | 24 +- .../idempotency/test_consumer_idempotent.py | 26 +- .../idempotency/test_idempotency.py | 184 ++++++------- .../result_processor/test_result_processor.py | 34 +-- .../sse/test_partitioned_event_router.py | 54 ++-- backend/tests/load/plot_report.py | 2 +- backend/tests/unit/conftest.py | 17 ++ .../events/test_schema_registry_manager.py | 15 +- .../coordinator/test_resource_manager.py | 15 +- .../unit/services/pod_monitor/test_monitor.py | 241 +++++++++++------- 16 files changed, 436 insertions(+), 365 deletions(-) diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 0d824014..e188a3b7 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -1,11 +1,24 @@ +import logging +import uuid from collections.abc import AsyncGenerator +import pytest import pytest_asyncio import redis.asyncio as redis from app.core.database_context import Database +from app.core.metrics import DatabaseMetrics, EventMetrics +from app.events.core import ConsumerConfig +from app.events.schema.schema_registry import SchemaRegistryManager +from app.services.idempotency.idempotency_manager import IdempotencyConfig, IdempotencyManager +from app.services.idempotency.redis_repository import RedisIdempotencyRepository +from app.services.sse.redis_bus import SSERedisBus +from app.settings import Settings +from dishka import AsyncContainer from tests.helpers.cleanup import cleanup_db_and_redis +_test_logger = logging.getLogger("test.integration") + @pytest_asyncio.fixture(autouse=True) async def _cleanup(db: Database, redis_client: redis.Redis) -> AsyncGenerator[None, None]: @@ -17,3 +30,77 @@ async def _cleanup(db: Database, redis_client: redis.Redis) -> AsyncGenerator[No await cleanup_db_and_redis(db, redis_client) yield # No post-test cleanup to avoid "Event loop is closed" errors + + +# ===== DI-based fixtures for integration tests ===== + + +@pytest_asyncio.fixture +async def schema_registry(scope: AsyncContainer) -> SchemaRegistryManager: + """Provide SchemaRegistryManager via DI.""" + return await scope.get(SchemaRegistryManager) + + +@pytest_asyncio.fixture +async def event_metrics(scope: AsyncContainer) -> EventMetrics: + """Provide EventMetrics via DI.""" + return await scope.get(EventMetrics) + + +@pytest_asyncio.fixture +async def database_metrics(scope: AsyncContainer) -> DatabaseMetrics: + """Provide DatabaseMetrics via DI.""" + return await scope.get(DatabaseMetrics) + + +# ===== Config fixtures ===== + + +@pytest.fixture +def consumer_config(test_settings: Settings) -> ConsumerConfig: + """Provide a unique ConsumerConfig for each test. + + Defaults for integration tests: + - enable_auto_commit=True: Commit offsets automatically for simpler test cleanup + - auto_offset_reset="earliest": Read all messages from start (default in ConsumerConfig) + """ + return ConsumerConfig( + bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"test-consumer-{uuid.uuid4().hex[:6]}", + enable_auto_commit=True, + ) + + +@pytest_asyncio.fixture +async def sse_redis_bus(redis_client: redis.Redis) -> SSERedisBus: + """Provide SSERedisBus with unique prefixes for test isolation.""" + suffix = uuid.uuid4().hex[:6] + return SSERedisBus( + redis_client, + exec_prefix=f"sse:exec:{suffix}:", + notif_prefix=f"sse:notif:{suffix}:", + logger=_test_logger, + ) + + +@pytest_asyncio.fixture +async def idempotency_manager( + redis_client: redis.Redis, database_metrics: DatabaseMetrics +) -> AsyncGenerator[IdempotencyManager, None]: + """Provide IdempotencyManager with unique prefix for test isolation.""" + prefix = f"idemp:{uuid.uuid4().hex[:6]}" + cfg = IdempotencyConfig( + key_prefix=prefix, + default_ttl_seconds=3600, + processing_timeout_seconds=5, + enable_result_caching=True, + max_result_size_bytes=1024, + enable_metrics=False, + ) + repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) + mgr = IdempotencyManager(cfg, repo, _test_logger, database_metrics=database_metrics) + await mgr.initialize() + try: + yield mgr + finally: + await mgr.close() diff --git a/backend/tests/integration/events/test_consume_roundtrip.py b/backend/tests/integration/events/test_consume_roundtrip.py index 94193247..40b0d490 100644 --- a/backend/tests/integration/events/test_consume_roundtrip.py +++ b/backend/tests/integration/events/test_consume_roundtrip.py @@ -1,15 +1,13 @@ import asyncio import logging -import uuid import pytest from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent -from app.events.core import UnifiedConsumer, UnifiedProducer +from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from dishka import AsyncContainer @@ -24,12 +22,15 @@ @pytest.mark.asyncio -async def test_produce_consume_roundtrip(scope: AsyncContainer) -> None: +async def test_produce_consume_roundtrip( + scope: AsyncContainer, + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + consumer_config: ConsumerConfig, + test_settings: Settings, +) -> None: # Ensure schemas are registered - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - await initialize_event_schemas(registry) + await initialize_event_schemas(schema_registry) # Real producer from DI producer: UnifiedProducer = await scope.get(UnifiedProducer) @@ -42,19 +43,11 @@ async def test_produce_consume_roundtrip(scope: AsyncContainer) -> None: async def _handle(_event: DomainEvent) -> None: received.set() - group_id = f"test-consumer.{uuid.uuid4().hex[:6]}" - config = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=group_id, - enable_auto_commit=True, - auto_offset_reset="earliest", - ) - consumer = UnifiedConsumer( - config, + consumer_config, dispatcher, - schema_registry=registry, - settings=settings, + schema_registry=schema_registry, + settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) @@ -62,7 +55,7 @@ async def _handle(_event: DomainEvent) -> None: try: # Produce a request event - execution_id = f"exec-{uuid.uuid4().hex[:8]}" + execution_id = f"exec-{consumer_config.group_id}" evt = make_execution_requested_event(execution_id=execution_id) await producer.produce(evt, key=execution_id) diff --git a/backend/tests/integration/events/test_consumer_lifecycle.py b/backend/tests/integration/events/test_consumer_lifecycle.py index f2e69c27..8272e772 100644 --- a/backend/tests/integration/events/test_consumer_lifecycle.py +++ b/backend/tests/integration/events/test_consumer_lifecycle.py @@ -1,5 +1,4 @@ import logging -from uuid import uuid4 import pytest from app.core.metrics import EventMetrics @@ -7,7 +6,6 @@ from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer from app.events.schema.schema_registry import SchemaRegistryManager from app.settings import Settings -from dishka import AsyncContainer # xdist_group: Kafka consumer creation can crash librdkafka when multiple workers # instantiate Consumer() objects simultaneously. Serial execution prevents this. @@ -17,20 +15,18 @@ @pytest.mark.asyncio -async def test_consumer_start_status_seek_and_stop(scope: AsyncContainer) -> None: - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - cfg = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test-consumer-{uuid4().hex[:6]}", - ) +async def test_consumer_start_status_seek_and_stop( + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + consumer_config: ConsumerConfig, + test_settings: Settings, +) -> None: disp = EventDispatcher(logger=_test_logger) c = UnifiedConsumer( - cfg, + consumer_config, event_dispatcher=disp, - schema_registry=registry, - settings=settings, + schema_registry=schema_registry, + settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) diff --git a/backend/tests/integration/events/test_event_dispatcher.py b/backend/tests/integration/events/test_event_dispatcher.py index 3d166cec..0940a88e 100644 --- a/backend/tests/integration/events/test_event_dispatcher.py +++ b/backend/tests/integration/events/test_event_dispatcher.py @@ -1,15 +1,13 @@ import asyncio import logging -import uuid import pytest from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent -from app.events.core import UnifiedConsumer, UnifiedProducer +from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.settings import Settings from dishka import AsyncContainer @@ -24,12 +22,15 @@ @pytest.mark.asyncio -async def test_dispatcher_with_multiple_handlers(scope: AsyncContainer) -> None: +async def test_dispatcher_with_multiple_handlers( + scope: AsyncContainer, + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + consumer_config: ConsumerConfig, + test_settings: Settings, +) -> None: # Ensure schema registry is ready - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) - await initialize_event_schemas(registry) + await initialize_event_schemas(schema_registry) # Build dispatcher with two handlers for the same event dispatcher = EventDispatcher(logger=_test_logger) @@ -45,17 +46,11 @@ async def h2(_e: DomainEvent) -> None: h2_called.set() # Real consumer against execution-events - cfg = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"dispatcher-it.{uuid.uuid4().hex[:6]}", - enable_auto_commit=True, - auto_offset_reset="earliest", - ) consumer = UnifiedConsumer( - cfg, + consumer_config, dispatcher, - schema_registry=registry, - settings=settings, + schema_registry=schema_registry, + settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) @@ -63,7 +58,7 @@ async def h2(_e: DomainEvent) -> None: # Produce a request event via DI producer: UnifiedProducer = await scope.get(UnifiedProducer) - evt = make_execution_requested_event(execution_id=f"exec-{uuid.uuid4().hex[:8]}") + evt = make_execution_requested_event(execution_id=f"exec-{consumer_config.group_id}") await producer.produce(evt, key="k") try: diff --git a/backend/tests/integration/events/test_producer_roundtrip.py b/backend/tests/integration/events/test_producer_roundtrip.py index cb91df15..eda0c299 100644 --- a/backend/tests/integration/events/test_producer_roundtrip.py +++ b/backend/tests/integration/events/test_producer_roundtrip.py @@ -38,4 +38,4 @@ async def test_unified_producer_start_produce_send_to_dlq_stop( await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) st = prod.get_status() - assert st["running"] is True and st["state"] == "running" + assert st["state"] == "running" diff --git a/backend/tests/integration/events/test_schema_registry_real.py b/backend/tests/integration/events/test_schema_registry_real.py index 3e9da631..f102c032 100644 --- a/backend/tests/integration/events/test_schema_registry_real.py +++ b/backend/tests/integration/events/test_schema_registry_real.py @@ -1,28 +1,24 @@ -import logging - import pytest from app.domain.events.typed import EventMetadata, PodCreatedEvent from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.mappings import get_topic_for_event -from app.settings import Settings pytestmark = [pytest.mark.integration, pytest.mark.kafka] -_test_logger = logging.getLogger("test.events.schema_registry_real") - @pytest.mark.asyncio -async def test_serialize_and_deserialize_event_real_registry(test_settings: Settings) -> None: +async def test_serialize_and_deserialize_event_real_registry( + schema_registry: SchemaRegistryManager, +) -> None: # Uses real Schema Registry configured via env (SCHEMA_REGISTRY_URL) - m = SchemaRegistryManager(settings=test_settings, logger=_test_logger) ev = PodCreatedEvent( execution_id="e1", pod_name="p", namespace="n", metadata=EventMetadata(service_name="s", service_version="1"), ) - data = await m.serialize_event(ev) + data = await schema_registry.serialize_event(ev) topic = str(get_topic_for_event(ev.event_type)) - obj = await m.deserialize_event(data, topic=topic) + obj = await schema_registry.deserialize_event(data, topic=topic) assert isinstance(obj, PodCreatedEvent) assert obj.namespace == "n" diff --git a/backend/tests/integration/events/test_schema_registry_roundtrip.py b/backend/tests/integration/events/test_schema_registry_roundtrip.py index f23b2fe6..fba233cb 100644 --- a/backend/tests/integration/events/test_schema_registry_roundtrip.py +++ b/backend/tests/integration/events/test_schema_registry_roundtrip.py @@ -1,35 +1,31 @@ -import logging - import pytest from app.events.schema.schema_registry import MAGIC_BYTE, SchemaRegistryManager from app.infrastructure.kafka.mappings import get_topic_for_event -from app.settings import Settings -from dishka import AsyncContainer from tests.helpers import make_execution_requested_event pytestmark = [pytest.mark.integration] -_test_logger = logging.getLogger("test.events.schema_registry_roundtrip") - @pytest.mark.asyncio -async def test_schema_registry_serialize_deserialize_roundtrip(scope: AsyncContainer) -> None: - reg: SchemaRegistryManager = await scope.get(SchemaRegistryManager) +async def test_schema_registry_serialize_deserialize_roundtrip( + schema_registry: SchemaRegistryManager, +) -> None: # Schema registration happens lazily in serialize_event ev = make_execution_requested_event(execution_id="e-rt") - data = await reg.serialize_event(ev) + data = await schema_registry.serialize_event(ev) assert data.startswith(MAGIC_BYTE) topic = str(get_topic_for_event(ev.event_type)) - back = await reg.deserialize_event(data, topic=topic) + back = await schema_registry.deserialize_event(data, topic=topic) assert back.event_id == ev.event_id and getattr(back, "execution_id", None) == ev.execution_id # initialize_schemas should be a no-op if already initialized; call to exercise path - await reg.initialize_schemas() + await schema_registry.initialize_schemas() @pytest.mark.asyncio -async def test_schema_registry_deserialize_invalid_header(test_settings: Settings) -> None: - reg = SchemaRegistryManager(settings=test_settings, logger=_test_logger) +async def test_schema_registry_deserialize_invalid_header( + schema_registry: SchemaRegistryManager, +) -> None: with pytest.raises(ValueError): - await reg.deserialize_event(b"\x01\x00\x00\x00\x01", topic="t") # wrong magic byte + await schema_registry.deserialize_event(b"\x01\x00\x00\x00\x01", topic="t") # wrong magic byte diff --git a/backend/tests/integration/idempotency/test_consumer_idempotent.py b/backend/tests/integration/idempotency/test_consumer_idempotent.py index 19d4b05f..658a553e 100644 --- a/backend/tests/integration/idempotency/test_consumer_idempotent.py +++ b/backend/tests/integration/idempotency/test_consumer_idempotent.py @@ -1,6 +1,5 @@ import asyncio import logging -import uuid import pytest from app.core.metrics import EventMetrics @@ -30,12 +29,15 @@ @pytest.mark.asyncio -async def test_consumer_idempotent_wrapper_blocks_duplicates(scope: AsyncContainer) -> None: +async def test_consumer_idempotent_wrapper_blocks_duplicates( + scope: AsyncContainer, + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + consumer_config: ConsumerConfig, + test_settings: Settings, +) -> None: producer: UnifiedProducer = await scope.get(UnifiedProducer) idm: IdempotencyManager = await scope.get(IdempotencyManager) - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) # Future resolves when handler processes an event - no polling needed handled_future: asyncio.Future[None] = asyncio.get_running_loop().create_future() @@ -51,23 +53,17 @@ async def handle(_ev: DomainEvent) -> None: handled_future.set_result(None) # Produce messages BEFORE starting consumer (auto_offset_reset="earliest" will read them) - execution_id = f"e-{uuid.uuid4().hex[:8]}" + execution_id = f"e-{consumer_config.group_id}" ev = make_execution_requested_event(execution_id=execution_id) await producer.produce(ev, key=execution_id) await producer.produce(ev, key=execution_id) # Real consumer with idempotent wrapper - cfg = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test-idem-consumer.{uuid.uuid4().hex[:6]}", - enable_auto_commit=True, - auto_offset_reset="earliest", - ) base = UnifiedConsumer( - cfg, + consumer_config, event_dispatcher=disp, - schema_registry=registry, - settings=settings, + schema_registry=schema_registry, + settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) diff --git a/backend/tests/integration/idempotency/test_idempotency.py b/backend/tests/integration/idempotency/test_idempotency.py index 032a7f46..69ff00cb 100644 --- a/backend/tests/integration/idempotency/test_idempotency.py +++ b/backend/tests/integration/idempotency/test_idempotency.py @@ -2,7 +2,6 @@ import json import logging import uuid -from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone from typing import Any @@ -14,7 +13,6 @@ from app.services.idempotency.idempotency_manager import IdempotencyConfig, IdempotencyManager from app.services.idempotency.middleware import IdempotentEventHandler, idempotent_handler from app.services.idempotency.redis_repository import RedisIdempotencyRepository -from app.settings import Settings from tests.helpers import make_execution_requested_event @@ -27,77 +25,57 @@ class TestIdempotencyManager: """IdempotencyManager backed by real Redis repository (DI-provided client).""" - @pytest.fixture - async def manager(self, redis_client: redis.Redis, test_settings: Settings) -> AsyncGenerator[IdempotencyManager, None]: - prefix = f"idemp_ut:{uuid.uuid4().hex[:6]}" - cfg = IdempotencyConfig( - key_prefix=prefix, - default_ttl_seconds=3600, - processing_timeout_seconds=5, - enable_result_caching=True, - max_result_size_bytes=1024, - enable_metrics=False, - ) - repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) - database_metrics = DatabaseMetrics(test_settings) - m = IdempotencyManager(cfg, repo, _test_logger, database_metrics=database_metrics) - await m.initialize() - try: - yield m - finally: - await m.close() - @pytest.mark.asyncio - async def test_complete_flow_new_event(self, manager: IdempotencyManager) -> None: + async def test_complete_flow_new_event(self, idempotency_manager: IdempotencyManager) -> None: """Test the complete flow for a new event""" real_event = make_execution_requested_event(execution_id="exec-123") # Check and reserve - result = await manager.check_and_reserve(real_event, key_strategy="event_based") + result = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result.is_duplicate is False assert result.status == IdempotencyStatus.PROCESSING assert result.key.endswith(f"{real_event.event_type}:{real_event.event_id}") - assert result.key.startswith(f"{manager.config.key_prefix}:") + assert result.key.startswith(f"{idempotency_manager.config.key_prefix}:") # Verify it's in the repository - record = await manager._repo.find_by_key(result.key) + record = await idempotency_manager._repo.find_by_key(result.key) assert record is not None assert record.status == IdempotencyStatus.PROCESSING # Mark as completed - success = await manager.mark_completed(real_event, key_strategy="event_based") + success = await idempotency_manager.mark_completed(real_event, key_strategy="event_based") assert success is True # Verify status updated - record = await manager._repo.find_by_key(result.key) + record = await idempotency_manager._repo.find_by_key(result.key) assert record is not None assert record.status == IdempotencyStatus.COMPLETED assert record.completed_at is not None assert record.processing_duration_ms is not None @pytest.mark.asyncio - async def test_duplicate_detection(self, manager: IdempotencyManager) -> None: + async def test_duplicate_detection(self, idempotency_manager: IdempotencyManager) -> None: """Test that duplicates are properly detected""" real_event = make_execution_requested_event(execution_id="exec-dupe-1") # First request - result1 = await manager.check_and_reserve(real_event, key_strategy="event_based") + result1 = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result1.is_duplicate is False # Mark as completed - await manager.mark_completed(real_event, key_strategy="event_based") + await idempotency_manager.mark_completed(real_event, key_strategy="event_based") # Second request with same event - result2 = await manager.check_and_reserve(real_event, key_strategy="event_based") + result2 = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result2.is_duplicate is True assert result2.status == IdempotencyStatus.COMPLETED @pytest.mark.asyncio - async def test_concurrent_requests_race_condition(self, manager: IdempotencyManager) -> None: + async def test_concurrent_requests_race_condition(self, idempotency_manager: IdempotencyManager) -> None: """Test handling of concurrent requests for the same event""" real_event = make_execution_requested_event(execution_id="exec-race-1") # Simulate concurrent requests tasks = [ - manager.check_and_reserve(real_event, key_strategy="event_based") + idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") for _ in range(5) ] @@ -112,26 +90,26 @@ async def test_concurrent_requests_race_condition(self, manager: IdempotencyMana assert duplicate_count == 4 @pytest.mark.asyncio - async def test_processing_timeout_allows_retry(self, manager: IdempotencyManager) -> None: + async def test_processing_timeout_allows_retry(self, idempotency_manager: IdempotencyManager) -> None: """Test that stuck processing allows retry after timeout""" real_event = make_execution_requested_event(execution_id="exec-timeout-1") # First request - result1 = await manager.check_and_reserve(real_event, key_strategy="event_based") + result1 = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result1.is_duplicate is False # Manually update the created_at to simulate old processing - record = await manager._repo.find_by_key(result1.key) + record = await idempotency_manager._repo.find_by_key(result1.key) assert record is not None record.created_at = datetime.now(timezone.utc) - timedelta(seconds=10) - await manager._repo.update_record(record) + await idempotency_manager._repo.update_record(record) # Second request should be allowed due to timeout - result2 = await manager.check_and_reserve(real_event, key_strategy="event_based") + result2 = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result2.is_duplicate is False # Allowed to retry assert result2.status == IdempotencyStatus.PROCESSING @pytest.mark.asyncio - async def test_content_hash_strategy(self, manager: IdempotencyManager) -> None: + async def test_content_hash_strategy(self, idempotency_manager: IdempotencyManager) -> None: """Test content-based deduplication""" # Two events with same content and same execution_id event1 = make_execution_requested_event( @@ -145,46 +123,46 @@ async def test_content_hash_strategy(self, manager: IdempotencyManager) -> None: ) # Use content hash strategy - result1 = await manager.check_and_reserve(event1, key_strategy="content_hash") + result1 = await idempotency_manager.check_and_reserve(event1, key_strategy="content_hash") assert result1.is_duplicate is False - await manager.mark_completed(event1, key_strategy="content_hash") + await idempotency_manager.mark_completed(event1, key_strategy="content_hash") # Second event with same content should be duplicate - result2 = await manager.check_and_reserve(event2, key_strategy="content_hash") + result2 = await idempotency_manager.check_and_reserve(event2, key_strategy="content_hash") assert result2.is_duplicate is True @pytest.mark.asyncio - async def test_failed_event_handling(self, manager: IdempotencyManager) -> None: + async def test_failed_event_handling(self, idempotency_manager: IdempotencyManager) -> None: """Test marking events as failed""" real_event = make_execution_requested_event(execution_id="exec-failed-1") # Reserve - result = await manager.check_and_reserve(real_event, key_strategy="event_based") + result = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result.is_duplicate is False # Mark as failed error_msg = "Execution failed: out of memory" - success = await manager.mark_failed(real_event, error=error_msg, key_strategy="event_based") + success = await idempotency_manager.mark_failed(real_event, error=error_msg, key_strategy="event_based") assert success is True # Verify status and error - record = await manager._repo.find_by_key(result.key) + record = await idempotency_manager._repo.find_by_key(result.key) assert record is not None assert record.status == IdempotencyStatus.FAILED assert record.error == error_msg assert record.completed_at is not None @pytest.mark.asyncio - async def test_result_caching(self, manager: IdempotencyManager) -> None: + async def test_result_caching(self, idempotency_manager: IdempotencyManager) -> None: """Test caching of results""" real_event = make_execution_requested_event(execution_id="exec-cache-1") # Reserve - result = await manager.check_and_reserve(real_event, key_strategy="event_based") + result = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result.is_duplicate is False # Complete with cached result cached_result = json.dumps({"output": "Hello, World!", "exit_code": 0}) - success = await manager.mark_completed_with_json( + success = await idempotency_manager.mark_completed_with_json( real_event, cached_json=cached_result, key_strategy="event_based" @@ -192,16 +170,16 @@ async def test_result_caching(self, manager: IdempotencyManager) -> None: assert success is True # Retrieve cached result - retrieved = await manager.get_cached_json(real_event, "event_based", None) + retrieved = await idempotency_manager.get_cached_json(real_event, "event_based", None) assert retrieved == cached_result # Check duplicate with cached result - duplicate_result = await manager.check_and_reserve(real_event, key_strategy="event_based") + duplicate_result = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert duplicate_result.is_duplicate is True assert duplicate_result.has_cached_result is True @pytest.mark.asyncio - async def test_stats_aggregation(self, manager: IdempotencyManager) -> None: + async def test_stats_aggregation(self, idempotency_manager: IdempotencyManager) -> None: """Test statistics aggregation""" # Create various events with different statuses events = [] @@ -215,62 +193,49 @@ async def test_stats_aggregation(self, manager: IdempotencyManager) -> None: # Process events with different outcomes for i, event in enumerate(events): - await manager.check_and_reserve(event, key_strategy="event_based") + await idempotency_manager.check_and_reserve(event, key_strategy="event_based") if i < 6: - await manager.mark_completed(event, key_strategy="event_based") + await idempotency_manager.mark_completed(event, key_strategy="event_based") elif i < 8: - await manager.mark_failed(event, "Test error", key_strategy="event_based") + await idempotency_manager.mark_failed(event, "Test error", key_strategy="event_based") # Leave rest in processing # Get stats - stats = await manager.get_stats() + stats = await idempotency_manager.get_stats() assert stats.total_keys == 10 assert stats.status_counts[IdempotencyStatus.COMPLETED] == 6 assert stats.status_counts[IdempotencyStatus.FAILED] == 2 assert stats.status_counts[IdempotencyStatus.PROCESSING] == 2 - assert stats.prefix == manager.config.key_prefix + assert stats.prefix == idempotency_manager.config.key_prefix @pytest.mark.asyncio - async def test_remove_key(self, manager: IdempotencyManager) -> None: + async def test_remove_key(self, idempotency_manager: IdempotencyManager) -> None: """Test removing idempotency keys""" real_event = make_execution_requested_event(execution_id="exec-remove-1") # Add a key - result = await manager.check_and_reserve(real_event, key_strategy="event_based") + result = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result.is_duplicate is False # Remove it - removed = await manager.remove(real_event, key_strategy="event_based") + removed = await idempotency_manager.remove(real_event, key_strategy="event_based") assert removed is True # Verify it's gone - record = await manager._repo.find_by_key(result.key) + record = await idempotency_manager._repo.find_by_key(result.key) assert record is None # Can process again - result2 = await manager.check_and_reserve(real_event, key_strategy="event_based") + result2 = await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") assert result2.is_duplicate is False class TestIdempotentEventHandlerIntegration: """Test IdempotentEventHandler with real components""" - @pytest.fixture - async def manager(self, redis_client: redis.Redis, test_settings: Settings) -> AsyncGenerator[IdempotencyManager, None]: - prefix = f"handler_test:{uuid.uuid4().hex[:6]}" - config = IdempotencyConfig(key_prefix=prefix, enable_metrics=False) - repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) - database_metrics = DatabaseMetrics(test_settings) - m = IdempotencyManager(config, repo, _test_logger, database_metrics=database_metrics) - await m.initialize() - try: - yield m - finally: - await m.close() - @pytest.mark.asyncio - async def test_handler_processes_new_event(self, manager: IdempotencyManager) -> None: + async def test_handler_processes_new_event(self, idempotency_manager: IdempotencyManager) -> None: """Test that handler processes new events""" processed_events: list[DomainEvent] = [] @@ -280,7 +245,7 @@ async def actual_handler(event: DomainEvent) -> None: # Create idempotent handler handler = IdempotentEventHandler( handler=actual_handler, - idempotency_manager=manager, + idempotency_manager=idempotency_manager, key_strategy="event_based", logger=_test_logger, ) @@ -294,7 +259,7 @@ async def actual_handler(event: DomainEvent) -> None: assert processed_events[0] == real_event @pytest.mark.asyncio - async def test_handler_blocks_duplicate(self, manager: IdempotencyManager) -> None: + async def test_handler_blocks_duplicate(self, idempotency_manager: IdempotencyManager) -> None: """Test that handler blocks duplicate events""" processed_events: list[DomainEvent] = [] @@ -304,7 +269,7 @@ async def actual_handler(event: DomainEvent) -> None: # Create idempotent handler handler = IdempotentEventHandler( handler=actual_handler, - idempotency_manager=manager, + idempotency_manager=idempotency_manager, key_strategy="event_based", logger=_test_logger, ) @@ -318,7 +283,7 @@ async def actual_handler(event: DomainEvent) -> None: assert len(processed_events) == 1 @pytest.mark.asyncio - async def test_handler_with_failure(self, manager: IdempotencyManager) -> None: + async def test_handler_with_failure(self, idempotency_manager: IdempotencyManager) -> None: """Test handler marks failure on exception""" async def failing_handler(event: DomainEvent) -> None: # noqa: ARG001 @@ -326,7 +291,7 @@ async def failing_handler(event: DomainEvent) -> None: # noqa: ARG001 handler = IdempotentEventHandler( handler=failing_handler, - idempotency_manager=manager, + idempotency_manager=idempotency_manager, key_strategy="event_based", logger=_test_logger, ) @@ -337,15 +302,15 @@ async def failing_handler(event: DomainEvent) -> None: # noqa: ARG001 await handler(real_event) # Verify marked as failed - key = f"{manager.config.key_prefix}:{real_event.event_type}:{real_event.event_id}" - record = await manager._repo.find_by_key(key) + key = f"{idempotency_manager.config.key_prefix}:{real_event.event_type}:{real_event.event_id}" + record = await idempotency_manager._repo.find_by_key(key) assert record is not None assert record.status == IdempotencyStatus.FAILED assert record.error is not None assert "Processing failed" in record.error @pytest.mark.asyncio - async def test_handler_duplicate_callback(self, manager: IdempotencyManager) -> None: + async def test_handler_duplicate_callback(self, idempotency_manager: IdempotencyManager) -> None: """Test duplicate callback is invoked""" duplicate_events: list[tuple[DomainEvent, Any]] = [] @@ -357,7 +322,7 @@ async def on_duplicate(event: DomainEvent, result: Any) -> None: handler = IdempotentEventHandler( handler=actual_handler, - idempotency_manager=manager, + idempotency_manager=idempotency_manager, key_strategy="event_based", on_duplicate=on_duplicate, logger=_test_logger, @@ -374,12 +339,12 @@ async def on_duplicate(event: DomainEvent, result: Any) -> None: assert duplicate_events[0][1].is_duplicate is True @pytest.mark.asyncio - async def test_decorator_integration(self, manager: IdempotencyManager) -> None: + async def test_decorator_integration(self, idempotency_manager: IdempotencyManager) -> None: """Test the @idempotent_handler decorator""" processed_events: list[DomainEvent] = [] @idempotent_handler( - idempotency_manager=manager, + idempotency_manager=idempotency_manager, key_strategy="content_hash", ttl_seconds=300, logger=_test_logger, @@ -406,7 +371,7 @@ async def my_handler(event: DomainEvent) -> None: assert len(processed_events) == 1 # Still only one @pytest.mark.asyncio - async def test_custom_key_function(self, manager: IdempotencyManager) -> None: + async def test_custom_key_function(self, idempotency_manager: IdempotencyManager) -> None: """Test handler with custom key function""" processed_scripts: list[str] = [] @@ -421,7 +386,7 @@ def extract_script_key(event: DomainEvent) -> str: handler = IdempotentEventHandler( handler=process_script, - idempotency_manager=manager, + idempotency_manager=idempotency_manager, key_strategy="custom", custom_key_func=extract_script_key, logger=_test_logger, @@ -457,45 +422,45 @@ def extract_script_key(event: DomainEvent) -> str: assert processed_scripts[0] == "print('hello')" @pytest.mark.asyncio - async def test_invalid_key_strategy(self, manager: IdempotencyManager) -> None: + async def test_invalid_key_strategy(self, idempotency_manager: IdempotencyManager) -> None: """Test that invalid key strategy raises error""" real_event = make_execution_requested_event(execution_id="invalid-strategy-1") with pytest.raises(ValueError, match="Invalid key strategy"): - await manager.check_and_reserve(real_event, key_strategy="invalid_strategy") + await idempotency_manager.check_and_reserve(real_event, key_strategy="invalid_strategy") @pytest.mark.asyncio - async def test_custom_key_without_custom_key_param(self, manager: IdempotencyManager) -> None: + async def test_custom_key_without_custom_key_param(self, idempotency_manager: IdempotencyManager) -> None: """Test that custom strategy without custom_key raises error""" real_event = make_execution_requested_event(execution_id="custom-key-missing-1") with pytest.raises(ValueError, match="Invalid key strategy"): - await manager.check_and_reserve(real_event, key_strategy="custom") + await idempotency_manager.check_and_reserve(real_event, key_strategy="custom") @pytest.mark.asyncio - async def test_get_cached_json_existing(self, manager: IdempotencyManager) -> None: + async def test_get_cached_json_existing(self, idempotency_manager: IdempotencyManager) -> None: """Test retrieving cached JSON result""" # First complete with cached result real_event = make_execution_requested_event(execution_id="cache-exist-1") - await manager.check_and_reserve(real_event, key_strategy="event_based") + await idempotency_manager.check_and_reserve(real_event, key_strategy="event_based") cached_data = json.dumps({"output": "test", "code": 0}) - await manager.mark_completed_with_json(real_event, cached_data, "event_based") + await idempotency_manager.mark_completed_with_json(real_event, cached_data, "event_based") # Retrieve cached result - retrieved = await manager.get_cached_json(real_event, "event_based", None) + retrieved = await idempotency_manager.get_cached_json(real_event, "event_based", None) assert retrieved == cached_data @pytest.mark.asyncio - async def test_get_cached_json_non_existing(self, manager: IdempotencyManager) -> None: + async def test_get_cached_json_non_existing(self, idempotency_manager: IdempotencyManager) -> None: """Test retrieving non-existing cached result raises assertion""" real_event = make_execution_requested_event(execution_id="cache-miss-1") # Trying to get cached result for non-existent key should raise with pytest.raises(AssertionError, match="cached result must exist"): - await manager.get_cached_json(real_event, "event_based", None) + await idempotency_manager.get_cached_json(real_event, "event_based", None) @pytest.mark.asyncio - async def test_cleanup_expired_keys(self, manager: IdempotencyManager) -> None: + async def test_cleanup_expired_keys(self, idempotency_manager: IdempotencyManager) -> None: """Test cleanup of expired keys""" # Create expired record - expired_key = f"{manager.config.key_prefix}:expired" + expired_key = f"{idempotency_manager.config.key_prefix}:expired" expired_record = IdempotencyRecord( key=expired_key, status=IdempotencyStatus.COMPLETED, @@ -505,19 +470,18 @@ async def test_cleanup_expired_keys(self, manager: IdempotencyManager) -> None: ttl_seconds=3600, # 1 hour TTL completed_at=datetime.now(timezone.utc) - timedelta(hours=2) ) - await manager._repo.insert_processing(expired_record) + await idempotency_manager._repo.insert_processing(expired_record) # Cleanup should detect it as expired # Note: actual cleanup implementation depends on repository - record = await manager._repo.find_by_key(expired_key) + record = await idempotency_manager._repo.find_by_key(expired_key) assert record is not None # Still exists until explicit cleanup @pytest.mark.asyncio - async def test_metrics_enabled(self, redis_client: redis.Redis, test_settings: Settings) -> None: + async def test_metrics_enabled(self, redis_client: redis.Redis, database_metrics: DatabaseMetrics) -> None: """Test manager with metrics enabled""" config = IdempotencyConfig(key_prefix=f"metrics:{uuid.uuid4().hex[:6]}", enable_metrics=True) repository = RedisIdempotencyRepository(redis_client, key_prefix=config.key_prefix) - database_metrics = DatabaseMetrics(test_settings) manager = IdempotencyManager(config, repository, _test_logger, database_metrics=database_metrics) # Initialize with metrics @@ -528,7 +492,7 @@ async def test_metrics_enabled(self, redis_client: redis.Redis, test_settings: S await manager.close() @pytest.mark.asyncio - async def test_content_hash_with_fields(self, manager: IdempotencyManager) -> None: + async def test_content_hash_with_fields(self, idempotency_manager: IdempotencyManager) -> None: """Test content hash with specific fields""" event1 = make_execution_requested_event( execution_id="exec-1", @@ -537,13 +501,13 @@ async def test_content_hash_with_fields(self, manager: IdempotencyManager) -> No # Use content hash with only script field fields = {"script", "language"} - result1 = await manager.check_and_reserve( + result1 = await idempotency_manager.check_and_reserve( event1, key_strategy="content_hash", fields=fields ) assert result1.is_duplicate is False - await manager.mark_completed(event1, key_strategy="content_hash", fields=fields) + await idempotency_manager.mark_completed(event1, key_strategy="content_hash", fields=fields) # Event with same script and language but different other fields event2 = make_execution_requested_event( @@ -556,7 +520,7 @@ async def test_content_hash_with_fields(self, manager: IdempotencyManager) -> No service_name="test-service", ) - result2 = await manager.check_and_reserve( + result2 = await idempotency_manager.check_and_reserve( event2, key_strategy="content_hash", fields=fields diff --git a/backend/tests/integration/result_processor/test_result_processor.py b/backend/tests/integration/result_processor/test_result_processor.py index de2546d6..2e62554f 100644 --- a/backend/tests/integration/result_processor/test_result_processor.py +++ b/backend/tests/integration/result_processor/test_result_processor.py @@ -1,6 +1,5 @@ import asyncio import logging -import uuid import pytest from app.core.database_context import Database @@ -12,9 +11,8 @@ from app.domain.events.typed import EventMetadata, ExecutionCompletedEvent, ResultStoredEvent from app.domain.execution import DomainExecutionCreate from app.domain.execution.models import ResourceUsageDomain -from app.events.core import UnifiedConsumer, UnifiedProducer +from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher -from app.events.core.types import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.idempotency import IdempotencyManager from app.services.result_processor.processor import ResultProcessor @@ -34,13 +32,16 @@ @pytest.mark.asyncio -async def test_result_processor_persists_and_emits(scope: AsyncContainer) -> None: +async def test_result_processor_persists_and_emits( + scope: AsyncContainer, + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + consumer_config: ConsumerConfig, + test_settings: Settings, +) -> None: # Ensure schemas - registry: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - settings: Settings = await scope.get(Settings) - event_metrics: EventMetrics = await scope.get(EventMetrics) execution_metrics: ExecutionMetrics = await scope.get(ExecutionMetrics) - await initialize_event_schemas(registry) + await initialize_event_schemas(schema_registry) # Dependencies db: Database = await scope.get(Database) @@ -62,8 +63,8 @@ async def test_result_processor_persists_and_emits(scope: AsyncContainer) -> Non processor = ResultProcessor( execution_repo=repo, producer=producer, - schema_registry=registry, - settings=settings, + schema_registry=schema_registry, + settings=test_settings, idempotency_manager=idem, logger=_test_logger, execution_metrics=execution_metrics, @@ -79,18 +80,11 @@ async def _stored(event: ResultStoredEvent) -> None: if event.execution_id == execution_id: stored_received.set() - group_id = f"rp-test.{uuid.uuid4().hex[:6]}" - cconf = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=group_id, - enable_auto_commit=True, - auto_offset_reset="earliest", - ) stored_consumer = UnifiedConsumer( - cconf, + consumer_config, dispatcher, - schema_registry=registry, - settings=settings, + schema_registry=schema_registry, + settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) diff --git a/backend/tests/integration/services/sse/test_partitioned_event_router.py b/backend/tests/integration/services/sse/test_partitioned_event_router.py index f31391c1..af620341 100644 --- a/backend/tests/integration/services/sse/test_partitioned_event_router.py +++ b/backend/tests/integration/services/sse/test_partitioned_event_router.py @@ -3,7 +3,6 @@ from uuid import uuid4 import pytest -import redis.asyncio as redis from app.core.metrics import EventMetrics from app.events.core import EventDispatcher from app.events.schema.schema_registry import SchemaRegistryManager @@ -20,19 +19,17 @@ @pytest.mark.asyncio -async def test_router_bridges_to_redis(redis_client: redis.Redis, test_settings: Settings) -> None: - suffix = uuid4().hex[:6] - bus = SSERedisBus( - redis_client, - exec_prefix=f"sse:exec:{suffix}:", - notif_prefix=f"sse:notif:{suffix}:", - logger=_test_logger, - ) +async def test_router_bridges_to_redis( + sse_redis_bus: SSERedisBus, + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + test_settings: Settings, +) -> None: router = SSEKafkaRedisBridge( - schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), + schema_registry=schema_registry, settings=test_settings, - event_metrics=EventMetrics(test_settings), - sse_bus=bus, + event_metrics=event_metrics, + sse_bus=sse_redis_bus, logger=_test_logger, ) disp = EventDispatcher(logger=_test_logger) @@ -40,7 +37,7 @@ async def test_router_bridges_to_redis(redis_client: redis.Redis, test_settings: # Open Redis subscription for our execution id execution_id = f"e-{uuid4().hex[:8]}" - subscription = await bus.open_subscription(execution_id) + subscription = await sse_redis_bus.open_subscription(execution_id) ev = make_execution_requested_event(execution_id=execution_id) handler = disp.get_handlers(ev.event_type)[0] @@ -53,29 +50,22 @@ async def test_router_bridges_to_redis(redis_client: redis.Redis, test_settings: @pytest.mark.asyncio -async def test_router_start_and_stop(redis_client: redis.Redis, test_settings: Settings) -> None: +async def test_router_start_and_stop( + sse_redis_bus: SSERedisBus, + schema_registry: SchemaRegistryManager, + event_metrics: EventMetrics, + test_settings: Settings, +) -> None: test_settings.SSE_CONSUMER_POOL_SIZE = 1 - suffix = uuid4().hex[:6] router = SSEKafkaRedisBridge( - schema_registry=SchemaRegistryManager(settings=test_settings, logger=_test_logger), + schema_registry=schema_registry, settings=test_settings, - event_metrics=EventMetrics(test_settings), - sse_bus=SSERedisBus( - redis_client, - exec_prefix=f"sse:exec:{suffix}:", - notif_prefix=f"sse:notif:{suffix}:", - logger=_test_logger, - ), + event_metrics=event_metrics, + sse_bus=sse_redis_bus, logger=_test_logger, ) - await router.__aenter__() - stats = router.get_stats() - assert stats["num_consumers"] == 1 - await router.__aexit__(None, None, None) + async with router: + assert router.get_stats()["num_consumers"] == 1 + assert router.get_stats()["num_consumers"] == 0 - # idempotent start/stop - await router.__aenter__() - await router.__aenter__() - await router.__aexit__(None, None, None) - await router.__aexit__(None, None, None) diff --git a/backend/tests/load/plot_report.py b/backend/tests/load/plot_report.py index b415e15e..86cb0667 100644 --- a/backend/tests/load/plot_report.py +++ b/backend/tests/load/plot_report.py @@ -114,7 +114,7 @@ def plot_endpoint_throughput(report: ReportDict, out_dir: Path, top_n: int = 10) labels = [k for k, _ in data] total = [v.get("count", 0) for _, v in data] errors = [v.get("errors", 0) for _, v in data] - successes = [t - e for t, e in zip(total, errors)] + successes = [t - e for t, e in zip(total, errors, strict=True)] x = range(len(labels)) width = 0.45 diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index 65b28839..a81a26a2 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -1,3 +1,4 @@ +import logging from typing import NoReturn import pytest @@ -15,8 +16,12 @@ ReplayMetrics, SecurityMetrics, ) +from app.events.schema.schema_registry import SchemaRegistryManager +from app.services.pod_monitor.config import PodMonitorConfig from app.settings import Settings +_test_logger = logging.getLogger("test.unit") + # Metrics fixtures - provided via DI, not global context @pytest.fixture @@ -97,3 +102,15 @@ def client() -> NoReturn: @pytest.fixture def app() -> NoReturn: raise RuntimeError("Unit tests should not use full app - use mocks or move to integration/") + + +# Config fixtures - fresh instance per test (can be customized by tests) +@pytest.fixture +def pod_monitor_config() -> PodMonitorConfig: + return PodMonitorConfig() + + +@pytest.fixture +def schema_registry(test_settings: Settings) -> SchemaRegistryManager: + """Provide SchemaRegistryManager for unit tests (no external connections).""" + return SchemaRegistryManager(test_settings, logger=_test_logger) diff --git a/backend/tests/unit/events/test_schema_registry_manager.py b/backend/tests/unit/events/test_schema_registry_manager.py index 6819237a..f7ac7cb1 100644 --- a/backend/tests/unit/events/test_schema_registry_manager.py +++ b/backend/tests/unit/events/test_schema_registry_manager.py @@ -1,15 +1,9 @@ -import logging - import pytest from app.domain.events.typed import ExecutionRequestedEvent from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings - -_test_logger = logging.getLogger("test.events.schema_registry_manager") -def test_deserialize_json_execution_requested(test_settings: Settings) -> None: - m = SchemaRegistryManager(test_settings, logger=_test_logger) +def test_deserialize_json_execution_requested(schema_registry: SchemaRegistryManager) -> None: data = { "event_type": "execution_requested", "execution_id": "e1", @@ -27,13 +21,12 @@ def test_deserialize_json_execution_requested(test_settings: Settings) -> None: "priority": 5, "metadata": {"service_name": "t", "service_version": "1.0"}, } - ev = m.deserialize_json(data) + ev = schema_registry.deserialize_json(data) assert isinstance(ev, ExecutionRequestedEvent) assert ev.execution_id == "e1" assert ev.language == "python" -def test_deserialize_json_missing_type_raises(test_settings: Settings) -> None: - m = SchemaRegistryManager(test_settings, logger=_test_logger) +def test_deserialize_json_missing_type_raises(schema_registry: SchemaRegistryManager) -> None: with pytest.raises(ValueError): - m.deserialize_json({}) + schema_registry.deserialize_json({}) diff --git a/backend/tests/unit/services/coordinator/test_resource_manager.py b/backend/tests/unit/services/coordinator/test_resource_manager.py index 3624dae6..4f579e45 100644 --- a/backend/tests/unit/services/coordinator/test_resource_manager.py +++ b/backend/tests/unit/services/coordinator/test_resource_manager.py @@ -9,7 +9,10 @@ @pytest.mark.asyncio async def test_request_allocation_defaults_and_limits(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) + rm = ResourceManager( + total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0, + logger=_test_logger, coordinator_metrics=coordinator_metrics + ) # Default for python alloc = await rm.request_allocation("e1", "python") @@ -27,7 +30,10 @@ async def test_request_allocation_defaults_and_limits(coordinator_metrics: Coord @pytest.mark.asyncio async def test_release_and_can_allocate(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) + rm = ResourceManager( + total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0, + logger=_test_logger, coordinator_metrics=coordinator_metrics + ) a = await rm.request_allocation("e1", "python", requested_cpu=1.0, requested_memory_mb=512) assert a is not None @@ -47,7 +53,10 @@ async def test_release_and_can_allocate(coordinator_metrics: CoordinatorMetrics) @pytest.mark.asyncio async def test_resource_stats(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager(total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0, logger=_test_logger, coordinator_metrics=coordinator_metrics) + rm = ResourceManager( + total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0, + logger=_test_logger, coordinator_metrics=coordinator_metrics + ) # Make sure the allocation succeeds alloc = await rm.request_allocation("e1", "python", requested_cpu=0.5, requested_memory_mb=256) assert alloc is not None, "Allocation should have succeeded" diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 691b4e6f..d775fd94 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -69,11 +69,12 @@ async def aclose(self) -> None: pass -def create_test_kafka_event_service(event_metrics: EventMetrics) -> tuple[KafkaEventService, FakeUnifiedProducer]: +def create_test_kafka_event_service( + event_metrics: EventMetrics, settings: Settings +) -> tuple[KafkaEventService, FakeUnifiedProducer]: """Create real KafkaEventService with fake dependencies for testing.""" fake_producer = FakeUnifiedProducer() fake_repo = FakeEventRepository() - settings = Settings() # Uses defaults/env vars service = KafkaEventService( event_repository=fake_repo, @@ -120,6 +121,7 @@ def make_k8s_clients_di( def make_pod_monitor( event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, + settings: Settings, config: PodMonitorConfig | None = None, kafka_service: KafkaEventService | None = None, k8s_clients: K8sClients | None = None, @@ -129,7 +131,7 @@ def make_pod_monitor( cfg = config or PodMonitorConfig() clients = k8s_clients or make_k8s_clients_di() mapper = event_mapper or PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")) - service = kafka_service or create_test_kafka_event_service(event_metrics)[0] + service = kafka_service or create_test_kafka_event_service(event_metrics, settings)[0] return PodMonitor( config=cfg, kafka_event_service=service, @@ -144,12 +146,14 @@ def make_pod_monitor( @pytest.mark.asyncio -async def test_start_and_stop_lifecycle(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False +async def test_start_and_stop_lifecycle( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.enable_state_reconciliation = False spy = SpyMapper() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=spy) # type: ignore[arg-type] + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, event_mapper=spy) # type: ignore[arg-type] # Replace _watch_loop to avoid real watch loop async def _quick_watch() -> None: @@ -157,22 +161,25 @@ async def _quick_watch() -> None: pm._watch_loop = _quick_watch # type: ignore[method-assign] - await pm.__aenter__() - assert pm._watch_task is not None + async with pm: + assert pm._watch_task is not None - await pm.__aexit__(None, None, None) assert spy.cleared is True @pytest.mark.asyncio -async def test_run_watch_flow_and_publish(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False +async def test_run_watch_flow_and_publish( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.enable_state_reconciliation = False pod = make_pod(name="p", phase="Succeeded", labels={"execution-id": "e1"}, term_exit=0, resource_version="rv1") k8s_clients = make_k8s_clients_di(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + ) await pm._run_watch() assert pm._last_resource_version == "rv2" @@ -180,10 +187,10 @@ async def test_run_watch_flow_and_publish(event_metrics: EventMetrics, kubernete @pytest.mark.asyncio async def test_process_raw_event_invalid_and_backoff( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) await pm._process_raw_event({}) @@ -195,13 +202,15 @@ async def test_process_raw_event_invalid_and_backoff( @pytest.mark.asyncio -async def test_get_status(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.namespace = "test-ns" - cfg.label_selector = "app=test" - cfg.enable_state_reconciliation = True +async def test_get_status( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.namespace = "test-ns" + pod_monitor_config.label_selector = "app=test" + pod_monitor_config.enable_state_reconciliation = True - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) pm._tracked_pods = {"pod1", "pod2"} pm._reconnect_attempts = 3 pm._last_resource_version = "v123" @@ -216,16 +225,20 @@ async def test_get_status(event_metrics: EventMetrics, kubernetes_metrics: Kuber @pytest.mark.asyncio -async def test_reconcile_success(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.namespace = "test" - cfg.label_selector = "app=test" +async def test_reconcile_success( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.namespace = "test" + pod_monitor_config.label_selector = "app=test" pod1 = make_pod(name="pod1", phase="Running", resource_version="v1") pod2 = make_pod(name="pod2", phase="Running", resource_version="v1") k8s_clients = make_k8s_clients_di(pods=[pod1, pod2]) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + ) pm._tracked_pods = {"pod2", "pod3"} processed: list[str] = [] @@ -244,9 +257,10 @@ async def mock_process(event: PodEvent) -> None: @pytest.mark.asyncio -async def test_reconcile_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - +async def test_reconcile_exception( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: class FailV1(FakeV1Api): def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: raise RuntimeError("API error") @@ -260,16 +274,20 @@ def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: watch=make_watch([]), ) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + ) # Should not raise - errors are caught and logged await pm._reconcile() @pytest.mark.asyncio -async def test_process_pod_event_full_flow(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.ignored_pod_phases = ["Unknown"] +async def test_process_pod_event_full_flow( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.ignored_pod_phases = ["Unknown"] class MockMapper: def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # noqa: ARG002 @@ -283,7 +301,11 @@ class Event: def clear_cache(self) -> None: pass - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=MockMapper()) # type: ignore[arg-type] + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, + config=pod_monitor_config, + event_mapper=MockMapper(), # type: ignore[arg-type] + ) published: list[Any] = [] @@ -326,10 +348,9 @@ async def mock_publish(event: Any, pod: Any) -> None: # noqa: ARG001 @pytest.mark.asyncio async def test_process_pod_event_exception_handling( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: - cfg = PodMonitorConfig() - class FailMapper: def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: raise RuntimeError("Mapping failed") @@ -337,7 +358,11 @@ def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: def clear_cache(self) -> None: pass - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, event_mapper=FailMapper()) # type: ignore[arg-type] + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, + config=pod_monitor_config, + event_mapper=FailMapper(), # type: ignore[arg-type] + ) event = PodEvent( event_type=WatchEventType.ADDED, @@ -350,10 +375,14 @@ def clear_cache(self) -> None: @pytest.mark.asyncio -async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - service, fake_producer = create_test_kafka_event_service(event_metrics) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, kafka_service=service) +async def test_publish_event_full_flow( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + service, fake_producer = create_test_kafka_event_service(event_metrics, test_settings) + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, kafka_service=service + ) event = ExecutionCompletedEvent( execution_id="exec1", @@ -372,10 +401,9 @@ async def test_publish_event_full_flow(event_metrics: EventMetrics, kubernetes_m @pytest.mark.asyncio async def test_publish_event_exception_handling( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: - cfg = PodMonitorConfig() - class FailingProducer(FakeUnifiedProducer): async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None @@ -388,12 +416,14 @@ async def produce( failing_service = KafkaEventService( event_repository=fake_repo, kafka_producer=failing_producer, - settings=Settings(), + settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, kafka_service=failing_service) + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, kafka_service=failing_service + ) event = ExecutionStartedEvent( execution_id="exec1", @@ -410,11 +440,13 @@ async def produce( @pytest.mark.asyncio -async def test_backoff_max_attempts(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.max_reconnect_attempts = 2 +async def test_backoff_max_attempts( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.max_reconnect_attempts = 2 - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) pm._reconnect_attempts = 2 with pytest.raises(RuntimeError, match="Max reconnect attempts exceeded"): @@ -422,10 +454,12 @@ async def test_backoff_max_attempts(event_metrics: EventMetrics, kubernetes_metr @pytest.mark.asyncio -async def test_watch_loop_with_cancellation(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_watch_loop_with_cancellation( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.enable_state_reconciliation = False + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) watch_count: list[int] = [] @@ -443,10 +477,12 @@ async def mock_run_watch() -> None: @pytest.mark.asyncio -async def test_watch_loop_api_exception_410(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_watch_loop_api_exception_410( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.enable_state_reconciliation = False + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) pm._last_resource_version = "v123" call_count = 0 @@ -472,10 +508,12 @@ async def mock_backoff() -> None: @pytest.mark.asyncio -async def test_watch_loop_generic_exception(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_watch_loop_generic_exception( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pod_monitor_config.enable_state_reconciliation = False + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) call_count = 0 backoff_count = 0 @@ -502,11 +540,11 @@ async def mock_backoff() -> None: @pytest.mark.asyncio async def test_pod_monitor_context_manager_lifecycle( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: """Test PodMonitor lifecycle via async context manager.""" - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False + pod_monitor_config.enable_state_reconciliation = False mock_v1 = FakeV1Api() mock_watch = make_watch([]) @@ -518,11 +556,11 @@ async def test_pod_monitor_context_manager_lifecycle( watch=mock_watch, ) - service, _ = create_test_kafka_event_service(event_metrics) + service, _ = create_test_kafka_event_service(event_metrics, test_settings) event_mapper = PodEventMapper(logger=_test_logger, k8s_api=mock_v1) monitor = PodMonitor( - config=cfg, + config=pod_monitor_config, kafka_event_service=service, logger=_test_logger, k8s_clients=mock_k8s_clients, @@ -537,25 +575,30 @@ async def test_pod_monitor_context_manager_lifecycle( @pytest.mark.asyncio -async def test_stop_with_tasks(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - """Test cleanup of tasks on __aexit__.""" - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +async def test_stop_with_tasks( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + """Test cleanup of tasks on context exit.""" + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - async def dummy_task() -> None: - await asyncio.Event().wait() + # Replace _watch_loop to avoid real watch and add tracked pods + async def _quick_watch() -> None: + pm._tracked_pods = {"pod1"} - pm._watch_task = asyncio.create_task(dummy_task()) - pm._tracked_pods = {"pod1"} + pm._watch_loop = _quick_watch # type: ignore[method-assign] - await pm.__aexit__(None, None, None) + async with pm: + assert pm._watch_task is not None assert len(pm._tracked_pods) == 0 -def test_update_resource_version(event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) +def test_update_resource_version( + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, +) -> None: + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) class Stream: _stop_event = types.SimpleNamespace(resource_version="v123") @@ -571,10 +614,10 @@ class BadStream: @pytest.mark.asyncio async def test_process_raw_event_with_metadata( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: - cfg = PodMonitorConfig() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) processed: list[PodEvent] = [] @@ -601,11 +644,11 @@ async def mock_process(event: PodEvent) -> None: @pytest.mark.asyncio async def test_watch_loop_api_exception_other_status( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) + pod_monitor_config.enable_state_reconciliation = False + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) call_count = 0 backoff_count = 0 @@ -632,11 +675,11 @@ async def mock_backoff() -> None: @pytest.mark.asyncio async def test_run_watch_with_field_selector( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: - cfg = PodMonitorConfig() - cfg.field_selector = "status.phase=Running" - cfg.enable_state_reconciliation = False + pod_monitor_config.field_selector = "status.phase=Running" + pod_monitor_config.enable_state_reconciliation = False watch_kwargs: list[dict[str, Any]] = [] @@ -658,7 +701,9 @@ def stream(self, func: Any, **kwargs: Any) -> FakeWatchStream: watch=TrackingWatch([], "rv1"), ) - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg, k8s_clients=k8s_clients) + pm = make_pod_monitor( + event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + ) await pm._run_watch() @@ -667,12 +712,12 @@ def stream(self, func: Any, **kwargs: Any) -> FakeWatchStream: @pytest.mark.asyncio async def test_watch_loop_with_reconciliation( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics + event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + pod_monitor_config: PodMonitorConfig, ) -> None: """Test that reconciliation is called before each watch restart.""" - cfg = PodMonitorConfig() - cfg.enable_state_reconciliation = True - pm = make_pod_monitor(event_metrics, kubernetes_metrics, config=cfg) + pod_monitor_config.enable_state_reconciliation = True + pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) reconcile_count = 0 watch_count = 0 From ff45a521cea0819e7f2042850c2b7252f2db84d7 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 20 Jan 2026 10:40:54 +0100 Subject: [PATCH 07/21] Di usage instead of stateful services --- backend/app/core/container.py | 2 + backend/app/core/dishka_lifespan.py | 50 +- backend/app/core/providers.py | 482 +++++++++++++----- backend/app/events/core/consumer.py | 251 +++------ backend/app/events/event_store_consumer.py | 168 +++--- backend/app/services/coordinator/__init__.py | 4 +- .../{coordinator.py => coordinator_logic.py} | 185 ++----- backend/app/services/idempotency/__init__.py | 3 +- .../app/services/idempotency/middleware.py | 255 +++------ backend/app/services/k8s_worker/__init__.py | 4 +- .../k8s_worker/{worker.py => worker_logic.py} | 234 +++------ backend/app/services/notification_service.py | 273 ++-------- backend/app/services/pod_monitor/monitor.py | 105 ++-- .../app/services/result_processor/__init__.py | 5 +- .../{processor.py => processor_logic.py} | 154 +----- backend/app/services/saga/__init__.py | 4 +- .../{saga_orchestrator.py => saga_logic.py} | 183 ++----- backend/app/services/saga/saga_service.py | 16 +- backend/app/services/sse/event_router.py | 63 +++ .../app/services/sse/kafka_redis_bridge.py | 134 ----- .../services/sse/sse_connection_registry.py | 57 +++ backend/app/services/sse/sse_service.py | 64 +-- .../app/services/sse/sse_shutdown_manager.py | 300 ----------- .../tests/e2e/test_k8s_worker_create_pod.py | 43 +- .../events/test_consume_roundtrip.py | 9 +- .../events/test_consumer_lifecycle.py | 36 +- .../events/test_event_dispatcher.py | 9 +- .../idempotency/test_consumer_idempotent.py | 11 +- .../idempotency/test_decorator_idempotent.py | 52 -- .../idempotency/test_idempotency.py | 34 +- .../result_processor/test_result_processor.py | 107 +++- .../coordinator/test_execution_coordinator.py | 8 +- .../sse/test_partitioned_event_router.py | 50 +- .../services/idempotency/test_middleware.py | 1 - .../unit/services/pod_monitor/test_monitor.py | 101 ++-- .../result_processor/test_processor.py | 38 +- .../saga/test_saga_orchestrator_unit.py | 52 +- .../services/sse/test_kafka_redis_bridge.py | 32 +- .../services/sse/test_shutdown_manager.py | 97 ---- .../sse/test_sse_connection_registry.py | 76 +++ .../unit/services/sse/test_sse_service.py | 74 +-- .../services/sse/test_sse_shutdown_manager.py | 87 ---- backend/workers/run_coordinator.py | 38 +- backend/workers/run_k8s_worker.py | 40 +- backend/workers/run_pod_monitor.py | 28 +- backend/workers/run_result_processor.py | 61 +-- backend/workers/run_saga_orchestrator.py | 42 +- 47 files changed, 1518 insertions(+), 2604 deletions(-) rename backend/app/services/coordinator/{coordinator.py => coordinator_logic.py} (70%) rename backend/app/services/k8s_worker/{worker.py => worker_logic.py} (70%) rename backend/app/services/result_processor/{processor.py => processor_logic.py} (58%) rename backend/app/services/saga/{saga_orchestrator.py => saga_logic.py} (76%) create mode 100644 backend/app/services/sse/event_router.py delete mode 100644 backend/app/services/sse/kafka_redis_bridge.py create mode 100644 backend/app/services/sse/sse_connection_registry.py delete mode 100644 backend/app/services/sse/sse_shutdown_manager.py delete mode 100644 backend/tests/integration/idempotency/test_decorator_idempotent.py delete mode 100644 backend/tests/unit/services/sse/test_shutdown_manager.py create mode 100644 backend/tests/unit/services/sse/test_sse_connection_registry.py delete mode 100644 backend/tests/unit/services/sse/test_sse_shutdown_manager.py diff --git a/backend/app/core/container.py b/backend/app/core/container.py index b67f133a..f45c2033 100644 --- a/backend/app/core/container.py +++ b/backend/app/core/container.py @@ -19,6 +19,7 @@ PodMonitorProvider, RedisProvider, RepositoryProvider, + ResultProcessorProvider, SagaOrchestratorProvider, SettingsProvider, SSEProvider, @@ -73,6 +74,7 @@ def create_result_processor_container(settings: Settings) -> AsyncContainer: RepositoryProvider(), EventProvider(), MessagingProvider(), + ResultProcessorProvider(), context={Settings: settings}, ) diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 956c5ebd..23dbcdf5 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -1,6 +1,6 @@ import asyncio import logging -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import asynccontextmanager from typing import AsyncGenerator import redis.asyncio as redis @@ -13,9 +13,10 @@ from app.core.startup import initialize_rate_limits from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS +from app.events.core import UnifiedConsumer from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.services.notification_service import NotificationService from app.settings import Settings @@ -24,10 +25,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """ Application lifespan with dishka dependency injection. - This is much cleaner than the old lifespan.py: - - No dependency_overrides - - No manual service management - - Dishka handles all lifecycle automatically + Services are already initialized by their DI providers (which handle __aenter__/__aexit__). + Lifespan just starts the run() methods as background tasks. """ # Get settings and logger from DI container (uses test settings in tests) container: AsyncContainer = app.state.dishka_container @@ -79,15 +78,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: database, redis_client, rate_limit_metrics, - sse_bridge, + sse_consumers, event_store_consumer, + notification_service, ) = await asyncio.gather( container.get(SchemaRegistryManager), container.get(Database), container.get(redis.Redis), container.get(RateLimitMetrics), - container.get(SSEKafkaRedisBridge), + container.get(list[UnifiedConsumer]), container.get(EventStoreConsumer), + container.get(NotificationService), ) # Phase 2: Initialize infrastructure in parallel (independent subsystems) @@ -98,11 +99,30 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) logger.info("Infrastructure initialized (schemas, beanie, rate limits)") - # Phase 3: Start Kafka consumers in parallel - async with AsyncExitStack() as stack: - await asyncio.gather( - stack.enter_async_context(sse_bridge), - stack.enter_async_context(event_store_consumer), - ) - logger.info("SSE bridge and EventStoreConsumer started") + # Phase 3: Start run() methods as background tasks + # Note: Services are already initialized by their DI providers (which handle __aenter__/__aexit__) + + async def run_sse_consumers() -> None: + """Run SSE consumers using TaskGroup.""" + async with asyncio.TaskGroup() as tg: + for consumer in sse_consumers: + tg.create_task(consumer.run()) + + tasks = [ + asyncio.create_task(run_sse_consumers(), name="sse_consumers"), + asyncio.create_task(event_store_consumer.run(), name="event_store_consumer"), + asyncio.create_task(notification_service.run(), name="notification_service"), + ] + logger.info(f"Background services started ({len(sse_consumers)} SSE consumers)") + + try: yield + finally: + # Cancel all background tasks on shutdown + logger.info("Shutting down background services...") + for task in tasks: + task.cancel() + + # Wait for tasks to finish cancellation + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("Background services stopped") diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index ad01da2a..d18ea80a 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -43,16 +43,16 @@ from app.db.repositories.user_settings_repository import UserSettingsRepository from app.dlq.manager import DLQManager from app.dlq.models import RetryPolicy, RetryStrategy -from app.domain.enums.kafka import GroupId, KafkaTopic +from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic from app.domain.saga.models import SagaConfig -from app.events.core import UnifiedProducer +from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer from app.events.event_store import EventStore, create_event_store from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.topics import get_all_topics from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService -from app.services.coordinator.coordinator import ExecutionCoordinator +from app.services.coordinator.coordinator_logic import CoordinatorLogic from app.services.event_bus import EventBus, EventBusEvent from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService @@ -60,9 +60,10 @@ from app.services.grafana_alert_processor import GrafanaAlertProcessor from app.services.idempotency import IdempotencyConfig, IdempotencyManager from app.services.idempotency.idempotency_manager import create_idempotency_manager +from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.services.idempotency.redis_repository import RedisIdempotencyRepository from app.services.k8s_worker.config import K8sWorkerConfig -from app.services.k8s_worker.worker import KubernetesWorker +from app.services.k8s_worker.worker_logic import K8sWorkerLogic from app.services.kafka_event_service import KafkaEventService from app.services.notification_service import NotificationService from app.services.pod_monitor.config import PodMonitorConfig @@ -70,13 +71,14 @@ from app.services.pod_monitor.monitor import PodMonitor from app.services.rate_limit_service import RateLimitService from app.services.replay_service import ReplayService -from app.services.saga import SagaOrchestrator +from app.services.result_processor.processor_logic import ProcessorLogic +from app.services.saga.saga_logic import SagaLogic from app.services.saga.saga_service import SagaService from app.services.saved_script_service import SavedScriptService -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.services.sse.event_router import SSEEventRouter from app.services.sse.redis_bus import SSERedisBus +from app.services.sse.sse_connection_registry import SSEConnectionRegistry from app.services.sse.sse_service import SSEService -from app.services.sse.sse_shutdown_manager import SSEShutdownManager from app.services.user_settings_service import UserSettingsService from app.settings import Settings @@ -243,26 +245,23 @@ async def get_event_store( ) @provide - async def get_event_store_consumer( + def get_event_store_consumer( self, event_store: EventStore, schema_registry: SchemaRegistryManager, settings: Settings, - kafka_producer: UnifiedProducer, logger: logging.Logger, event_metrics: EventMetrics, - ) -> AsyncIterator[EventStoreConsumer]: + ) -> EventStoreConsumer: topics = get_all_topics() - async with EventStoreConsumer( - event_store=event_store, - topics=list(topics), - schema_registry_manager=schema_registry, - settings=settings, - producer=kafka_producer, - logger=logger, - event_metrics=event_metrics, - ) as consumer: - yield consumer + return EventStoreConsumer( + event_store=event_store, + topics=list(topics), + schema_registry_manager=schema_registry, + settings=settings, + logger=logger, + event_metrics=event_metrics, + ) @provide async def get_event_bus( @@ -411,32 +410,63 @@ async def get_sse_redis_bus(self, redis_client: redis.Redis, logger: logging.Log yield bus @provide - async def get_sse_kafka_redis_bridge( + def get_sse_event_router( + self, + sse_redis_bus: SSERedisBus, + logger: logging.Logger, + ) -> SSEEventRouter: + return SSEEventRouter(sse_bus=sse_redis_bus, logger=logger) + + @provide + def get_sse_consumers( self, + router: SSEEventRouter, schema_registry: SchemaRegistryManager, settings: Settings, event_metrics: EventMetrics, - sse_redis_bus: SSERedisBus, logger: logging.Logger, - ) -> AsyncIterator[SSEKafkaRedisBridge]: - async with SSEKafkaRedisBridge( + ) -> list[UnifiedConsumer]: + """Create SSE consumer pool with routing handlers wired to SSEEventRouter.""" + topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.WEBSOCKET_GATEWAY]) + suffix = settings.KAFKA_GROUP_SUFFIX + consumers: list[UnifiedConsumer] = [] + + for i in range(settings.SSE_CONSUMER_POOL_SIZE): + dispatcher = EventDispatcher(logger=logger) + router.register_handlers(dispatcher) + + config = ConsumerConfig( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"sse-bridge-pool.{suffix}", + client_id=f"sse-bridge-{i}.{suffix}", + enable_auto_commit=True, + auto_offset_reset="latest", + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + consumer = UnifiedConsumer( + config=config, + dispatcher=dispatcher, schema_registry=schema_registry, settings=settings, - event_metrics=event_metrics, - sse_bus=sse_redis_bus, logger=logger, - ) as bridge: - yield bridge + event_metrics=event_metrics, + topics=topics, + ) + consumers.append(consumer) + + return consumers @provide(scope=Scope.REQUEST) - def get_sse_shutdown_manager( + def get_sse_connection_registry( self, - router: SSEKafkaRedisBridge, logger: logging.Logger, connection_metrics: ConnectionMetrics, - ) -> SSEShutdownManager: - return SSEShutdownManager( - router=router, + ) -> SSEConnectionRegistry: + return SSEConnectionRegistry( logger=logger, connection_metrics=connection_metrics, ) @@ -445,18 +475,18 @@ def get_sse_shutdown_manager( def get_sse_service( self, sse_repository: SSERepository, - router: SSEKafkaRedisBridge, + consumers: list[UnifiedConsumer], sse_redis_bus: SSERedisBus, - shutdown_manager: SSEShutdownManager, + connection_registry: SSEConnectionRegistry, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics, ) -> SSEService: return SSEService( repository=sse_repository, - router=router, + num_consumers=len(consumers), sse_bus=sse_redis_bus, - shutdown_manager=shutdown_manager, + connection_registry=connection_registry, settings=settings, logger=logger, connection_metrics=connection_metrics, @@ -550,25 +580,19 @@ def get_admin_settings_service( def get_notification_service( self, notification_repository: NotificationRepository, - kafka_event_service: KafkaEventService, event_bus: EventBus, - schema_registry: SchemaRegistryManager, sse_redis_bus: SSERedisBus, settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, - event_metrics: EventMetrics, ) -> NotificationService: return NotificationService( notification_repository=notification_repository, - event_service=kafka_event_service, event_bus=event_bus, - schema_registry_manager=schema_registry, sse_bus=sse_redis_bus, settings=settings, logger=logger, notification_metrics=notification_metrics, - event_metrics=event_metrics, ) @provide @@ -593,80 +617,26 @@ def _create_default_saga_config() -> SagaConfig: ) -# Standalone factory functions for lifecycle-managed services (eliminates duplication) -async def _provide_saga_orchestrator( - saga_repository: SagaRepository, - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - resource_allocation_repository: ResourceAllocationRepository, - logger: logging.Logger, - event_metrics: EventMetrics, -) -> AsyncIterator[SagaOrchestrator]: - """Shared factory for SagaOrchestrator with lifecycle management.""" - async with SagaOrchestrator( - config=_create_default_saga_config(), - saga_repository=saga_repository, - producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - resource_allocation_repository=resource_allocation_repository, - logger=logger, - event_metrics=event_metrics, - ) as orchestrator: - yield orchestrator - - -async def _provide_execution_coordinator( - kafka_producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - execution_repository: ExecutionRepository, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - event_metrics: EventMetrics, -) -> AsyncIterator[ExecutionCoordinator]: - """Shared factory for ExecutionCoordinator with lifecycle management.""" - async with ExecutionCoordinator( - producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - execution_repository=execution_repository, - idempotency_manager=idempotency_manager, - logger=logger, - coordinator_metrics=coordinator_metrics, - event_metrics=event_metrics, - ) as coordinator: - yield coordinator +# Standalone factory functions for services (no lifecycle - run() handles everything) + + class BusinessServicesProvider(Provider): scope = Scope.REQUEST - def __init__(self) -> None: - super().__init__() - # Register shared factory functions on instance (avoids warning about missing self) - self.provide(_provide_execution_coordinator) - @provide def get_saga_service( self, saga_repository: SagaRepository, execution_repository: ExecutionRepository, - saga_orchestrator: SagaOrchestrator, + saga_logic: SagaLogic, logger: logging.Logger, ) -> SagaService: return SagaService( saga_repo=saga_repository, execution_repo=execution_repository, - orchestrator=saga_orchestrator, + saga_logic=saga_logic, logger=logger, ) @@ -736,37 +706,141 @@ def get_admin_user_service( class CoordinatorProvider(Provider): scope = Scope.APP - def __init__(self) -> None: - super().__init__() - self.provide(_provide_execution_coordinator) + @provide + def get_coordinator_logic( + self, + kafka_producer: UnifiedProducer, + execution_repository: ExecutionRepository, + logger: logging.Logger, + coordinator_metrics: CoordinatorMetrics, + ) -> CoordinatorLogic: + return CoordinatorLogic( + producer=kafka_producer, + execution_repository=execution_repository, + logger=logger, + coordinator_metrics=coordinator_metrics, + ) + + @provide + def get_coordinator_consumer( + self, + logic: CoordinatorLogic, + schema_registry: SchemaRegistryManager, + settings: Settings, + idempotency_manager: IdempotencyManager, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> IdempotentConsumerWrapper: + """Create consumer with handlers wired to CoordinatorLogic.""" + # Create dispatcher and register handlers from logic + dispatcher = EventDispatcher(logger=logger) + logic.register_handlers(dispatcher) + + # Build consumer + consumer_config = ConsumerConfig( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"{GroupId.EXECUTION_COORDINATOR}.{settings.KAFKA_GROUP_SUFFIX}", + enable_auto_commit=False, + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + max_poll_records=100, + fetch_max_wait_ms=500, + fetch_min_bytes=1, + ) + + topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.EXECUTION_COORDINATOR]) + consumer = UnifiedConsumer( + consumer_config, + dispatcher=dispatcher, + schema_registry=schema_registry, + settings=settings, + logger=logger, + event_metrics=event_metrics, + topics=topics, + ) + + return IdempotentConsumerWrapper( + consumer=consumer, + dispatcher=dispatcher, + idempotency_manager=idempotency_manager, + logger=logger, + default_key_strategy="event_based", + default_ttl_seconds=7200, + enable_for_all_handlers=True, + ) class K8sWorkerProvider(Provider): scope = Scope.APP @provide - async def get_kubernetes_worker( + def get_k8s_worker_logic( self, kafka_producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> K8sWorkerLogic: + config = K8sWorkerConfig() + logic = K8sWorkerLogic( + config=config, + producer=kafka_producer, + settings=settings, + logger=logger, + event_metrics=event_metrics, + ) + # Initialize K8s clients synchronously (safe during DI setup) + logic.initialize() + return logic + + @provide + def get_k8s_worker_consumer( + self, + logic: K8sWorkerLogic, schema_registry: SchemaRegistryManager, settings: Settings, - event_store: EventStore, idempotency_manager: IdempotencyManager, logger: logging.Logger, event_metrics: EventMetrics, - ) -> AsyncIterator[KubernetesWorker]: - config = K8sWorkerConfig() - async with KubernetesWorker( - config=config, - producer=kafka_producer, - schema_registry_manager=schema_registry, - settings=settings, - event_store=event_store, - idempotency_manager=idempotency_manager, - logger=logger, - event_metrics=event_metrics, - ) as worker: - yield worker + ) -> IdempotentConsumerWrapper: + """Create consumer with handlers wired to K8sWorkerLogic.""" + # Create dispatcher and register handlers from logic + dispatcher = EventDispatcher(logger=logger) + logic.register_handlers(dispatcher) + + # Build consumer + consumer_config = ConsumerConfig( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"{logic.config.consumer_group}.{settings.KAFKA_GROUP_SUFFIX}", + enable_auto_commit=False, + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.K8S_WORKER]) + consumer = UnifiedConsumer( + consumer_config, + dispatcher=dispatcher, + schema_registry=schema_registry, + settings=settings, + logger=logger, + event_metrics=event_metrics, + topics=topics, + ) + + return IdempotentConsumerWrapper( + consumer=consumer, + dispatcher=dispatcher, + idempotency_manager=idempotency_manager, + logger=logger, + default_key_strategy="content_hash", + default_ttl_seconds=3600, + enable_for_all_handlers=True, + ) class PodMonitorProvider(Provider): @@ -781,32 +855,102 @@ def get_event_mapper( return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) @provide - async def get_pod_monitor( + def get_pod_monitor( self, kafka_event_service: KafkaEventService, k8s_clients: K8sClients, logger: logging.Logger, event_mapper: PodEventMapper, kubernetes_metrics: KubernetesMetrics, - ) -> AsyncIterator[PodMonitor]: + ) -> PodMonitor: config = PodMonitorConfig() - async with PodMonitor( - config=config, - kafka_event_service=kafka_event_service, - logger=logger, - k8s_clients=k8s_clients, - event_mapper=event_mapper, - kubernetes_metrics=kubernetes_metrics, - ) as monitor: - yield monitor + return PodMonitor( + config=config, + kafka_event_service=kafka_event_service, + logger=logger, + k8s_clients=k8s_clients, + event_mapper=event_mapper, + kubernetes_metrics=kubernetes_metrics, + ) class SagaOrchestratorProvider(Provider): scope = Scope.APP - def __init__(self) -> None: - super().__init__() - self.provide(_provide_saga_orchestrator) + @provide + def get_saga_logic( + self, + saga_repository: SagaRepository, + kafka_producer: UnifiedProducer, + resource_allocation_repository: ResourceAllocationRepository, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> SagaLogic: + logic = SagaLogic( + config=_create_default_saga_config(), + saga_repository=saga_repository, + producer=kafka_producer, + resource_allocation_repository=resource_allocation_repository, + logger=logger, + event_metrics=event_metrics, + ) + # Register default sagas + logic.register_default_sagas() + return logic + + @provide + def get_saga_consumer( + self, + logic: SagaLogic, + schema_registry: SchemaRegistryManager, + settings: Settings, + idempotency_manager: IdempotencyManager, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> IdempotentConsumerWrapper | None: + """Create consumer with handlers wired to SagaLogic.""" + # Get topics from registered sagas + topics = logic.get_trigger_topics() + if not topics: + logger.warning("No trigger events found in registered sagas") + return None + + # Create dispatcher and register handlers from logic + dispatcher = EventDispatcher(logger=logger) + logic.register_handlers(dispatcher) + + # Build consumer + consumer_config = ConsumerConfig( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"saga-{logic.config.name}.{settings.KAFKA_GROUP_SUFFIX}", + enable_auto_commit=False, + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + consumer = UnifiedConsumer( + config=consumer_config, + dispatcher=dispatcher, + schema_registry=schema_registry, + settings=settings, + logger=logger, + event_metrics=event_metrics, + topics=list(topics), + ) + + logger.info(f"Saga consumer configured for topics: {topics}") + + return IdempotentConsumerWrapper( + consumer=consumer, + dispatcher=dispatcher, + idempotency_manager=idempotency_manager, + logger=logger, + default_key_strategy="event_based", + default_ttl_seconds=7200, + enable_for_all_handlers=False, + ) class EventReplayProvider(Provider): @@ -828,3 +972,73 @@ def get_event_replay_service( settings=settings, logger=logger, ) + + +class ResultProcessorProvider(Provider): + scope = Scope.APP + + @provide + def get_processor_logic( + self, + execution_repo: ExecutionRepository, + kafka_producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + execution_metrics: ExecutionMetrics, + ) -> ProcessorLogic: + return ProcessorLogic( + execution_repo=execution_repo, + producer=kafka_producer, + settings=settings, + logger=logger, + execution_metrics=execution_metrics, + ) + + @provide + def get_processor_consumer( + self, + logic: ProcessorLogic, + schema_registry: SchemaRegistryManager, + settings: Settings, + idempotency_manager: IdempotencyManager, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> IdempotentConsumerWrapper: + """Create consumer with handlers wired to ProcessorLogic.""" + # Create dispatcher and register handlers from logic + dispatcher = EventDispatcher(logger=logger) + logic.register_handlers(dispatcher) + + # Build consumer + consumer_config = ConsumerConfig( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"{GroupId.RESULT_PROCESSOR}.{settings.KAFKA_GROUP_SUFFIX}", + max_poll_records=1, + enable_auto_commit=True, + auto_offset_reset="earliest", + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.RESULT_PROCESSOR]) + consumer = UnifiedConsumer( + consumer_config, + dispatcher=dispatcher, + schema_registry=schema_registry, + settings=settings, + logger=logger, + event_metrics=event_metrics, + topics=topics, + ) + + return IdempotentConsumerWrapper( + consumer=consumer, + dispatcher=dispatcher, + idempotency_manager=idempotency_manager, + logger=logger, + default_key_strategy="content_hash", + default_ttl_seconds=7200, + enable_for_all_handlers=True, + ) diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py index bb37a134..8a051429 100644 --- a/backend/app/events/core/consumer.py +++ b/backend/app/events/core/consumer.py @@ -1,11 +1,7 @@ -import asyncio import logging from collections.abc import Awaitable, Callable -from datetime import datetime, timezone -from typing import Any from aiokafka import AIOKafkaConsumer, TopicPartition -from aiokafka.errors import KafkaError from opentelemetry.trace import SpanKind from app.core.metrics import EventMetrics @@ -17,38 +13,48 @@ from app.settings import Settings from .dispatcher import EventDispatcher -from .types import ConsumerConfig, ConsumerMetrics, ConsumerMetricsSnapshot, ConsumerState, ConsumerStatus +from .types import ConsumerConfig class UnifiedConsumer: + """Kafka consumer with framework-style run(). + + No loops in user code. Register handlers, call run(), handlers get called. + + Usage: + dispatcher = EventDispatcher() + dispatcher.register(EventType.FOO, handle_foo) + + consumer = UnifiedConsumer(..., dispatcher=dispatcher) + await consumer.run() # Blocks, calls handlers when events arrive + """ + def __init__( self, config: ConsumerConfig, - event_dispatcher: EventDispatcher, + dispatcher: EventDispatcher, schema_registry: SchemaRegistryManager, settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, + topics: list[KafkaTopic], + error_callback: Callable[[Exception, DomainEvent], Awaitable[None]] | None = None, ): self._config = config - self.logger = logger + self._dispatcher = dispatcher self._schema_registry = schema_registry - self._dispatcher = event_dispatcher - self._consumer: AIOKafkaConsumer | None = None - self._state = ConsumerState.STOPPED - self._metrics = ConsumerMetrics() self._event_metrics = event_metrics - self._error_callback: "Callable[[Exception, DomainEvent], Awaitable[None]] | None" = None - self._consume_task: asyncio.Task[None] | None = None - self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - - async def start(self, topics: list[KafkaTopic]) -> None: - self._state = self._state if self._state != ConsumerState.STOPPED else ConsumerState.STARTING + self._topics = [f"{settings.KAFKA_TOPIC_PREFIX}{t}" for t in topics] + self._error_callback = error_callback + self.logger = logger + self._consumer: AIOKafkaConsumer | None = None - topic_strings = [f"{self._topic_prefix}{str(topic)}" for topic in topics] + async def run(self) -> None: + """Run the consumer. Blocks until stopped. Calls registered handlers.""" + tracer = get_tracer() self._consumer = AIOKafkaConsumer( - *topic_strings, + *self._topics, bootstrap_servers=self._config.bootstrap_servers, group_id=self._config.group_id, client_id=self._config.client_id, @@ -63,189 +69,58 @@ async def start(self, topics: list[KafkaTopic]) -> None: ) await self._consumer.start() - self._consume_task = asyncio.create_task(self._consume_loop()) - - self._state = ConsumerState.RUNNING - - self.logger.info(f"Consumer started for topics: {topic_strings}") - - async def stop(self) -> None: - self._state = ( - ConsumerState.STOPPING - if self._state not in (ConsumerState.STOPPED, ConsumerState.STOPPING) - else self._state - ) - - if self._consume_task: - self._consume_task.cancel() - await asyncio.gather(self._consume_task, return_exceptions=True) - self._consume_task = None - - await self._cleanup() - self._state = ConsumerState.STOPPED - - async def _cleanup(self) -> None: - if self._consumer: - await self._consumer.stop() - self._consumer = None - - async def _consume_loop(self) -> None: - self.logger.info(f"Consumer loop started for group {self._config.group_id}") - poll_count = 0 - message_count = 0 + self.logger.info(f"Consumer running for topics: {self._topics}") try: - while True: - if not self._consumer: - break - - poll_count += 1 - if poll_count % 100 == 0: # Log every 100 polls - self.logger.debug(f"Consumer loop active: polls={poll_count}, messages={message_count}") + async for msg in self._consumer: + if not msg.value: + continue try: - # Use getone() with timeout for single message consumption - msg = await asyncio.wait_for( - self._consumer.getone(), - timeout=0.1 - ) + event = await self._schema_registry.deserialize_event(msg.value, msg.topic) + + headers = {k: v.decode() if isinstance(v, bytes) else v for k, v in (msg.headers or [])} + ctx = extract_trace_context(headers) + + with tracer.start_as_current_span( + "kafka.consume", + context=ctx, + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: msg.topic, + EventAttributes.KAFKA_PARTITION: msg.partition, + EventAttributes.KAFKA_OFFSET: msg.offset, + EventAttributes.EVENT_TYPE: event.event_type, + EventAttributes.EVENT_ID: event.event_id, + }, + ): + await self._dispatcher.dispatch(event) - message_count += 1 - self.logger.debug( - f"Message received from topic {msg.topic}, partition {msg.partition}, offset {msg.offset}" - ) - await self._process_message(msg) if not self._config.enable_auto_commit: await self._consumer.commit() - except asyncio.TimeoutError: - # No message available within timeout, continue polling - await asyncio.sleep(0.01) - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - self._metrics.processing_errors += 1 - - except asyncio.CancelledError: - self.logger.info(f"Consumer loop cancelled for group {self._config.group_id}") - - async def _process_message(self, message: Any) -> None: - """Process a ConsumerRecord from aiokafka.""" - topic = message.topic - if not topic: - self.logger.warning("Message with no topic received") - return - - raw_value = message.value - if not raw_value: - self.logger.warning(f"Empty message from topic {topic}") - return - - self.logger.debug(f"Deserializing message from topic {topic}, size={len(raw_value)} bytes") - event = await self._schema_registry.deserialize_event(raw_value, topic) - self.logger.info(f"Deserialized event: type={event.event_type}, id={event.event_id}") - - # Extract trace context from Kafka headers and start a consumer span - # aiokafka headers are list of tuples: [(key, value), ...] - header_list = message.headers or [] - headers: dict[str, str] = {} - for k, v in header_list: - headers[str(k)] = v.decode("utf-8") if isinstance(v, (bytes, bytearray)) else (v or "") - ctx = extract_trace_context(headers) - tracer = get_tracer() + self._event_metrics.record_kafka_message_consumed(msg.topic, self._config.group_id) - # Dispatch event through EventDispatcher - try: - self.logger.debug(f"Dispatching {event.event_type} to handlers") - partition_val = message.partition - offset_val = message.offset - part_attr = partition_val if partition_val is not None else -1 - off_attr = offset_val if offset_val is not None else -1 - with tracer.start_as_current_span( - name="kafka.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: topic, - EventAttributes.KAFKA_PARTITION: part_attr, - EventAttributes.KAFKA_OFFSET: off_attr, - EventAttributes.EVENT_TYPE: event.event_type, - EventAttributes.EVENT_ID: event.event_id, - }, - ): - await self._dispatcher.dispatch(event) - self.logger.debug(f"Successfully dispatched {event.event_type}") - # Update metrics on successful dispatch - self._metrics.messages_consumed += 1 - self._metrics.bytes_consumed += len(raw_value) - self._metrics.last_message_time = datetime.now(timezone.utc) - # Record Kafka consumption metrics - self._event_metrics.record_kafka_message_consumed(topic=topic, consumer_group=self._config.group_id) - except Exception as e: - self.logger.error(f"Dispatcher error for event {event.event_type}: {e}") - self._metrics.processing_errors += 1 - # Record Kafka consumption error - self._event_metrics.record_kafka_consumption_error( - topic=topic, consumer_group=self._config.group_id, error_type=type(e).__name__ - ) - if self._error_callback: - await self._error_callback(e, event) - - def register_error_callback(self, callback: Callable[[Exception, DomainEvent], Awaitable[None]]) -> None: - self._error_callback = callback - - @property - def state(self) -> ConsumerState: - return self._state - - @property - def metrics(self) -> ConsumerMetrics: - return self._metrics - - @property - def consumer(self) -> AIOKafkaConsumer | None: - return self._consumer - - def get_status(self) -> ConsumerStatus: - return ConsumerStatus( - state=self._state, - group_id=self._config.group_id, - client_id=self._config.client_id, - metrics=ConsumerMetricsSnapshot( - messages_consumed=self._metrics.messages_consumed, - bytes_consumed=self._metrics.bytes_consumed, - consumer_lag=self._metrics.consumer_lag, - commit_failures=self._metrics.commit_failures, - processing_errors=self._metrics.processing_errors, - last_message_time=self._metrics.last_message_time, - last_updated=self._metrics.last_updated, - ), - ) + except Exception as e: + self.logger.error(f"Error processing message: {e}", exc_info=True) + self._event_metrics.record_kafka_consumption_error( + msg.topic, self._config.group_id, type(e).__name__ + ) + if self._error_callback: + await self._error_callback(e, event) - async def seek_to_beginning(self) -> None: - """Seek all assigned partitions to the beginning.""" - if not self._consumer: - self.logger.warning("Cannot seek: consumer not initialized") - return + finally: + await self._consumer.stop() + self.logger.info("Consumer stopped") - assignment = self._consumer.assignment() - if assignment: + async def seek_to_beginning(self) -> None: + if self._consumer and (assignment := self._consumer.assignment()): await self._consumer.seek_to_beginning(*assignment) async def seek_to_end(self) -> None: - """Seek all assigned partitions to the end.""" - if not self._consumer: - self.logger.warning("Cannot seek: consumer not initialized") - return - - assignment = self._consumer.assignment() - if assignment: + if self._consumer and (assignment := self._consumer.assignment()): await self._consumer.seek_to_end(*assignment) async def seek_to_offset(self, topic: str, partition: int, offset: int) -> None: - """Seek a specific partition to a specific offset.""" - if not self._consumer: - self.logger.warning("Cannot seek to offset: consumer not initialized") - return - - tp = TopicPartition(topic, partition) - self._consumer.seek(tp, offset) + if self._consumer: + self._consumer.seek(TopicPartition(topic, partition), offset) diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py index c7b712ac..01764c82 100644 --- a/backend/app/events/event_store_consumer.py +++ b/backend/app/events/event_store_consumer.py @@ -1,21 +1,28 @@ import asyncio import logging +from aiokafka import AIOKafkaConsumer from opentelemetry.trace import SpanKind from app.core.metrics import EventMetrics from app.core.tracing.utils import trace_span -from app.domain.enums.events import EventType from app.domain.enums.kafka import GroupId, KafkaTopic from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer, create_dlq_error_handler from app.events.event_store import EventStore from app.events.schema.schema_registry import SchemaRegistryManager from app.settings import Settings class EventStoreConsumer: - """Consumes events from Kafka and stores them in MongoDB.""" + """Consumes events from Kafka and stores them in MongoDB. + + Uses Kafka's native batching via getmany() - no application-level buffering. + Kafka's fetch_max_wait_ms controls batch timing at the protocol level. + + Usage: + consumer = EventStoreConsumer(...) + await consumer.run() # Blocks until cancelled + """ def __init__( self, @@ -25,32 +32,33 @@ def __init__( settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, - producer: UnifiedProducer | None = None, group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, batch_size: int = 100, - batch_timeout_seconds: float = 5.0, + batch_timeout_ms: int = 5000, ): + """Store dependencies. All work happens in run().""" self.event_store = event_store self.topics = topics self.settings = settings self.group_id = group_id self.batch_size = batch_size - self.batch_timeout = batch_timeout_seconds + self.batch_timeout_ms = batch_timeout_ms self.logger = logger self.event_metrics = event_metrics - self.consumer: UnifiedConsumer | None = None self.schema_registry_manager = schema_registry_manager - self.dispatcher = EventDispatcher(logger) - self.producer = producer # For DLQ handling - self._batch_buffer: list[DomainEvent] = [] - self._batch_lock = asyncio.Lock() - self._last_batch_time: float = 0.0 - self._batch_task: asyncio.Task[None] | None = None - - async def __aenter__(self) -> "EventStoreConsumer": - """Start consuming and storing events.""" - self._last_batch_time = asyncio.get_running_loop().time() - config = ConsumerConfig( + + async def run(self) -> None: + """Run the consumer. Blocks until cancelled. + + Creates consumer, starts it, runs batch loop, stops on cancellation. + Uses getmany() which blocks on Kafka's fetch - no polling, no timers. + """ + self.logger.info("Event store consumer starting...") + + topic_strings = [f"{self.settings.KAFKA_TOPIC_PREFIX}{topic}" for topic in self.topics] + + consumer = AIOKafkaConsumer( + *topic_strings, bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, group_id=f"{self.group_id}.{self.settings.KAFKA_GROUP_SUFFIX}", enable_auto_commit=False, @@ -59,102 +67,58 @@ async def __aenter__(self) -> "EventStoreConsumer": heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, + fetch_max_wait_ms=self.batch_timeout_ms, ) - self.consumer = UnifiedConsumer( - config, - event_dispatcher=self.dispatcher, - schema_registry=self.schema_registry_manager, - settings=self.settings, - logger=self.logger, - event_metrics=self.event_metrics, - ) - - # Register handler for all event types - store everything - for event_type in EventType: - self.dispatcher.register(event_type)(self._handle_event) - - # Register error callback - use DLQ if producer is available - if self.producer: - # Use DLQ handler with retry logic - dlq_handler = create_dlq_error_handler( - producer=self.producer, - original_topic="event-store", # Generic topic name for event store - logger=self.logger, - max_retries=3, - ) - self.consumer.register_error_callback(dlq_handler) - else: - # Fallback to simple logging - self.consumer.register_error_callback(self._handle_error_with_event) - - await self.consumer.start(self.topics) - - self._batch_task = asyncio.create_task(self._batch_processor()) - - self.logger.info(f"Event store consumer started for topics: {self.topics}") - return self + await consumer.start() + self.logger.info(f"Event store consumer initialized for topics: {topic_strings}") - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop consumer.""" - await self._flush_batch() - - if self._batch_task: - self._batch_task.cancel() - try: - await self._batch_task - except asyncio.CancelledError: - pass - - if self.consumer: - await self.consumer.stop() - - self.logger.info("Event store consumer stopped") - - async def _handle_event(self, event: DomainEvent) -> None: - """Handle incoming event from dispatcher.""" - self.logger.info(f"Event store received event: {event.event_type} - {event.event_id}") - - async with self._batch_lock: - self._batch_buffer.append(event) - - if len(self._batch_buffer) >= self.batch_size: - await self._flush_batch() - - async def _handle_error_with_event(self, error: Exception, event: DomainEvent) -> None: - """Handle processing errors with event context.""" - self.logger.error(f"Error processing event {event.event_id} ({event.event_type}): {error}", exc_info=True) - - async def _batch_processor(self) -> None: - """Periodically flush batches based on timeout.""" try: while True: - await asyncio.sleep(1) - - async with self._batch_lock: - time_since_last_batch = asyncio.get_running_loop().time() - self._last_batch_time - - if self._batch_buffer and time_since_last_batch >= self.batch_timeout: - await self._flush_batch() + # getmany() blocks until Kafka has data OR fetch_max_wait_ms expires + # This is NOT polling - it's async waiting on the network socket + batch_data = await consumer.getmany( + timeout_ms=self.batch_timeout_ms, + max_records=self.batch_size, + ) + + if not batch_data: + continue + + # Deserialize all messages in the batch + events: list[DomainEvent] = [] + for tp, messages in batch_data.items(): + for msg in messages: + try: + event = await self.schema_registry_manager.deserialize_event(msg.value, msg.topic) + events.append(event) + self.event_metrics.record_kafka_message_consumed( + topic=msg.topic, + consumer_group=str(self.group_id), + ) + except Exception as e: + self.logger.error(f"Failed to deserialize message from {tp}: {e}", exc_info=True) + + if events: + await self._store_batch(events) + await consumer.commit() except asyncio.CancelledError: - self.logger.info("Batch processor cancelled") - - async def _flush_batch(self) -> None: - if not self._batch_buffer: - return + self.logger.info("Event store consumer cancelled") + finally: + await consumer.stop() + self.logger.info("Event store consumer stopped") - batch = self._batch_buffer.copy() - self._batch_buffer.clear() - self._last_batch_time = asyncio.get_running_loop().time() + async def _store_batch(self, events: list[DomainEvent]) -> None: + """Store a batch of events.""" + self.logger.info(f"Storing batch of {len(events)} events") - self.logger.info(f"Event store flushing batch of {len(batch)} events") with trace_span( - name="event_store.flush_batch", + name="event_store.store_batch", kind=SpanKind.CONSUMER, - attributes={"events.batch.count": len(batch)}, + attributes={"events.batch.count": len(events)}, ): - results = await self.event_store.store_batch(batch) + results = await self.event_store.store_batch(events) self.logger.info( f"Stored event batch: total={results['total']}, " diff --git a/backend/app/services/coordinator/__init__.py b/backend/app/services/coordinator/__init__.py index b3890c9d..c3fd1ffb 100644 --- a/backend/app/services/coordinator/__init__.py +++ b/backend/app/services/coordinator/__init__.py @@ -1,9 +1,9 @@ -from app.services.coordinator.coordinator import ExecutionCoordinator +from app.services.coordinator.coordinator_logic import CoordinatorLogic from app.services.coordinator.queue_manager import QueueManager, QueuePriority from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager __all__ = [ - "ExecutionCoordinator", + "CoordinatorLogic", "QueueManager", "QueuePriority", "ResourceManager", diff --git a/backend/app/services/coordinator/coordinator.py b/backend/app/services/coordinator/coordinator_logic.py similarity index 70% rename from backend/app/services/coordinator/coordinator.py rename to backend/app/services/coordinator/coordinator_logic.py index 4bc09a69..528983d6 100644 --- a/backend/app/services/coordinator/coordinator.py +++ b/backend/app/services/coordinator/coordinator_logic.py @@ -1,17 +1,16 @@ import asyncio import logging import time -from collections.abc import Coroutine from typing import Any, TypeAlias from uuid import uuid4 -from app.core.metrics import CoordinatorMetrics, EventMetrics +from app.core.metrics import CoordinatorMetrics from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, + DomainEvent, EventMetadata, ExecutionAcceptedEvent, ExecutionCancelledEvent, @@ -19,56 +18,37 @@ ExecutionFailedEvent, ExecutionRequestedEvent, ) -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import ( - SchemaRegistryManager, -) +from app.events.core import EventDispatcher, UnifiedProducer from app.services.coordinator.queue_manager import QueueManager, QueuePriority from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper -from app.settings import Settings -EventHandler: TypeAlias = Coroutine[Any, Any, None] ExecutionMap: TypeAlias = dict[str, ResourceAllocation] -class ExecutionCoordinator: +class CoordinatorLogic: """ - Coordinates execution scheduling across the system. - - This service: - 1. Consumes ExecutionRequested events - 2. Manages execution queue with priority - 3. Enforces rate limits - 4. Allocates resources - 5. Publishes ExecutionStarted events for workers + Business logic for execution coordination. + + Handles: + - Execution request queuing and validation + - Resource allocation and management + - Scheduling loop for processing queued executions + - Event publishing (ExecutionAccepted, CreatePodCommand, ExecutionFailed) + + This class is stateful and must be instantiated once per coordinator instance. """ def __init__( self, producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, execution_repository: ExecutionRepository, - idempotency_manager: IdempotencyManager, logger: logging.Logger, coordinator_metrics: CoordinatorMetrics, - event_metrics: EventMetrics, - consumer_group: str = GroupId.EXECUTION_COORDINATOR, max_concurrent_scheduling: int = 10, scheduling_interval_seconds: float = 0.5, ): self.logger = logger self.metrics = coordinator_metrics - self._event_metrics = event_metrics - self._settings = settings - - # Kafka configuration - self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS - self.consumer_group = consumer_group # Components self.queue_manager = QueueManager( @@ -87,15 +67,9 @@ def __init__( total_gpu_count=0, ) - # Kafka components - self.consumer: UnifiedConsumer | None = None - self.idempotent_consumer: IdempotentConsumerWrapper | None = None - self.producer: UnifiedProducer = producer - - # Persistence via repositories + # Kafka producer (injected, lifecycle managed by DI) + self.producer = producer self.execution_repository = execution_repository - self.idempotency_manager = idempotency_manager - self._event_store = event_store # Scheduling self.max_concurrent_scheduling = max_concurrent_scheduling @@ -103,103 +77,30 @@ def __init__( self._scheduling_semaphore = asyncio.Semaphore(max_concurrent_scheduling) # State tracking - self._scheduling_task: asyncio.Task[None] | None = None self._active_executions: set[str] = set() self._execution_resources: ExecutionMap = {} - self._schema_registry_manager = schema_registry_manager - self.dispatcher = EventDispatcher(logger=self.logger) - - async def __aenter__(self) -> "ExecutionCoordinator": - """Start the coordinator service.""" - self.logger.info("Starting ExecutionCoordinator service...") - - self.logger.info("Queue manager initialized") - - await self.idempotency_manager.initialize() - - consumer_config = ConsumerConfig( - bootstrap_servers=self.kafka_servers, - group_id=f"{self.consumer_group}.{self._settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - max_poll_records=100, # Process max 100 messages at a time for flow control - fetch_max_wait_ms=500, # Wait max 500ms for data (reduces latency) - fetch_min_bytes=1, # Return immediately if any data available - ) - self.consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self.dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) + def register_handlers(self, dispatcher: EventDispatcher) -> None: + """Register event handlers with the dispatcher.""" - # Register handlers with EventDispatcher BEFORE wrapping with idempotency - @self.dispatcher.register(EventType.EXECUTION_REQUESTED) + @dispatcher.register(EventType.EXECUTION_REQUESTED) async def handle_requested(event: ExecutionRequestedEvent) -> None: await self._route_execution_event(event) - @self.dispatcher.register(EventType.EXECUTION_COMPLETED) + @dispatcher.register(EventType.EXECUTION_COMPLETED) async def handle_completed(event: ExecutionCompletedEvent) -> None: await self._route_execution_result(event) - @self.dispatcher.register(EventType.EXECUTION_FAILED) + @dispatcher.register(EventType.EXECUTION_FAILED) async def handle_failed(event: ExecutionFailedEvent) -> None: await self._route_execution_result(event) - @self.dispatcher.register(EventType.EXECUTION_CANCELLED) + @dispatcher.register(EventType.EXECUTION_CANCELLED) async def handle_cancelled(event: ExecutionCancelledEvent) -> None: await self._route_execution_event(event) - self.idempotent_consumer = IdempotentConsumerWrapper( - consumer=self.consumer, - idempotency_manager=self.idempotency_manager, - dispatcher=self.dispatcher, - logger=self.logger, - default_key_strategy="event_based", # Use event ID for deduplication - default_ttl_seconds=7200, # 2 hours TTL for coordinator events - enable_for_all_handlers=True, # Enable idempotency for ALL handlers - ) - - self.logger.info("COORDINATOR: Event handlers registered with idempotency protection") - - await self.idempotent_consumer.start(list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.EXECUTION_COORDINATOR])) - - # Start scheduling task - self._scheduling_task = asyncio.create_task(self._scheduling_loop()) - - self.logger.info("ExecutionCoordinator service started successfully") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the coordinator service.""" - self.logger.info("Stopping ExecutionCoordinator service...") - - # Stop scheduling task - if self._scheduling_task: - self._scheduling_task.cancel() - try: - await self._scheduling_task - except asyncio.CancelledError: - pass - - # Stop consumer (idempotent wrapper only) - if self.idempotent_consumer: - await self.idempotent_consumer.stop() - - # Close idempotency manager - if hasattr(self, "idempotency_manager") and self.idempotency_manager: - await self.idempotency_manager.close() - - self.logger.info(f"ExecutionCoordinator service stopped. Active executions: {len(self._active_executions)}") - - async def _route_execution_event(self, event: ExecutionRequestedEvent | ExecutionCancelledEvent) -> None: - """Route execution events to appropriate handlers based on event type""" + async def _route_execution_event(self, event: DomainEvent) -> None: + """Route execution events to appropriate handlers based on event type.""" self.logger.info( f"COORDINATOR: Routing execution event - type: {event.event_type}, " f"id: {event.event_id}, " @@ -213,8 +114,8 @@ async def _route_execution_event(self, event: ExecutionRequestedEvent | Executio else: self.logger.debug(f"Ignoring execution event type: {event.event_type}") - async def _route_execution_result(self, event: ExecutionCompletedEvent | ExecutionFailedEvent) -> None: - """Route execution result events to appropriate handlers based on event type""" + async def _route_execution_result(self, event: DomainEvent) -> None: + """Route execution result events to appropriate handlers based on event type.""" if event.event_type == EventType.EXECUTION_COMPLETED: await self._handle_execution_completed(event) elif event.event_type == EventType.EXECUTION_FAILED: @@ -223,7 +124,7 @@ async def _route_execution_result(self, event: ExecutionCompletedEvent | Executi self.logger.debug(f"Ignoring execution result event type: {event.event_type}") async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: - """Handle execution requested event - add to queue for processing""" + """Handle execution requested event - add to queue for processing.""" self.logger.info(f"HANDLER CALLED: _handle_execution_requested for event {event.event_id}") start_time = time.time() @@ -261,7 +162,7 @@ async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> N self.metrics.record_coordinator_execution_scheduled("error") async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: - """Handle execution cancelled event""" + """Handle execution cancelled event.""" execution_id = event.execution_id removed = await self.queue_manager.remove_execution(execution_id) @@ -277,7 +178,7 @@ async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> N self.logger.info(f"Execution {execution_id} cancelled and removed from queue") async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: - """Handle execution completed event""" + """Handle execution completed event.""" execution_id = event.execution_id if execution_id in self._execution_resources: @@ -291,7 +192,7 @@ async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> N self.logger.info(f"Execution {execution_id} completed, resources released") async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle execution failed event""" + """Handle execution failed event.""" execution_id = event.execution_id # Release resources @@ -303,8 +204,9 @@ async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: self._active_executions.discard(execution_id) self.metrics.update_coordinator_active_executions(len(self._active_executions)) - async def _scheduling_loop(self) -> None: - """Main scheduling loop""" + async def scheduling_loop(self) -> None: + """Main scheduling loop - processes queued executions.""" + self.logger.info("Scheduling loop started") try: while True: try: @@ -327,13 +229,13 @@ async def _scheduling_loop(self) -> None: self.logger.info("Scheduling loop cancelled") async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: - """Schedule a single execution""" + """Schedule a single execution.""" async with self._scheduling_semaphore: start_time = time.time() execution_id = event.execution_id # Atomic check-and-claim: no await between check and add prevents TOCTOU race - # when both eager scheduling (position=0) and _scheduling_loop try to schedule + # when both eager scheduling (position=0) and scheduling_loop try to schedule if execution_id in self._active_executions: self.logger.debug(f"Execution {execution_id} already claimed, skipping") return @@ -374,8 +276,8 @@ async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: # Track metrics queue_time = start_time - event.timestamp.timestamp() - priority = getattr(event, "priority", QueuePriority.NORMAL.value) - self.metrics.record_coordinator_queue_time(queue_time, QueuePriority(priority).name) + priority = getattr(event, "priority", QueuePriority.NORMAL) + self.metrics.record_coordinator_queue_time(queue_time, priority.name) scheduling_duration = time.time() - start_time self.metrics.record_coordinator_scheduling_duration(scheduling_duration) @@ -417,7 +319,7 @@ async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> Eve ) async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> None: - """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic""" + """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic.""" metadata = await self._build_command_metadata(request) create_pod_cmd = CreatePodCommandEvent( @@ -440,8 +342,10 @@ async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> await self.producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) - async def _publish_execution_accepted(self, request: ExecutionRequestedEvent, position: int, priority: int) -> None: - """Publish execution accepted event to notify that request was valid and queued""" + async def _publish_execution_accepted( + self, request: ExecutionRequestedEvent, position: int, priority: int + ) -> None: + """Publish execution accepted event to notify that request was valid and queued.""" self.logger.info(f"Publishing ExecutionAcceptedEvent for execution {request.execution_id}") event = ExecutionAcceptedEvent( @@ -456,7 +360,7 @@ async def _publish_execution_accepted(self, request: ExecutionRequestedEvent, po self.logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish queue full event""" + """Publish queue full event.""" # Get queue stats for context queue_stats = await self.queue_manager.get_queue_stats() @@ -473,7 +377,7 @@ async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str await self.producer.produce(event_to_produce=event, key=request.execution_id) async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish scheduling failed event""" + """Publish scheduling failed event.""" # Get resource stats for context resource_stats = await self.resource_manager.get_resource_stats() @@ -492,9 +396,8 @@ async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, err await self.producer.produce(event_to_produce=event, key=request.execution_id) async def get_status(self) -> dict[str, Any]: - """Get coordinator status""" + """Get coordinator status.""" return { - "scheduling_task_active": self._scheduling_task is not None and not self._scheduling_task.done(), "active_executions": len(self._active_executions), "queue_stats": await self.queue_manager.get_queue_stats(), "resource_stats": await self.resource_manager.get_resource_stats(), diff --git a/backend/app/services/idempotency/__init__.py b/backend/app/services/idempotency/__init__.py index 82af12f0..fc82d7d3 100644 --- a/backend/app/services/idempotency/__init__.py +++ b/backend/app/services/idempotency/__init__.py @@ -6,7 +6,7 @@ IdempotencyResult, create_idempotency_manager, ) -from app.services.idempotency.middleware import IdempotentConsumerWrapper, IdempotentEventHandler, idempotent_handler +from app.services.idempotency.middleware import IdempotentConsumerWrapper, IdempotentEventHandler __all__ = [ "IdempotencyManager", @@ -16,6 +16,5 @@ "IdempotencyKeyStrategy", "create_idempotency_manager", "IdempotentEventHandler", - "idempotent_handler", "IdempotentConsumerWrapper", ] diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py index 689897d5..a6fc772f 100644 --- a/backend/app/services/idempotency/middleware.py +++ b/backend/app/services/idempotency/middleware.py @@ -2,17 +2,17 @@ import asyncio import logging -from typing import Any, Awaitable, Callable, Dict, Set +from collections.abc import Awaitable, Callable +from typing import Any from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DomainEvent from app.events.core import EventDispatcher, UnifiedConsumer from app.services.idempotency.idempotency_manager import IdempotencyManager class IdempotentEventHandler: - """Wrapper for event handlers with idempotency support""" + """Wrapper for event handlers with idempotency support.""" def __init__( self, @@ -21,9 +21,8 @@ def __init__( logger: logging.Logger, key_strategy: str = "event_based", custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: Set[str] | None = None, + fields: set[str] | None = None, ttl_seconds: int | None = None, - cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, ): self.handler = handler @@ -33,22 +32,12 @@ def __init__( self.custom_key_func = custom_key_func self.fields = fields self.ttl_seconds = ttl_seconds - self.cache_result = cache_result self.on_duplicate = on_duplicate async def __call__(self, event: DomainEvent) -> None: - """Process event with idempotency check""" - self.logger.info( - f"IdempotentEventHandler called for event {event.event_type}, " - f"id={event.event_id}, handler={self.handler.__name__}" - ) - # Generate custom key if function provided - custom_key = None - if self.key_strategy == "custom" and self.custom_key_func: - custom_key = self.custom_key_func(event) + custom_key = self.custom_key_func(event) if self.key_strategy == "custom" and self.custom_key_func else None - # Check idempotency - idempotency_result = await self.idempotency_manager.check_and_reserve( + result = await self.idempotency_manager.check_and_reserve( event=event, key_strategy=self.key_strategy, custom_key=custom_key, @@ -56,224 +45,100 @@ async def __call__(self, event: DomainEvent) -> None: fields=self.fields, ) - if idempotency_result.is_duplicate: - # Handle duplicate - self.logger.info( - f"Duplicate event detected: {event.event_type} ({event.event_id}), status: {idempotency_result.status}" - ) - - # Call duplicate handler if provided + if result.is_duplicate: + self.logger.info(f"Duplicate event: {event.event_type} ({event.event_id})") if self.on_duplicate: if asyncio.iscoroutinefunction(self.on_duplicate): - await self.on_duplicate(event, idempotency_result) + await self.on_duplicate(event, result) else: - await asyncio.to_thread(self.on_duplicate, event, idempotency_result) - - # For duplicate, just return without error + await asyncio.to_thread(self.on_duplicate, event, result) return - # Not a duplicate, process the event try: - # Call the actual handler - it returns None await self.handler(event) - - # Mark as completed await self.idempotency_manager.mark_completed( event=event, key_strategy=self.key_strategy, custom_key=custom_key, fields=self.fields ) - except Exception as e: - # Mark as failed await self.idempotency_manager.mark_failed( event=event, error=str(e), key_strategy=self.key_strategy, custom_key=custom_key, fields=self.fields ) raise -def idempotent_handler( - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - key_strategy: str = "event_based", - custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: Set[str] | None = None, - ttl_seconds: int | None = None, - cache_result: bool = True, - on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, -) -> Callable[[Callable[[DomainEvent], Awaitable[None]]], Callable[[DomainEvent], Awaitable[None]]]: - """Decorator for making event handlers idempotent""" - - def decorator(func: Callable[[DomainEvent], Awaitable[None]]) -> Callable[[DomainEvent], Awaitable[None]]: - handler = IdempotentEventHandler( - handler=func, - idempotency_manager=idempotency_manager, - logger=logger, - key_strategy=key_strategy, - custom_key_func=custom_key_func, - fields=fields, - ttl_seconds=ttl_seconds, - cache_result=cache_result, - on_duplicate=on_duplicate, - ) - return handler # IdempotentEventHandler is already callable with the right signature - - return decorator +class IdempotentConsumerWrapper: + """Wrapper for UnifiedConsumer with automatic idempotency. + Usage: + dispatcher = EventDispatcher() + dispatcher.register(EventType.FOO, handle_foo) -class IdempotentConsumerWrapper: - """Wrapper for Kafka consumer with automatic idempotency""" + consumer = UnifiedConsumer(..., dispatcher=dispatcher) + wrapper = IdempotentConsumerWrapper(consumer, dispatcher, idempotency_manager, ...) + await wrapper.run() # Handlers are wrapped with idempotency, then consumer runs + """ def __init__( self, consumer: UnifiedConsumer, - idempotency_manager: IdempotencyManager, dispatcher: EventDispatcher, + idempotency_manager: IdempotencyManager, logger: logging.Logger, default_key_strategy: str = "event_based", default_ttl_seconds: int = 3600, enable_for_all_handlers: bool = True, ): - self.consumer = consumer - self.idempotency_manager = idempotency_manager - self.dispatcher = dispatcher - self.logger = logger - self.default_key_strategy = default_key_strategy - self.default_ttl_seconds = default_ttl_seconds - self.enable_for_all_handlers = enable_for_all_handlers - self._original_handlers: Dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = {} - - def make_handlers_idempotent(self) -> None: - """Wrap all registered handlers with idempotency""" - self.logger.info( - f"make_handlers_idempotent called: enable_for_all={self.enable_for_all_handlers}, " - f"dispatcher={self.dispatcher is not None}" - ) - if not self.enable_for_all_handlers or not self.dispatcher: - self.logger.warning("Skipping handler wrapping - conditions not met") - return - - # Store original handlers using public API - self._original_handlers = self.dispatcher.get_all_handlers() - self.logger.info(f"Got {len(self._original_handlers)} event types with handlers to wrap") - - # Wrap each handler - for event_type, handlers in self._original_handlers.items(): - wrapped_handlers: list[Callable[[DomainEvent], Awaitable[None]]] = [] - for handler in handlers: - # Wrap with idempotency - IdempotentEventHandler is callable with the right signature - wrapped = IdempotentEventHandler( - handler=handler, - idempotency_manager=self.idempotency_manager, - logger=self.logger, - key_strategy=self.default_key_strategy, - ttl_seconds=self.default_ttl_seconds, + self._consumer = consumer + self._dispatcher = dispatcher + self._idempotency_manager = idempotency_manager + self._logger = logger + self._default_key_strategy = default_key_strategy + self._default_ttl_seconds = default_ttl_seconds + self._enable_for_all_handlers = enable_for_all_handlers + + async def run(self) -> None: + """Wrap handlers with idempotency, then run consumer.""" + if self._enable_for_all_handlers: + self._wrap_handlers() + self._logger.info("IdempotentConsumerWrapper running") + await self._consumer.run() + + def _wrap_handlers(self) -> None: + """Wrap all registered handlers with idempotency.""" + original_handlers = self._dispatcher.get_all_handlers() + + for event_type, handlers in original_handlers.items(): + wrapped: list[Callable[[DomainEvent], Awaitable[None]]] = [ + IdempotentEventHandler( + handler=h, + idempotency_manager=self._idempotency_manager, + logger=self._logger, + key_strategy=self._default_key_strategy, + ttl_seconds=self._default_ttl_seconds, ) - wrapped_handlers.append(wrapped) - - # Replace handlers using public API - self.logger.info( - f"Replacing {len(handlers)} handlers for {event_type} with {len(wrapped_handlers)} wrapped handlers" - ) - self.dispatcher.replace_handlers(event_type, wrapped_handlers) + for h in handlers + ] + self._dispatcher.replace_handlers(event_type, wrapped) - self.logger.info("Handler wrapping complete") - - def subscribe_idempotent_handler( + def register_idempotent_handler( self, - event_type: str, + event_type: EventType, handler: Callable[[DomainEvent], Awaitable[None]], key_strategy: str | None = None, custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: Set[str] | None = None, + fields: set[str] | None = None, ttl_seconds: int | None = None, - cache_result: bool = True, on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, ) -> None: - """Subscribe an idempotent handler for specific event type""" - # Create the idempotent handler wrapper - idempotent_wrapper = IdempotentEventHandler( + """Register an idempotent handler for an event type.""" + wrapped = IdempotentEventHandler( handler=handler, - idempotency_manager=self.idempotency_manager, - logger=self.logger, - key_strategy=key_strategy or self.default_key_strategy, + idempotency_manager=self._idempotency_manager, + logger=self._logger, + key_strategy=key_strategy or self._default_key_strategy, custom_key_func=custom_key_func, fields=fields, - ttl_seconds=ttl_seconds or self.default_ttl_seconds, - cache_result=cache_result, + ttl_seconds=ttl_seconds or self._default_ttl_seconds, on_duplicate=on_duplicate, ) - - # Create an async handler that processes the message - async def async_handler(message: Any) -> Any: - self.logger.info(f"IDEMPOTENT HANDLER CALLED for {event_type}") - - # Extract event from confluent-kafka Message - if not hasattr(message, "value"): - self.logger.error(f"Received non-Message object for {event_type}: {type(message)}") - return None - - # Debug log to check message details - self.logger.info( - f"Handler for {event_type} - Message type: {type(message)}, " - f"has key: {hasattr(message, 'key')}, " - f"has topic: {hasattr(message, 'topic')}" - ) - - raw_value = message.value - - # Debug the raw value - self.logger.info(f"Raw value extracted: {raw_value[:100] if raw_value else 'None or empty'}") - - # Handle tombstone messages (null value for log compaction) - if raw_value is None: - self.logger.warning(f"Received empty message for {event_type} - tombstone or consumed value") - return None - - # Handle empty messages - if not raw_value: - self.logger.warning(f"Received empty message for {event_type} - empty bytes") - return None - - try: - # Deserialize using schema registry if available - event = await self.consumer._schema_registry.deserialize_event(raw_value, message.topic) - if not event: - self.logger.error(f"Failed to deserialize event for {event_type}") - return None - - # Call the idempotent wrapper directly in async context - await idempotent_wrapper(event) - - self.logger.debug(f"Successfully processed {event_type} event: {event.event_id}") - return None - except Exception as e: - self.logger.error(f"Failed to process message for {event_type}: {e}", exc_info=True) - raise - - # Register with the dispatcher if available - if self.dispatcher: - # Create wrapper for EventDispatcher - async def dispatch_handler(event: DomainEvent) -> None: - await idempotent_wrapper(event) - - self.dispatcher.register(EventType(event_type))(dispatch_handler) - else: - # Fallback to direct consumer registration if no dispatcher - self.logger.error(f"No EventDispatcher available for registering idempotent handler for {event_type}") - - async def start(self, topics: list[KafkaTopic]) -> None: - """Start the consumer with idempotency""" - self.logger.info(f"IdempotentConsumerWrapper.start called with topics: {topics}") - # Make handlers idempotent before starting - self.make_handlers_idempotent() - - # Start the consumer with required topics parameter - await self.consumer.start(topics) - self.logger.info("IdempotentConsumerWrapper started successfully") - - async def stop(self) -> None: - """Stop the consumer""" - await self.consumer.stop() - - # Delegate other methods to the wrapped consumer - def __getattr__(self, name: str) -> Any: - return getattr(self.consumer, name) + self._dispatcher.register(event_type)(wrapped) diff --git a/backend/app/services/k8s_worker/__init__.py b/backend/app/services/k8s_worker/__init__.py index 31616a3b..8e61a7b4 100644 --- a/backend/app/services/k8s_worker/__init__.py +++ b/backend/app/services/k8s_worker/__init__.py @@ -1,9 +1,9 @@ from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.pod_builder import PodBuilder -from app.services.k8s_worker.worker import KubernetesWorker +from app.services.k8s_worker.worker_logic import K8sWorkerLogic __all__ = [ - "KubernetesWorker", + "K8sWorkerLogic", "PodBuilder", "K8sWorkerConfig", ] diff --git a/backend/app/services/k8s_worker/worker.py b/backend/app/services/k8s_worker/worker_logic.py similarity index 70% rename from backend/app/services/k8s_worker/worker.py rename to backend/app/services/k8s_worker/worker_logic.py index 7fd8f5aa..848f4c48 100644 --- a/backend/app/services/k8s_worker/worker.py +++ b/backend/app/services/k8s_worker/worker_logic.py @@ -11,7 +11,6 @@ from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, @@ -21,41 +20,34 @@ ExecutionStartedEvent, PodCreatedEvent, ) -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import ( - SchemaRegistryManager, -) +from app.events.core import EventDispatcher, UnifiedProducer from app.runtime_registry import RUNTIME_REGISTRY -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.pod_builder import PodBuilder from app.settings import Settings -class KubernetesWorker: +class K8sWorkerLogic: """ - Worker service that creates Kubernetes pods from execution events. - - This service: - 1. Consumes ExecutionStarted events from Kafka - 2. Creates ConfigMaps with script content - 3. Creates Pods to execute the scripts - 4. Creates NetworkPolicies for security - 5. Publishes PodCreated events + Business logic for Kubernetes pod management. + + Handles: + - K8s client initialization + - Pod creation from command events + - Pod deletion (compensation) + - Image pre-puller daemonset management + - Event publishing (PodCreated, ExecutionFailed) + + This class is stateful and must be instantiated once per worker instance. """ def __init__( - self, - config: K8sWorkerConfig, - producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, + self, + config: K8sWorkerConfig, + producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + event_metrics: EventMetrics, ): self._event_metrics = event_metrics self.logger = logger @@ -64,121 +56,35 @@ def __init__( self.config = config or K8sWorkerConfig() self._settings = settings - self.kafka_servers = self._settings.KAFKA_BOOTSTRAP_SERVERS - self._event_store = event_store - - # Kubernetes clients + # Kubernetes clients (initialized in initialize()) self.v1: k8s_client.CoreV1Api | None = None self.networking_v1: k8s_client.NetworkingV1Api | None = None self.apps_v1: k8s_client.AppsV1Api | None = None # Components self.pod_builder = PodBuilder(namespace=self.config.namespace, config=self.config) - self.consumer: UnifiedConsumer | None = None - self.idempotent_consumer: IdempotentConsumerWrapper | None = None - self.idempotency_manager: IdempotencyManager = idempotency_manager - self.dispatcher: EventDispatcher | None = None - self.producer: UnifiedProducer = producer + self.producer = producer # State tracking self._active_creations: set[str] = set() self._creation_semaphore = asyncio.Semaphore(self.config.max_concurrent_pods) - self._schema_registry_manager = schema_registry_manager - - async def __aenter__(self) -> "KubernetesWorker": - """Start the Kubernetes worker.""" - self.logger.info("Starting KubernetesWorker service...") - self.logger.info("DEBUG: About to initialize Kubernetes client") + def initialize(self) -> None: + """Initialize Kubernetes clients. Must be called before handling events.""" if self.config.namespace == "default": raise RuntimeError( "KubernetesWorker namespace 'default' is forbidden. Set K8S_NAMESPACE to a dedicated namespace." ) - # Initialize Kubernetes client self._initialize_kubernetes_client() - self.logger.info("DEBUG: Kubernetes client initialized") - - self.logger.info("Using provided producer") - - self.logger.info("Idempotency manager provided") - - # Create consumer configuration - consumer_config = ConsumerConfig( - bootstrap_servers=self.kafka_servers, - group_id=f"{self.config.consumer_group}.{self._settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create dispatcher and register handlers for saga commands - self.dispatcher = EventDispatcher(logger=self.logger) - self.dispatcher.register_handler(EventType.CREATE_POD_COMMAND, self._handle_create_pod_command_wrapper) - self.dispatcher.register_handler(EventType.DELETE_POD_COMMAND, self._handle_delete_pod_command_wrapper) - - # Create consumer with dispatcher - self.consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self.dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - - # Wrap consumer with idempotency - use content hash for pod commands - self.idempotent_consumer = IdempotentConsumerWrapper( - consumer=self.consumer, - idempotency_manager=self.idempotency_manager, - dispatcher=self.dispatcher, - logger=self.logger, - default_key_strategy="content_hash", # Hash execution_id + script for deduplication - default_ttl_seconds=3600, # 1 hour TTL for pod creation events - enable_for_all_handlers=True, # Enable idempotency for all handlers - ) - - # Start the consumer with idempotency - topics from centralized config - await self.idempotent_consumer.start(list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.K8S_WORKER])) - - # Create daemonset for image pre-pulling - asyncio.create_task(self.ensure_image_pre_puller_daemonset()) - self.logger.info("Image pre-puller daemonset task scheduled") - self.logger.info("KubernetesWorker service started successfully") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the Kubernetes worker.""" - self.logger.info("Stopping KubernetesWorker service...") - - # Wait for active creations to complete - if self._active_creations: - self.logger.info(f"Waiting for {len(self._active_creations)} active pod creations to complete...") - timeout = 30 - start_time = time.time() - - while self._active_creations and (time.time() - start_time) < timeout: - await asyncio.sleep(1) - - if self._active_creations: - self.logger.warning(f"Timeout waiting for pod creations, {len(self._active_creations)} still active") - - # Stop the consumer (idempotent wrapper only) - if self.idempotent_consumer: - await self.idempotent_consumer.stop() - - # Close idempotency manager - await self.idempotency_manager.close() - - # Note: producer is managed by DI container, not stopped here - - self.logger.info("KubernetesWorker service stopped") + def register_handlers(self, dispatcher: EventDispatcher) -> None: + """Register event handlers with the dispatcher.""" + dispatcher.register_handler(EventType.CREATE_POD_COMMAND, self._handle_create_pod_command_wrapper) + dispatcher.register_handler(EventType.DELETE_POD_COMMAND, self._handle_delete_pod_command_wrapper) def _initialize_kubernetes_client(self) -> None: - """Initialize Kubernetes API clients""" + """Initialize Kubernetes API clients.""" try: # Load config if self.config.in_cluster: @@ -199,7 +105,6 @@ def _initialize_kubernetes_client(self) -> None: # Get the default configuration that was set by load_kube_config configuration = k8s_client.Configuration.get_default_copy() - # The certificate data should already be configured by load_kube_config # Log the configuration for debugging self.logger.info(f"Kubernetes API host: {configuration.host}") self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}") @@ -212,7 +117,9 @@ def _initialize_kubernetes_client(self) -> None: # Test connection with namespace-scoped operation _ = self.v1.list_namespaced_pod(namespace=self.config.namespace, limit=1) - self.logger.info(f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible") + self.logger.info( + f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible" + ) except Exception as e: self.logger.error(f"Failed to initialize Kubernetes client: {e}") @@ -221,17 +128,21 @@ def _initialize_kubernetes_client(self) -> None: async def _handle_create_pod_command_wrapper(self, event: DomainEvent) -> None: """Wrapper for handling CreatePodCommandEvent with type safety.""" assert isinstance(event, CreatePodCommandEvent) - self.logger.info(f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}") + self.logger.info( + f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}" + ) await self._handle_create_pod_command(event) async def _handle_delete_pod_command_wrapper(self, event: DomainEvent) -> None: """Wrapper for handling DeletePodCommandEvent.""" assert isinstance(event, DeletePodCommandEvent) - self.logger.info(f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}") + self.logger.info( + f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}" + ) await self._handle_delete_pod_command(event) async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: - """Handle create pod command from saga orchestrator""" + """Handle create pod command from saga orchestrator.""" execution_id = command.execution_id # Check if already processing @@ -243,7 +154,7 @@ async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> No asyncio.create_task(self._create_pod_for_execution(command)) async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: - """Handle delete pod command from saga orchestrator (compensation)""" + """Handle delete pod command from saga orchestrator (compensation).""" execution_id = command.execution_id self.logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}") @@ -271,19 +182,19 @@ async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> No except ApiException as e: if e.status == 404: - self.logger.warning(f"Resources for execution {execution_id} not found (may have already been deleted)") + self.logger.warning( + f"Resources for execution {execution_id} not found (may have already been deleted)" + ) else: self.logger.error(f"Failed to delete resources for execution {execution_id}: {e}") async def _create_pod_for_execution(self, command: CreatePodCommandEvent) -> None: - """Create pod for execution""" + """Create pod for execution.""" async with self._creation_semaphore: execution_id = command.execution_id self._active_creations.add(execution_id) self.metrics.update_k8s_active_creations(len(self._active_creations)) - # Queue depth is owned by the coordinator; do not modify here - start_time = time.time() try: @@ -328,7 +239,7 @@ async def _create_pod_for_execution(self, command: CreatePodCommandEvent) -> Non self.metrics.update_k8s_active_creations(len(self._active_creations)) async def _get_entrypoint_script(self) -> str: - """Get entrypoint script content""" + """Get entrypoint script content.""" entrypoint_path = Path("app/scripts/entrypoint.sh") if entrypoint_path.exists(): return await asyncio.to_thread(entrypoint_path.read_text) @@ -351,7 +262,7 @@ async def _get_entrypoint_script(self) -> str: """ async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None: - """Create ConfigMap in Kubernetes""" + """Create ConfigMap in Kubernetes.""" if not self.v1: raise RuntimeError("Kubernetes client not initialized") try: @@ -369,7 +280,7 @@ async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None: raise async def _create_pod(self, pod: k8s_client.V1Pod) -> None: - """Create Pod in Kubernetes""" + """Create Pod in Kubernetes.""" if not self.v1: raise RuntimeError("Kubernetes client not initialized") try: @@ -382,7 +293,7 @@ async def _create_pod(self, pod: k8s_client.V1Pod) -> None: raise async def _publish_execution_started(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: - """Publish execution started event""" + """Publish execution started event.""" event = ExecutionStartedEvent( execution_id=command.execution_id, aggregate_id=command.execution_id, # Set aggregate_id to execution_id @@ -397,7 +308,7 @@ async def _publish_execution_started(self, command: CreatePodCommandEvent, pod: await self.producer.produce(event_to_produce=event) async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: - """Publish pod created event""" + """Publish pod created event.""" event = PodCreatedEvent( execution_id=command.execution_id, pod_name=pod.metadata.name, @@ -411,7 +322,7 @@ async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_cl await self.producer.produce(event_to_produce=event) async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, error: str) -> None: - """Publish pod creation failed event""" + """Publish pod creation failed event.""" event = ExecutionFailedEvent( execution_id=command.execution_id, error_type=ExecutionErrorType.SYSTEM_ERROR, @@ -427,20 +338,15 @@ async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, err return await self.producer.produce(event_to_produce=event) - async def get_status(self) -> dict[str, Any]: - """Get worker status""" - return { - "running": self.idempotent_consumer is not None, - "active_creations": len(self._active_creations), - "config": { - "namespace": self.config.namespace, - "max_concurrent_pods": self.config.max_concurrent_pods, - "enable_network_policies": True, - }, - } + async def ensure_daemonset_task(self) -> None: + """Ensure daemonset exists, then complete (not a loop).""" + await self.ensure_image_pre_puller_daemonset() + self.logger.info("Image pre-puller daemonset task completed") + # This task completes immediately after ensuring the daemonset + # The TaskGroup will keep running because the consumer task is still running async def ensure_image_pre_puller_daemonset(self) -> None: - """Ensure the runtime image pre-puller DaemonSet exists""" + """Ensure the runtime image pre-puller DaemonSet exists.""" if not self.apps_v1: self.logger.warning("Kubernetes AppsV1Api client not initialized. Skipping DaemonSet creation.") return @@ -454,7 +360,9 @@ async def ensure_image_pre_puller_daemonset(self) -> None: all_images = {config.image for lang in RUNTIME_REGISTRY.values() for config in lang.values()} for i, image_ref in enumerate(sorted(list(all_images))): - sanitized_image_ref = image_ref.split("/")[-1].replace(":", "-").replace(".", "-").replace("_", "-") + sanitized_image_ref = ( + image_ref.split("/")[-1].replace(":", "-").replace(".", "-").replace("_", "-") + ) self.logger.info(f"DAEMONSET: before: {image_ref} -> {sanitized_image_ref}") container_name = f"pull-{i}-{sanitized_image_ref}" init_containers.append( @@ -490,7 +398,10 @@ async def ensure_image_pre_puller_daemonset(self) -> None: ) self.logger.info(f"DaemonSet '{daemonset_name}' exists. Replacing to ensure it is up-to-date.") await asyncio.to_thread( - self.apps_v1.replace_namespaced_daemon_set, name=daemonset_name, namespace=namespace, body=manifest + self.apps_v1.replace_namespaced_daemon_set, + name=daemonset_name, + namespace=namespace, + body=manifest, ) self.logger.info(f"DaemonSet '{daemonset_name}' replaced successfully.") except ApiException as e: @@ -507,3 +418,24 @@ async def ensure_image_pre_puller_daemonset(self) -> None: self.logger.error(f"K8s API error applying DaemonSet '{daemonset_name}': {e.reason}", exc_info=True) except Exception as e: self.logger.error(f"Unexpected error applying image-puller DaemonSet: {e}", exc_info=True) + + async def wait_for_active_creations(self, timeout: float = 30) -> None: + """Wait for active pod creations to complete.""" + if self._active_creations: + self.logger.info(f"Waiting for {len(self._active_creations)} active pod creations...") + start_time = time.time() + while self._active_creations and (time.time() - start_time) < timeout: + await asyncio.sleep(1) + if self._active_creations: + self.logger.warning(f"Timeout, {len(self._active_creations)} pod creations still active") + + async def get_status(self) -> dict[str, Any]: + """Get worker status.""" + return { + "active_creations": len(self._active_creations), + "config": { + "namespace": self.config.namespace, + "max_concurrent_pods": self.config.max_concurrent_pods, + "enable_network_policies": True, + }, + } diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 13a48418..ffb3e10c 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -2,29 +2,19 @@ import logging from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta -from enum import auto from typing import Awaitable, Callable import httpx -from app.core.metrics import EventMetrics, NotificationMetrics +from app.core.metrics import NotificationMetrics from app.core.tracing.utils import add_span_attributes -from app.core.utils import StringEnum from app.db.repositories.notification_repository import NotificationRepository -from app.domain.enums.events import EventType -from app.domain.enums.kafka import GroupId from app.domain.enums.notification import ( NotificationChannel, NotificationSeverity, NotificationStatus, ) from app.domain.enums.user import UserRole -from app.domain.events.typed import ( - DomainEvent, - ExecutionCompletedEvent, - ExecutionFailedEvent, - ExecutionTimeoutEvent, -) from app.domain.notification import ( DomainNotification, DomainNotificationCreate, @@ -36,12 +26,8 @@ NotificationThrottledError, NotificationValidationError, ) -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.mappings import get_topic_for_event from app.schemas_pydantic.sse import RedisNotificationMessage from app.services.event_bus import EventBus -from app.services.kafka_event_service import KafkaEventService from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings @@ -49,21 +35,12 @@ ENTITY_EXECUTION_TAG = "entity:execution" # Type aliases -type EventPayload = dict[str, object] type NotificationContext = dict[str, object] type ChannelHandler = Callable[[DomainNotification, DomainNotificationSubscription], Awaitable[None]] type SystemNotificationStats = dict[str, int] type SlackMessage = dict[str, object] -class ServiceState(StringEnum): - """Service lifecycle states.""" - - RUNNING = auto() - STOPPING = auto() - STOPPED = auto() - - @dataclass class ThrottleCache: """Manages notification throttling with time windows.""" @@ -98,11 +75,6 @@ async def check_throttle( self._entries[key].append(now) return False - async def clear(self) -> None: - """Clear all throttle entries.""" - async with self._lock: - self._entries.clear() - @dataclass(frozen=True) class SystemConfig: @@ -111,39 +83,37 @@ class SystemConfig: class NotificationService: + """Service for creating and managing notifications. + + This service handles: + - Creating notifications (user and system) + - Delivering notifications via channels (in-app, webhook, slack) + - Managing notification subscriptions + - Rate limiting via throttle cache + + Background tasks (pending notification processing, cleanup) are started + via the run() method, which should be called from app lifespan. + """ + def __init__( self, notification_repository: NotificationRepository, - event_service: KafkaEventService, event_bus: EventBus, - schema_registry_manager: SchemaRegistryManager, sse_bus: SSERedisBus, settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, - event_metrics: EventMetrics, ) -> None: self.repository = notification_repository - self.event_service = event_service self.event_bus = event_bus self.metrics = notification_metrics - self._event_metrics = event_metrics self.settings = settings - self.schema_registry_manager = schema_registry_manager self.sse_bus = sse_bus self.logger = logger - # State - self._state = ServiceState.RUNNING + # Throttle cache for rate limiting self._throttle_cache = ThrottleCache() - # Tasks - self._tasks: set[asyncio.Task[None]] = set() - - self._consumer: UnifiedConsumer | None = None - self._dispatcher: EventDispatcher | None = None - self._consumer_task: asyncio.Task[None] | None = None - # Channel handlers mapping self._channel_handlers: dict[NotificationChannel, ChannelHandler] = { NotificationChannel.IN_APP: self._send_in_app, @@ -151,105 +121,25 @@ def __init__( NotificationChannel.SLACK: self._send_slack, } - # Start background processors - self._start_background_tasks() - self.logger.info( "NotificationService initialized", - extra={ - "repository": type(notification_repository).__name__, - "event_service": type(event_service).__name__, - "schema_registry": type(schema_registry_manager).__name__, - }, + extra={"repository": type(notification_repository).__name__}, ) - @property - def state(self) -> ServiceState: - return self._state + async def run(self) -> None: + """Run background tasks. Blocks until cancelled. - async def shutdown(self) -> None: - """Shutdown notification service.""" - if self._state == ServiceState.STOPPED: - return - - self.logger.info("Shutting down notification service...") - self._state = ServiceState.STOPPING - - # Cancel all tasks - for task in self._tasks: - task.cancel() - - # Wait for cancellation - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - - # Stop consumer - if self._consumer: - await self._consumer.stop() - - # Clear cache - await self._throttle_cache.clear() - - self._state = ServiceState.STOPPED - self.logger.info("Notification service stopped") - - def _start_background_tasks(self) -> None: - """Start background processing tasks.""" - tasks = [ - asyncio.create_task(self._process_pending_notifications()), - asyncio.create_task(self._cleanup_old_notifications()), - ] - - for task in tasks: - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) - - async def _subscribe_to_events(self) -> None: - """Subscribe to relevant events for notifications.""" - # Configure consumer for notification-relevant events - consumer_config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.NOTIFICATION_SERVICE}.{self.settings.KAFKA_GROUP_SUFFIX}", - max_poll_records=10, - enable_auto_commit=True, - auto_offset_reset="latest", - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - execution_results_topic = get_topic_for_event(EventType.EXECUTION_COMPLETED) - - # Log topics for debugging - self.logger.info(f"Notification service will subscribe to topics: {execution_results_topic}") - - # Create dispatcher and register handlers for specific event types - self._dispatcher = EventDispatcher(logger=self.logger) - # Use a single handler for execution result events (simpler and less brittle) - self._dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_execution_event) - self._dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_execution_event) - self._dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_execution_event) - - # Create consumer with dispatcher - self._consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self._dispatcher, - schema_registry=self.schema_registry_manager, - settings=self.settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - - # Start consumer - await self._consumer.start([execution_results_topic]) - - # Start consumer task - self._consumer_task = asyncio.create_task(self._run_consumer()) - self._tasks.add(self._consumer_task) - self._consumer_task.add_done_callback(self._tasks.discard) - - self.logger.info("Notification service subscribed to execution events") + Runs: + - Pending notification processor (retries failed deliveries) + - Old notification cleanup (daily) + """ + self.logger.info("Starting NotificationService background tasks...") + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(self._process_pending_notifications()) + tg.create_task(self._cleanup_old_notifications()) + except* asyncio.CancelledError: + self.logger.info("NotificationService background tasks cancelled") async def create_notification( self, @@ -544,7 +434,7 @@ def _get_slack_color(self, priority: NotificationSeverity) -> str: async def _process_pending_notifications(self) -> None: """Process pending notifications in background.""" - while self._state == ServiceState.RUNNING: + while True: try: # Find pending notifications notifications = await self.repository.find_pending_notifications( @@ -553,130 +443,35 @@ async def _process_pending_notifications(self) -> None: # Process each notification for notification in notifications: - if self._state != ServiceState.RUNNING: - break await self._deliver_notification(notification) # Sleep between batches await asyncio.sleep(5) + except asyncio.CancelledError: + raise except Exception as e: self.logger.error(f"Error processing pending notifications: {e}") await asyncio.sleep(10) async def _cleanup_old_notifications(self) -> None: """Cleanup old notifications periodically.""" - while self._state == ServiceState.RUNNING: + while True: try: # Run cleanup once per day await asyncio.sleep(86400) # 24 hours - if self._state != ServiceState.RUNNING: - break - # Delete old notifications deleted_count = await self.repository.cleanup_old_notifications(self.settings.NOTIF_OLD_DAYS) self.logger.info(f"Cleaned up {deleted_count} old notifications") - except Exception as e: - self.logger.error(f"Error cleaning up old notifications: {e}") - - async def _run_consumer(self) -> None: - """Run the event consumer loop.""" - while self._state == ServiceState.RUNNING: - try: - # Consumer handles polling internally - await asyncio.sleep(1) except asyncio.CancelledError: - self.logger.info("Notification consumer task cancelled") - break + raise except Exception as e: - self.logger.error(f"Error in notification consumer loop: {e}") + self.logger.error(f"Error cleaning up old notifications: {e}") await asyncio.sleep(5) - async def _handle_execution_timeout_typed(self, event: ExecutionTimeoutEvent) -> None: - """Handle typed execution timeout event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") - return - - title = f"Execution Timeout: {event.execution_id}" - body = f"Your execution timed out after {event.timeout_seconds}s." - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.HIGH, - tags=["execution", "timeout", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), - ) - - async def _handle_execution_completed_typed(self, event: ExecutionCompletedEvent) -> None: - """Handle typed execution completed event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") - return - - title = f"Execution Completed: {event.execution_id}" - duration = event.resource_usage.execution_time_wall_seconds if event.resource_usage else 0.0 - body = f"Your execution completed successfully. Duration: {duration:.2f}s." - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.MEDIUM, - tags=["execution", "completed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ), - ) - - async def _handle_execution_event(self, event: DomainEvent) -> None: - """Unified handler for execution result events.""" - try: - if isinstance(event, ExecutionCompletedEvent): - await self._handle_execution_completed_typed(event) - elif isinstance(event, ExecutionFailedEvent): - await self._handle_execution_failed_typed(event) - elif isinstance(event, ExecutionTimeoutEvent): - await self._handle_execution_timeout_typed(event) - else: - self.logger.warning(f"Unhandled execution event type: {event.event_type}") - except Exception as e: - self.logger.error(f"Error handling execution event: {e}", exc_info=True) - - async def _handle_execution_failed_typed(self, event: ExecutionFailedEvent) -> None: - """Handle typed execution failed event.""" - user_id = event.metadata.user_id - if not user_id: - self.logger.error("No user_id in event metadata") - return - - # Use model_dump to get all event data - event_data = event.model_dump( - exclude={"metadata", "event_type", "event_version", "timestamp", "aggregate_id", "topic"} - ) - - # Truncate stdout/stderr for notification context - event_data["stdout"] = event_data["stdout"][:200] - event_data["stderr"] = event_data["stderr"][:200] - - title = f"Execution Failed: {event.execution_id}" - body = f"Your execution failed: {event.error_message}" - await self.create_notification( - user_id=user_id, - subject=title, - body=body, - severity=NotificationSeverity.HIGH, - tags=["execution", "failed", ENTITY_EXECUTION_TAG, f"exec:{event.execution_id}"], - metadata=event_data, - ) - async def mark_as_read(self, user_id: str, notification_id: str) -> bool: """Mark notification as read.""" success = await self.repository.mark_as_read(notification_id, user_id) diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index ae95f6a7..c43d53ad 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -71,6 +71,10 @@ class PodMonitor: Watches pods with specific labels using the K8s watch API, maps Kubernetes events to application events, and publishes them to Kafka. Reconciles state when watch restarts (every watch_timeout_seconds or on error). + + Usage: + monitor = PodMonitor(...) + await monitor.run() # Blocks until cancelled """ def __init__( @@ -82,6 +86,7 @@ def __init__( event_mapper: PodEventMapper, kubernetes_metrics: KubernetesMetrics, ) -> None: + """Store dependencies. All work happens in run().""" self.logger = logger self.config = config @@ -99,76 +104,59 @@ def __init__( self._reconnect_attempts: int = 0 self._last_resource_version: ResourceVersion | None = None - # Task - self._watch_task: asyncio.Task[None] | None = None - # Metrics self._metrics = kubernetes_metrics - async def __aenter__(self) -> "PodMonitor": - """Start the pod monitor.""" - self.logger.info("Starting PodMonitor service...") + async def run(self) -> None: + """Run the monitor. Blocks until cancelled. + + Verifies K8s connectivity, runs watch loop, cleans up on exit. + """ + self.logger.info("PodMonitor starting...") # Verify K8s connectivity await asyncio.to_thread(self._v1.get_api_resources) self.logger.info("Successfully connected to Kubernetes API") - # Start watch task - self._watch_task = asyncio.create_task(self._watch_loop()) - - self.logger.info("PodMonitor service started successfully") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the pod monitor.""" - self.logger.info("Stopping PodMonitor service...") - - if self._watch_task: - self._watch_task.cancel() - try: - await self._watch_task - except asyncio.CancelledError: - pass - - if self._watch: - self._watch.stop() - - self._tracked_pods.clear() - self._event_mapper.clear_cache() - self.logger.info("PodMonitor service stopped") + try: + await self._watch_loop() + except asyncio.CancelledError: + self.logger.info("PodMonitor cancelled") + finally: + if self._watch: + self._watch.stop() + self._tracked_pods.clear() + self._event_mapper.clear_cache() + self.logger.info("PodMonitor stopped") async def _watch_loop(self) -> None: """Main watch loop - reconciles on each restart.""" - try: - while True: - try: - # Reconcile before starting watch (catches missed events) - if self.config.enable_state_reconciliation: - await self._reconcile() - - self._reconnect_attempts = 0 - await self._run_watch() - - except ApiException as e: - if e.status == 410: # Resource version expired - self.logger.warning("Resource version expired, resetting watch") - self._last_resource_version = None - self._metrics.record_pod_monitor_watch_error(ErrorType.RESOURCE_VERSION_EXPIRED) - else: - self.logger.error(f"API error in watch: {e}") - self._metrics.record_pod_monitor_watch_error(ErrorType.API_ERROR) - await self._backoff() - - except asyncio.CancelledError: - raise - - except Exception as e: - self.logger.error(f"Unexpected error in watch: {e}", exc_info=True) - self._metrics.record_pod_monitor_watch_error(ErrorType.UNEXPECTED) - await self._backoff() + while True: + try: + # Reconcile before starting watch (catches missed events) + if self.config.enable_state_reconciliation: + await self._reconcile() + + self._reconnect_attempts = 0 + await self._run_watch() + + except ApiException as e: + if e.status == 410: # Resource version expired + self.logger.warning("Resource version expired, resetting watch") + self._last_resource_version = None + self._metrics.record_pod_monitor_watch_error(ErrorType.RESOURCE_VERSION_EXPIRED) + else: + self.logger.error(f"API error in watch: {e}") + self._metrics.record_pod_monitor_watch_error(ErrorType.API_ERROR) + await self._backoff() - except asyncio.CancelledError: - self.logger.info("Watch loop cancelled") + except asyncio.CancelledError: + raise + + except Exception as e: + self.logger.error(f"Unexpected error in watch: {e}", exc_info=True) + self._metrics.record_pod_monitor_watch_error(ErrorType.UNEXPECTED) + await self._backoff() async def _run_watch(self) -> None: """Run a single watch session.""" @@ -360,7 +348,6 @@ async def get_status(self) -> StatusDict: "tracked_pods": len(self._tracked_pods), "reconnect_attempts": self._reconnect_attempts, "last_resource_version": self._last_resource_version, - "watch_task_active": self._watch_task is not None and not self._watch_task.done(), "config": { "namespace": self.config.namespace, "label_selector": self.config.label_selector, diff --git a/backend/app/services/result_processor/__init__.py b/backend/app/services/result_processor/__init__.py index e3907fa1..de62dc32 100644 --- a/backend/app/services/result_processor/__init__.py +++ b/backend/app/services/result_processor/__init__.py @@ -1,8 +1,7 @@ -from app.services.result_processor.processor import ResultProcessor, ResultProcessorConfig +from app.services.result_processor.processor_logic import ProcessorLogic from app.services.result_processor.resource_cleaner import ResourceCleaner __all__ = [ - "ResultProcessor", - "ResultProcessorConfig", + "ProcessorLogic", "ResourceCleaner", ] diff --git a/backend/app/services/result_processor/processor.py b/backend/app/services/result_processor/processor_logic.py similarity index 58% rename from backend/app/services/result_processor/processor.py rename to backend/app/services/result_processor/processor_logic.py index dddb9eb7..3bae92cb 100644 --- a/backend/app/services/result_processor/processor.py +++ b/backend/app/services/result_processor/processor_logic.py @@ -1,15 +1,10 @@ import logging -from enum import auto -from typing import Any -from pydantic import BaseModel, ConfigDict, Field - -from app.core.metrics import EventMetrics, ExecutionMetrics -from app.core.utils import StringEnum +from app.core.metrics import ExecutionMetrics from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic +from app.domain.enums.kafka import GroupId from app.domain.enums.storage import ExecutionErrorType, StorageType from app.domain.events.typed import ( DomainEvent, @@ -21,138 +16,43 @@ ResultStoredEvent, ) from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.events.core import EventDispatcher, UnifiedProducer from app.settings import Settings -class ProcessingState(StringEnum): - """Processing state enumeration.""" - - IDLE = auto() - PROCESSING = auto() - STOPPED = auto() - - -class ResultProcessorConfig(BaseModel): - """Configuration for result processor.""" +class ProcessorLogic: + """ + Business logic for result processing. - model_config = ConfigDict(frozen=True) + Handles: + - Processing execution completion events + - Storing results in database + - Publishing ResultStored/ResultFailed events + - Recording metrics - consumer_group: GroupId = Field(default=GroupId.RESULT_PROCESSOR) - topics: list[KafkaTopic] = Field( - default_factory=lambda: list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.RESULT_PROCESSOR]) - ) - result_topic: KafkaTopic = Field(default=KafkaTopic.EXECUTION_RESULTS) - batch_size: int = Field(default=10) - processing_timeout: int = Field(default=300) - - -class ResultProcessor: - """Service for processing execution completion events and storing results.""" + This class is stateful and must be instantiated once per processor instance. + """ def __init__( - self, - execution_repo: ExecutionRepository, - producer: UnifiedProducer, - schema_registry: SchemaRegistryManager, - settings: Settings, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - execution_metrics: ExecutionMetrics, - event_metrics: EventMetrics, + self, + execution_repo: ExecutionRepository, + producer: UnifiedProducer, + settings: Settings, + logger: logging.Logger, + execution_metrics: ExecutionMetrics, ) -> None: - """Initialize the result processor.""" - self.config = ResultProcessorConfig() self._execution_repo = execution_repo self._producer = producer - self._schema_registry = schema_registry self._settings = settings self._metrics = execution_metrics - self._event_metrics = event_metrics - self._idempotency_manager: IdempotencyManager = idempotency_manager - self._state = ProcessingState.IDLE - self._consumer: IdempotentConsumerWrapper | None = None - self._dispatcher: EventDispatcher | None = None self.logger = logger - async def __aenter__(self) -> "ResultProcessor": - """Start the result processor.""" - self.logger.info("Starting ResultProcessor...") - - # Initialize idempotency manager (safe to call multiple times) - await self._idempotency_manager.initialize() - self.logger.info("Idempotency manager initialized for ResultProcessor") - - self._dispatcher = self._create_dispatcher() - self._consumer = await self._create_consumer() - self._state = ProcessingState.PROCESSING - self.logger.info("ResultProcessor started successfully with idempotency protection") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the result processor.""" - self.logger.info("Stopping ResultProcessor...") - self._state = ProcessingState.STOPPED - - if self._consumer: - await self._consumer.stop() - - await self._idempotency_manager.close() - # Note: producer is managed by DI container, not stopped here - self.logger.info("ResultProcessor stopped") - - def _create_dispatcher(self) -> EventDispatcher: - """Create and configure event dispatcher with handlers.""" - dispatcher = EventDispatcher(logger=self.logger) - - # Register handlers for specific event types + def register_handlers(self, dispatcher: EventDispatcher) -> None: + """Register event handlers with the dispatcher.""" dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_completed_wrapper) dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_failed_wrapper) dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_timeout_wrapper) - return dispatcher - - async def _create_consumer(self) -> IdempotentConsumerWrapper: - """Create and configure idempotent Kafka consumer.""" - consumer_config = ConsumerConfig( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{self.config.consumer_group}.{self._settings.KAFKA_GROUP_SUFFIX}", - max_poll_records=1, - enable_auto_commit=True, - auto_offset_reset="earliest", - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create consumer with schema registry and dispatcher - if not self._dispatcher: - raise RuntimeError("Event dispatcher not initialized") - - base_consumer = UnifiedConsumer( - consumer_config, - event_dispatcher=self._dispatcher, - schema_registry=self._schema_registry, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - wrapper = IdempotentConsumerWrapper( - consumer=base_consumer, - idempotency_manager=self._idempotency_manager, - dispatcher=self._dispatcher, - logger=self.logger, - default_key_strategy="content_hash", - default_ttl_seconds=7200, - enable_for_all_handlers=True, - ) - await wrapper.start(self.config.topics) - return wrapper - # Wrappers accepting DomainEvent to satisfy dispatcher typing async def _handle_completed_wrapper(self, event: DomainEvent) -> None: @@ -169,7 +69,6 @@ async def _handle_timeout_wrapper(self, event: DomainEvent) -> None: async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: """Handle execution completed event.""" - exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -213,7 +112,6 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: async def _handle_failed(self, event: ExecutionFailedEvent) -> None: """Handle execution failed event.""" - # Fetch execution to get language and version for metrics exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: @@ -242,7 +140,6 @@ async def _handle_failed(self, event: ExecutionFailedEvent) -> None: async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: """Handle execution timeout event.""" - exec_obj = await self._execution_repo.get_execution(event.execution_id) if exec_obj is None: raise ExecutionNotFoundError(event.execution_id) @@ -273,7 +170,6 @@ async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: """Publish result stored event.""" - size_bytes = len(result.stdout) + len(result.stderr) event = ResultStoredEvent( execution_id=result.execution_id, @@ -290,7 +186,6 @@ async def _publish_result_stored(self, result: ExecutionResultDomain) -> None: async def _publish_result_failed(self, execution_id: str, error_message: str) -> None: """Publish result processing failed event.""" - event = ResultFailedEvent( execution_id=execution_id, error=error_message, @@ -301,10 +196,3 @@ async def _publish_result_failed(self, execution_id: str, error_message: str) -> ) await self._producer.produce(event_to_produce=event, key=execution_id) - - async def get_status(self) -> dict[str, Any]: - """Get processor status.""" - return { - "state": self._state, - "consumer_active": self._consumer is not None, - } diff --git a/backend/app/services/saga/__init__.py b/backend/app/services/saga/__init__.py index ec47a201..a06e11a8 100644 --- a/backend/app/services/saga/__init__.py +++ b/backend/app/services/saga/__init__.py @@ -12,11 +12,11 @@ RemoveFromQueueCompensation, ValidateExecutionStep, ) -from app.services.saga.saga_orchestrator import SagaOrchestrator +from app.services.saga.saga_logic import SagaLogic from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep __all__ = [ - "SagaOrchestrator", + "SagaLogic", "SagaConfig", "SagaState", "SagaInstance", diff --git a/backend/app/services/saga/saga_orchestrator.py b/backend/app/services/saga/saga_logic.py similarity index 76% rename from backend/app/services/saga/saga_orchestrator.py rename to backend/app/services/saga/saga_logic.py index 4c293165..caed3457 100644 --- a/backend/app/services/saga/saga_orchestrator.py +++ b/backend/app/services/saga/saga_logic.py @@ -10,34 +10,37 @@ from app.core.tracing.utils import get_tracer from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository from app.db.repositories.saga_repository import SagaRepository +from app.domain.enums.events import EventType +from app.domain.enums.kafka import KafkaTopic from app.domain.enums.saga import SagaState from app.domain.events.typed import DomainEvent, EventMetadata, SagaCancelledEvent from app.domain.saga.models import Saga, SagaConfig -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager +from app.events.core import EventDispatcher, UnifiedProducer from app.infrastructure.kafka.mappings import get_topic_for_event -from app.services.idempotency import IdempotentConsumerWrapper -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.settings import Settings from .base_saga import BaseSaga from .execution_saga import ExecutionSaga from .saga_step import SagaContext -class SagaOrchestrator: - """Orchestrates saga execution and compensation""" +class SagaLogic: + """ + Business logic for saga orchestration. + + Handles: + - Saga registration and management + - Event handling and saga triggering + - Saga execution and compensation + - Timeout checking + + This class is stateful and must be instantiated once per orchestrator instance. + """ def __init__( self, config: SagaConfig, saga_repository: SagaRepository, producer: UnifiedProducer, - schema_registry_manager: SchemaRegistryManager, - settings: Settings, - event_store: EventStore, - idempotency_manager: IdempotencyManager, resource_allocation_repository: ResourceAllocationRepository, logger: logging.Logger, event_metrics: EventMetrics, @@ -45,118 +48,49 @@ def __init__( self.config = config self._sagas: dict[str, type[BaseSaga]] = {} self._running_instances: dict[str, Saga] = {} - self._consumer: IdempotentConsumerWrapper | None = None - self._idempotency_manager: IdempotencyManager = idempotency_manager self._producer = producer - self._schema_registry_manager = schema_registry_manager - self._settings = settings - self._event_store = event_store self._repo: SagaRepository = saga_repository self._alloc_repo: ResourceAllocationRepository = resource_allocation_repository - self._tasks: list[asyncio.Task[None]] = [] self.logger = logger self._event_metrics = event_metrics def register_saga(self, saga_class: type[BaseSaga]) -> None: + """Register a saga class.""" self._sagas[saga_class.get_name()] = saga_class self.logger.info(f"Registered saga: {saga_class.get_name()}") - def _register_default_sagas(self) -> None: + def register_default_sagas(self) -> None: + """Register the default sagas.""" self.register_saga(ExecutionSaga) self.logger.info("Registered default sagas") - async def __aenter__(self) -> "SagaOrchestrator": - """Start the saga orchestrator.""" - self.logger.info(f"Starting saga orchestrator: {self.config.name}") - - self._register_default_sagas() - - await self._start_consumer() - - timeout_task = asyncio.create_task(self._check_timeouts()) - self._tasks.append(timeout_task) - - self.logger.info("Saga orchestrator started") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the saga orchestrator.""" - self.logger.info("Stopping saga orchestrator...") - - if self._consumer: - await self._consumer.stop() - - await self._idempotency_manager.close() - - for task in self._tasks: - if not task.done(): - task.cancel() - - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - - self.logger.info("Saga orchestrator stopped") - - async def _start_consumer(self) -> None: - self.logger.info(f"Registered sagas: {list(self._sagas.keys())}") - topics = set() - event_types_to_register = set() - + def get_trigger_topics(self) -> set[KafkaTopic]: + """Get all topics that trigger sagas.""" + topics: set[KafkaTopic] = set() for saga_class in self._sagas.values(): trigger_event_types = saga_class.get_trigger_events() - self.logger.info(f"Saga {saga_class.get_name()} triggers on event types: {trigger_event_types}") - - # Convert event types to topics for subscription for event_type in trigger_event_types: topic = get_topic_for_event(event_type) topics.add(topic) - event_types_to_register.add(event_type) - self.logger.debug(f"Event type {event_type} maps to topic {topic}") - - if not topics: - self.logger.warning("No trigger events found in registered sagas") - return - - consumer_config = ConsumerConfig( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"saga-{self.config.name}.{self._settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - session_timeout_ms=self._settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self._settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self._settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self._settings.KAFKA_REQUEST_TIMEOUT_MS, - ) + return topics - dispatcher = EventDispatcher(logger=self.logger) - for event_type in event_types_to_register: - dispatcher.register_handler(event_type, self._handle_event) + def get_trigger_event_types(self) -> set[EventType]: + """Get all event types that trigger sagas.""" + event_types: set[EventType] = set() + for saga_class in self._sagas.values(): + trigger_event_types = saga_class.get_trigger_events() + event_types.update(trigger_event_types) + return event_types + + def register_handlers(self, dispatcher: EventDispatcher) -> None: + """Register event handlers with the dispatcher.""" + event_types = self.get_trigger_event_types() + for event_type in event_types: + dispatcher.register_handler(event_type, self.handle_event) self.logger.info(f"Registered handler for event type: {event_type}") - base_consumer = UnifiedConsumer( - config=consumer_config, - event_dispatcher=dispatcher, - schema_registry=self._schema_registry_manager, - settings=self._settings, - logger=self.logger, - event_metrics=self._event_metrics, - ) - self._consumer = IdempotentConsumerWrapper( - consumer=base_consumer, - idempotency_manager=self._idempotency_manager, - dispatcher=dispatcher, - logger=self.logger, - default_key_strategy="event_based", - default_ttl_seconds=7200, - enable_for_all_handlers=False, - ) - - assert self._consumer is not None - await self._consumer.start(list(topics)) - - self.logger.info(f"Saga consumer started for topics: {topics}") - - async def _handle_event(self, event: DomainEvent) -> None: - """Handle incoming event""" + async def handle_event(self, event: DomainEvent) -> None: + """Handle incoming event.""" self.logger.info(f"Saga orchestrator handling event: type={event.event_type}, id={event.event_id}") try: saga_triggered = False @@ -186,7 +120,7 @@ def _should_trigger_saga(self, saga_class: type[BaseSaga], event: DomainEvent) - return should_trigger async def _start_saga(self, saga_name: str, trigger_event: DomainEvent) -> str | None: - """Start a new saga instance""" + """Start a new saga instance.""" self.logger.info(f"Starting saga {saga_name} for event {trigger_event.event_type}") saga_class = self._sagas.get(saga_name) if not saga_class: @@ -241,7 +175,7 @@ async def _execute_saga( context: SagaContext, trigger_event: DomainEvent, ) -> None: - """Execute saga steps""" + """Execute saga steps.""" tracer = get_tracer() try: # Get saga steps @@ -302,7 +236,7 @@ async def _execute_saga( await self._fail_saga(instance, str(e)) async def _compensate_saga(self, instance: Saga, context: SagaContext) -> None: - """Execute compensation steps""" + """Execute compensation steps.""" self.logger.info(f"Starting compensation for saga {instance.saga_id}") # Only update state if not already cancelled @@ -313,14 +247,18 @@ async def _compensate_saga(self, instance: Saga, context: SagaContext) -> None: # Execute compensations in reverse order for compensation in reversed(context.compensations): try: - self.logger.info(f"Executing compensation: {compensation.name} for saga {instance.saga_id}") + self.logger.info( + f"Executing compensation: {compensation.name} for saga {instance.saga_id}" + ) success = await compensation.compensate(context) if success: instance.compensated_steps.append(compensation.name) else: - self.logger.error(f"Compensation {compensation.name} failed for saga {instance.saga_id}") + self.logger.error( + f"Compensation {compensation.name} failed for saga {instance.saga_id}" + ) except Exception as e: self.logger.error(f"Error in compensation {compensation.name}: {e}", exc_info=True) @@ -336,7 +274,7 @@ async def _compensate_saga(self, instance: Saga, context: SagaContext) -> None: await self._fail_saga(instance, "Saga compensated due to failure") async def _complete_saga(self, instance: Saga) -> None: - """Mark saga as completed""" + """Mark saga as completed.""" instance.state = SagaState.COMPLETED instance.completed_at = datetime.now(UTC) await self._save_saga(instance) @@ -347,7 +285,7 @@ async def _complete_saga(self, instance: Saga) -> None: self.logger.info(f"Saga {instance.saga_id} completed successfully") async def _fail_saga(self, instance: Saga, error_message: str) -> None: - """Mark saga as failed""" + """Mark saga as failed.""" instance.state = SagaState.FAILED instance.error_message = error_message instance.completed_at = datetime.now(UTC) @@ -358,8 +296,8 @@ async def _fail_saga(self, instance: Saga, error_message: str) -> None: self.logger.error(f"Saga {instance.saga_id} failed: {error_message}") - async def _check_timeouts(self) -> None: - """Check for saga timeouts""" + async def check_timeouts_loop(self) -> None: + """Check for saga timeouts (runs until cancelled).""" try: while True: # Check every 30 seconds @@ -383,12 +321,12 @@ async def _check_timeouts(self) -> None: self.logger.info("Timeout checker cancelled") async def _save_saga(self, instance: Saga) -> None: - """Persist saga through repository""" + """Persist saga through repository.""" instance.updated_at = datetime.now(UTC) await self._repo.upsert_saga(instance) async def get_saga_status(self, saga_id: str) -> Saga | None: - """Get saga instance status""" + """Get saga instance status.""" # Check memory first if saga_id in self._running_instances: return self._running_instances[saga_id] @@ -396,19 +334,12 @@ async def get_saga_status(self, saga_id: str) -> Saga | None: return await self._repo.get_saga(saga_id) async def get_execution_sagas(self, execution_id: str) -> list[Saga]: - """Get all sagas for an execution, sorted by created_at descending (newest first)""" + """Get all sagas for an execution, sorted by created_at descending (newest first).""" result = await self._repo.get_sagas_by_execution(execution_id) return result.sagas async def cancel_saga(self, saga_id: str) -> bool: - """Cancel a running saga and trigger compensation. - - Args: - saga_id: The ID of the saga to cancel - - Returns: - True if cancelled successfully, False otherwise - """ + """Cancel a running saga and trigger compensation.""" try: # Get saga instance saga_instance = await self.get_saga_status(saga_id) @@ -499,11 +430,7 @@ async def cancel_saga(self, saga_id: str) -> bool: return False async def _publish_saga_cancelled_event(self, saga_instance: Saga) -> None: - """Publish saga cancelled event. - - Args: - saga_instance: The cancelled saga instance - """ + """Publish saga cancelled event.""" try: cancelled_by = saga_instance.context_data.get("user_id") if saga_instance.context_data else None metadata = EventMetadata( diff --git a/backend/app/services/saga/saga_service.py b/backend/app/services/saga/saga_service.py index 5ed6e4e3..302e2891 100644 --- a/backend/app/services/saga/saga_service.py +++ b/backend/app/services/saga/saga_service.py @@ -9,7 +9,7 @@ ) from app.domain.saga.models import Saga, SagaFilter, SagaListResult from app.schemas_pydantic.user import User -from app.services.saga import SagaOrchestrator +from app.services.saga.saga_logic import SagaLogic class SagaService: @@ -19,12 +19,12 @@ def __init__( self, saga_repo: SagaRepository, execution_repo: ExecutionRepository, - orchestrator: SagaOrchestrator, + saga_logic: SagaLogic, logger: logging.Logger, ): self.saga_repo = saga_repo self.execution_repo = execution_repo - self.orchestrator = orchestrator + self._saga_logic = saga_logic self.logger = logger self.logger.info( @@ -32,7 +32,7 @@ def __init__( extra={ "saga_repo": type(saga_repo).__name__, "execution_repo": type(execution_repo).__name__, - "orchestrator": type(orchestrator).__name__, + "saga_logic": type(saga_logic).__name__, }, ) @@ -137,8 +137,8 @@ async def cancel_saga(self, saga_id: str, user: User) -> bool: if saga.state not in [SagaState.RUNNING, SagaState.CREATED]: raise SagaInvalidStateError(saga_id, str(saga.state), "cancel") - # Use orchestrator to cancel - success = await self.orchestrator.cancel_saga(saga_id) + # Use saga logic to cancel + success = await self._saga_logic.cancel_saga(saga_id) if success: self.logger.info( "User cancelled saga", @@ -163,8 +163,8 @@ async def get_saga_status_from_orchestrator(self, saga_id: str, user: User) -> S """Get saga status from orchestrator with fallback to database.""" self.logger.debug("Getting live saga status", extra={"saga_id": saga_id}) - # Try orchestrator first for live status - saga = await self.orchestrator.get_saga_status(saga_id) + # Try saga logic first for live status + saga = await self._saga_logic.get_saga_status(saga_id) if saga: # Check access if not await self.check_execution_access(saga.execution_id, user): diff --git a/backend/app/services/sse/event_router.py b/backend/app/services/sse/event_router.py new file mode 100644 index 00000000..c2c6ef81 --- /dev/null +++ b/backend/app/services/sse/event_router.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import logging + +from app.domain.enums.events import EventType +from app.domain.events.typed import DomainEvent +from app.events.core import EventDispatcher +from app.services.sse.redis_bus import SSERedisBus + +# Events that should be routed to SSE clients +SSE_RELEVANT_EVENTS: frozenset[EventType] = frozenset([ + EventType.EXECUTION_REQUESTED, + EventType.EXECUTION_QUEUED, + EventType.EXECUTION_STARTED, + EventType.EXECUTION_RUNNING, + EventType.EXECUTION_COMPLETED, + EventType.EXECUTION_FAILED, + EventType.EXECUTION_TIMEOUT, + EventType.EXECUTION_CANCELLED, + EventType.RESULT_STORED, + EventType.POD_CREATED, + EventType.POD_SCHEDULED, + EventType.POD_RUNNING, + EventType.POD_SUCCEEDED, + EventType.POD_FAILED, + EventType.POD_TERMINATED, + EventType.POD_DELETED, +]) + + +class SSEEventRouter: + """Routes domain events to Redis channels for SSE delivery. + + Stateless service that extracts execution_id from events and publishes + them to Redis via SSERedisBus. Each execution_id has its own channel. + """ + + def __init__(self, sse_bus: SSERedisBus, logger: logging.Logger) -> None: + self._sse_bus = sse_bus + self._logger = logger + + async def route_event(self, event: DomainEvent) -> None: + """Route an event to Redis for SSE delivery.""" + data = event.model_dump() + execution_id = data.get("execution_id") + + if not execution_id: + self._logger.debug(f"Event {event.event_type} has no execution_id") + return + + try: + await self._sse_bus.publish_event(execution_id, event) + self._logger.info(f"Published {event.event_type} to Redis for {execution_id}") + except Exception as e: + self._logger.error( + f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", + exc_info=True, + ) + + def register_handlers(self, dispatcher: EventDispatcher) -> None: + """Register routing handlers for all relevant event types.""" + for event_type in SSE_RELEVANT_EVENTS: + dispatcher.register_handler(event_type, self.route_event) diff --git a/backend/app/services/sse/kafka_redis_bridge.py b/backend/app/services/sse/kafka_redis_bridge.py deleted file mode 100644 index 2a6b8c64..00000000 --- a/backend/app/services/sse/kafka_redis_bridge.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging - -from app.core.metrics import EventMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId -from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings - - -class SSEKafkaRedisBridge: - """ - Bridges Kafka events to Redis channels for SSE delivery. - - - Consumes relevant Kafka topics using a small consumer pool - - Deserializes events and publishes them to Redis via SSERedisBus - - Keeps no in-process buffers; delivery to clients is via Redis only - """ - - def __init__( - self, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - sse_bus: SSERedisBus, - logger: logging.Logger, - ) -> None: - self.schema_registry = schema_registry - self.settings = settings - self.event_metrics = event_metrics - self.sse_bus = sse_bus - self.logger = logger - - self.num_consumers = settings.SSE_CONSUMER_POOL_SIZE - self.consumers: list[UnifiedConsumer] = [] - - async def __aenter__(self) -> "SSEKafkaRedisBridge": - """Start the SSE Kafka→Redis bridge.""" - self.logger.info(f"Starting SSE Kafka→Redis bridge with {self.num_consumers} consumers") - - # Phase 1: Build all consumers and track them immediately (no I/O) - self.consumers = [self._build_consumer(i) for i in range(self.num_consumers)] - - # Phase 2: Start all in parallel - already tracked in self.consumers for cleanup - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.WEBSOCKET_GATEWAY]) - await asyncio.gather(*[c.start(topics) for c in self.consumers]) - - self.logger.info("SSE Kafka→Redis bridge started successfully") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the SSE Kafka→Redis bridge.""" - self.logger.info("Stopping SSE Kafka→Redis bridge") - await asyncio.gather(*[c.stop() for c in self.consumers], return_exceptions=True) - self.consumers.clear() - self.logger.info("SSE Kafka→Redis bridge stopped") - - def _build_consumer(self, consumer_index: int) -> UnifiedConsumer: - """Build a consumer instance without starting it.""" - suffix = self.settings.KAFKA_GROUP_SUFFIX - config = ConsumerConfig( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"sse-bridge-pool.{suffix}", - client_id=f"sse-bridge-{consumer_index}.{suffix}", - enable_auto_commit=True, - auto_offset_reset="latest", - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - dispatcher = EventDispatcher(logger=self.logger) - self._register_routing_handlers(dispatcher) - - return UnifiedConsumer( - config=config, - event_dispatcher=dispatcher, - schema_registry=self.schema_registry, - settings=self.settings, - logger=self.logger, - event_metrics=self.event_metrics, - ) - - def _register_routing_handlers(self, dispatcher: EventDispatcher) -> None: - """Publish relevant events to Redis channels keyed by execution_id.""" - relevant_events = [ - EventType.EXECUTION_REQUESTED, - EventType.EXECUTION_QUEUED, - EventType.EXECUTION_STARTED, - EventType.EXECUTION_RUNNING, - EventType.EXECUTION_COMPLETED, - EventType.EXECUTION_FAILED, - EventType.EXECUTION_TIMEOUT, - EventType.EXECUTION_CANCELLED, - EventType.RESULT_STORED, - EventType.POD_CREATED, - EventType.POD_SCHEDULED, - EventType.POD_RUNNING, - EventType.POD_SUCCEEDED, - EventType.POD_FAILED, - EventType.POD_TERMINATED, - EventType.POD_DELETED, - ] - - async def route_event(event: DomainEvent) -> None: - data = event.model_dump() - execution_id = data.get("execution_id") - if not execution_id: - self.logger.debug(f"Event {event.event_type} has no execution_id") - return - try: - await self.sse_bus.publish_event(execution_id, event) - self.logger.info(f"Published {event.event_type} to Redis for {execution_id}") - except Exception as e: - self.logger.error( - f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", - exc_info=True, - ) - - for et in relevant_events: - dispatcher.register_handler(et, route_event) - - def get_stats(self) -> dict[str, int]: - return { - "num_consumers": len(self.consumers), - "active_executions": 0, - "total_buffers": 0, - } diff --git a/backend/app/services/sse/sse_connection_registry.py b/backend/app/services/sse/sse_connection_registry.py new file mode 100644 index 00000000..575d13dc --- /dev/null +++ b/backend/app/services/sse/sse_connection_registry.py @@ -0,0 +1,57 @@ +import asyncio +import logging +from typing import Dict, Set + +from app.core.metrics import ConnectionMetrics + + +class SSEConnectionRegistry: + """ + Tracks active SSE connections. + + Simple registry for connection tracking and metrics. + Shutdown is handled via task cancellation, not explicit shutdown orchestration. + """ + + def __init__( + self, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, + ): + self.logger = logger + self.metrics = connection_metrics + + # Track active connections by execution + self._active_connections: Dict[str, Set[str]] = {} # execution_id -> connection_ids + self._lock = asyncio.Lock() + + self.logger.info("SSEConnectionRegistry initialized") + + async def register_connection(self, execution_id: str, connection_id: str) -> None: + """Register a new SSE connection.""" + async with self._lock: + if execution_id not in self._active_connections: + self._active_connections[execution_id] = set() + + self._active_connections[execution_id].add(connection_id) + self.logger.debug("Registered SSE connection", extra={"connection_id": connection_id}) + self.metrics.increment_sse_connections("executions") + + async def unregister_connection(self, execution_id: str, connection_id: str) -> None: + """Unregister an SSE connection.""" + async with self._lock: + if execution_id in self._active_connections: + self._active_connections[execution_id].discard(connection_id) + if not self._active_connections[execution_id]: + del self._active_connections[execution_id] + + self.logger.debug("Unregistered SSE connection", extra={"connection_id": connection_id}) + self.metrics.decrement_sse_connections("executions") + + def get_connection_count(self) -> int: + """Get total number of active connections.""" + return sum(len(conns) for conns in self._active_connections.values()) + + def get_execution_count(self) -> int: + """Get number of executions with active connections.""" + return len(self._active_connections) diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index 3af70fb0..cf1cfcdf 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -8,7 +8,7 @@ from app.db.repositories.sse_repository import SSERepository from app.domain.enums.events import EventType from app.domain.enums.sse import SSEControlEvent, SSENotificationEvent -from app.domain.sse import SSEHealthDomain +from app.domain.sse import ShutdownStatus, SSEHealthDomain from app.schemas_pydantic.execution import ExecutionResult from app.schemas_pydantic.sse import ( RedisNotificationMessage, @@ -16,9 +16,8 @@ SSEExecutionEventData, SSENotificationEventData, ) -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge from app.services.sse.redis_bus import SSERedisBus -from app.services.sse.sse_shutdown_manager import SSEShutdownManager +from app.services.sse.sse_connection_registry import SSEConnectionRegistry from app.settings import Settings @@ -34,17 +33,17 @@ class SSEService: def __init__( self, repository: SSERepository, - router: SSEKafkaRedisBridge, + num_consumers: int, sse_bus: SSERedisBus, - shutdown_manager: SSEShutdownManager, + connection_registry: SSEConnectionRegistry, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics, ) -> None: self.repository = repository - self.router = router + self._num_consumers = num_consumers self.sse_bus = sse_bus - self.shutdown_manager = shutdown_manager + self.connection_registry = connection_registry self.settings = settings self.logger = logger self.metrics = connection_metrics @@ -53,17 +52,7 @@ def __init__( async def create_execution_stream(self, execution_id: str, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: connection_id = f"sse_{execution_id}_{datetime.now(timezone.utc).timestamp()}" - shutdown_event = await self.shutdown_manager.register_connection(execution_id, connection_id) - if shutdown_event is None: - yield self._format_sse_event( - SSEExecutionEventData( - event_type=SSEControlEvent.ERROR, - execution_id=execution_id, - timestamp=datetime.now(timezone.utc), - error="Server is shutting down", - ) - ) - return + await self.connection_registry.register_connection(execution_id, connection_id) subscription = None try: @@ -108,7 +97,6 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn async for event_data in self._stream_events_redis( execution_id, subscription, - shutdown_event, include_heartbeat=False, ): yield event_data @@ -116,30 +104,18 @@ async def create_execution_stream(self, execution_id: str, user_id: str) -> Asyn finally: if subscription is not None: await asyncio.shield(subscription.close()) - await asyncio.shield(self.shutdown_manager.unregister_connection(execution_id, connection_id)) + await asyncio.shield(self.connection_registry.unregister_connection(execution_id, connection_id)) self.logger.info("SSE connection closed", extra={"execution_id": execution_id}) async def _stream_events_redis( self, execution_id: str, subscription: Any, - shutdown_event: asyncio.Event, include_heartbeat: bool = True, ) -> AsyncGenerator[Dict[str, Any], None]: + """Stream events from Redis subscription until terminal event or cancellation.""" last_heartbeat = datetime.now(timezone.utc) while True: - if shutdown_event.is_set(): - yield self._format_sse_event( - SSEExecutionEventData( - event_type=SSEControlEvent.SHUTDOWN, - execution_id=execution_id, - timestamp=datetime.now(timezone.utc), - message="Server is shutting down", - grace_period=30, - ) - ) - break - now = datetime.now(timezone.utc) if include_heartbeat and (now - last_heartbeat).total_seconds() >= self.heartbeat_interval: yield self._format_sse_event( @@ -196,6 +172,7 @@ async def _build_sse_event_from_redis(self, execution_id: str, msg: RedisSSEMess ) async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[str, Any], None]: + """Stream notifications until cancelled.""" subscription = None try: @@ -224,7 +201,7 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[ ) last_heartbeat = datetime.now(timezone.utc) - while not self.shutdown_manager.is_shutting_down(): + while True: # Heartbeat now = datetime.now(timezone.utc) if (now - last_heartbeat).total_seconds() >= self.heartbeat_interval: @@ -259,15 +236,22 @@ async def create_notification_stream(self, user_id: str) -> AsyncGenerator[Dict[ await asyncio.shield(subscription.close()) async def get_health_status(self) -> SSEHealthDomain: - router_stats = self.router.get_stats() + """Get SSE service health status.""" + active_connections = self.connection_registry.get_connection_count() return SSEHealthDomain( - status="draining" if self.shutdown_manager.is_shutting_down() else "healthy", + status="healthy", kafka_enabled=True, - active_connections=router_stats["active_executions"], - active_executions=router_stats["active_executions"], - active_consumers=router_stats["num_consumers"], + active_connections=active_connections, + active_executions=self.connection_registry.get_execution_count(), + active_consumers=self._num_consumers, max_connections_per_user=5, - shutdown=self.shutdown_manager.get_shutdown_status(), + shutdown=ShutdownStatus( + phase="ready", + initiated=False, + complete=False, + active_connections=active_connections, + draining_connections=0, + ), timestamp=datetime.now(timezone.utc), ) diff --git a/backend/app/services/sse/sse_shutdown_manager.py b/backend/app/services/sse/sse_shutdown_manager.py deleted file mode 100644 index 63f22eb4..00000000 --- a/backend/app/services/sse/sse_shutdown_manager.py +++ /dev/null @@ -1,300 +0,0 @@ -import asyncio -import logging -import time -from enum import Enum -from typing import Any, Dict, Set - -from app.core.metrics import ConnectionMetrics -from app.domain.sse import ShutdownStatus - - -class ShutdownPhase(Enum): - """Phases of SSE shutdown process""" - - READY = "ready" - NOTIFYING = "notifying" # Notify connections of impending shutdown - DRAINING = "draining" # Wait for connections to close gracefully - CLOSING = "closing" # Force close remaining connections - COMPLETE = "complete" - - -class SSEShutdownManager: - """ - Manages graceful shutdown of SSE connections. - - Works alongside the SSEKafkaRedisBridge to: - - Track active SSE connections - - Notify clients about shutdown - - Coordinate graceful disconnection - - Ensure clean resource cleanup - - The router handles Kafka consumer shutdown while this - manager handles SSE client connection lifecycle. - """ - - def __init__( - self, - router: Any, - logger: logging.Logger, - connection_metrics: ConnectionMetrics, - drain_timeout: float = 30.0, - notification_timeout: float = 5.0, - force_close_timeout: float = 10.0, - ): - self._router = router - self.logger = logger - self.drain_timeout = drain_timeout - self.notification_timeout = notification_timeout - self.force_close_timeout = force_close_timeout - self.metrics = connection_metrics - - self._phase = ShutdownPhase.READY - self._shutdown_initiated = False - self._shutdown_complete = False - self._shutdown_start_time: float | None = None - - # Track active connections by execution - self._active_connections: Dict[str, Set[str]] = {} # execution_id -> connection_ids - self._connection_callbacks: Dict[str, asyncio.Event] = {} # connection_id -> shutdown event - self._draining_connections: Set[str] = set() - - # Synchronization - self._lock = asyncio.Lock() - self._shutdown_event = asyncio.Event() - self._drain_complete_event = asyncio.Event() - - # Phase transition events for external coordination - self.initiated_event = asyncio.Event() # Set when shutdown initiated - self.notifying_event = asyncio.Event() # Set when entering notifying phase - - self.logger.info( - "SSEShutdownManager initialized", - extra={"drain_timeout": drain_timeout, "notification_timeout": notification_timeout}, - ) - - async def register_connection(self, execution_id: str, connection_id: str) -> asyncio.Event | None: - """ - Register a new SSE connection. - - Returns: - Shutdown event for the connection to monitor, or None if rejected - """ - async with self._lock: - if self._shutdown_initiated: - self.logger.warning( - "Rejecting new SSE connection during shutdown", - extra={"execution_id": execution_id, "connection_id": connection_id}, - ) - return None - - if execution_id not in self._active_connections: - self._active_connections[execution_id] = set() - - self._active_connections[execution_id].add(connection_id) - - # Create shutdown event for this connection - shutdown_event = asyncio.Event() - self._connection_callbacks[connection_id] = shutdown_event - - self.logger.debug("Registered SSE connection", extra={"connection_id": connection_id}) - self.metrics.increment_sse_connections("executions") - - return shutdown_event - - async def unregister_connection(self, execution_id: str, connection_id: str) -> None: - """Unregister an SSE connection""" - async with self._lock: - if execution_id in self._active_connections: - self._active_connections[execution_id].discard(connection_id) - if not self._active_connections[execution_id]: - del self._active_connections[execution_id] - - self._connection_callbacks.pop(connection_id, None) - self._draining_connections.discard(connection_id) - - self.logger.debug("Unregistered SSE connection", extra={"connection_id": connection_id}) - self.metrics.decrement_sse_connections("executions") - - # Check if all connections are drained - if self._shutdown_initiated and not self._active_connections: - self._drain_complete_event.set() - - async def initiate_shutdown(self) -> None: - """Initiate graceful shutdown of all SSE connections""" - async with self._lock: - if self._shutdown_initiated: - self.logger.warning("SSE shutdown already initiated") - return - - self._shutdown_initiated = True - self._shutdown_start_time = time.time() - self._phase = ShutdownPhase.DRAINING - - total_connections = sum(len(conns) for conns in self._active_connections.values()) - self.logger.info(f"Initiating SSE shutdown with {total_connections} active connections") - - self.metrics.update_sse_draining_connections(total_connections) - - # Start shutdown process - self._shutdown_event.set() - - # Execute shutdown phases - try: - await self._execute_shutdown() - except Exception as e: - self.logger.error(f"Error during SSE shutdown: {e}") - raise - finally: - self._shutdown_complete = True - self._phase = ShutdownPhase.COMPLETE - - async def _execute_shutdown(self) -> None: - """Execute the shutdown process in phases""" - - # Phase 1: Stop accepting new connections (already done by setting _shutdown_initiated) - phase_start = time.time() - self.logger.info("Phase 1: Stopped accepting new SSE connections") - - # Phase 2: Notify connections about shutdown - await self._notify_connections() - self.metrics.update_sse_shutdown_duration(time.time() - phase_start, "notify") - - # Phase 3: Drain connections gracefully - phase_start = time.time() - await self._drain_connections() - self.metrics.update_sse_shutdown_duration(time.time() - phase_start, "drain") - - # Phase 4: Force close remaining connections - phase_start = time.time() - await self._force_close_connections() - self.metrics.update_sse_shutdown_duration(time.time() - phase_start, "force_close") - - # Total shutdown duration - if self._shutdown_start_time is not None: - total_duration = time.time() - self._shutdown_start_time - self.metrics.update_sse_shutdown_duration(total_duration, "total") - self.logger.info(f"SSE shutdown complete in {total_duration:.2f}s") - else: - self.logger.info("SSE shutdown complete") - - async def _notify_connections(self) -> None: - """Notify all active connections about shutdown""" - self._phase = ShutdownPhase.NOTIFYING - - async with self._lock: - active_count = sum(len(conns) for conns in self._active_connections.values()) - connection_events = list(self._connection_callbacks.values()) - self._draining_connections = set(self._connection_callbacks.keys()) - - self.logger.info(f"Phase 2: Notifying {active_count} connections about shutdown") - self.metrics.update_sse_draining_connections(active_count) - - # Trigger shutdown events for all connections - # The connections will see this and send shutdown message to clients - for event in connection_events: - event.set() - - # Give connections time to send shutdown messages - await asyncio.sleep(self.notification_timeout) - - self.logger.info("Shutdown notification phase complete") - - async def _drain_connections(self) -> None: - """Wait for connections to close gracefully""" - self._phase = ShutdownPhase.DRAINING - - async with self._lock: - remaining = sum(len(conns) for conns in self._active_connections.values()) - - self.logger.info(f"Phase 3: Draining {remaining} connections (timeout: {self.drain_timeout}s)") - self.metrics.update_sse_draining_connections(remaining) - - start_time = time.time() - check_interval = 0.5 - last_count = remaining - - while remaining > 0 and (time.time() - start_time) < self.drain_timeout: - # Wait for connections to close - try: - await asyncio.wait_for(self._drain_complete_event.wait(), timeout=check_interval) - break # All connections drained - except asyncio.TimeoutError: - pass - - # Update metrics - async with self._lock: - remaining = sum(len(conns) for conns in self._active_connections.values()) - - if remaining < last_count: - self.logger.info(f"Connections remaining: {remaining}") - self.metrics.update_sse_draining_connections(remaining) - last_count = remaining - - if remaining == 0: - self.logger.info("All connections drained gracefully") - else: - self.logger.warning(f"{remaining} connections still active after drain timeout") - - async def _force_close_connections(self) -> None: - """Force close any remaining connections""" - self._phase = ShutdownPhase.CLOSING - - async with self._lock: - remaining_count = sum(len(conns) for conns in self._active_connections.values()) - - if remaining_count == 0: - self.logger.info("Phase 4: No connections to force close") - return - - self.logger.warning(f"Phase 4: Force closing {remaining_count} connections") - self.metrics.update_sse_draining_connections(remaining_count) - - # Clear all tracking - connections will be forcibly terminated - self._active_connections.clear() - self._connection_callbacks.clear() - self._draining_connections.clear() - - # Router lifecycle is managed by DI container - - self.metrics.update_sse_draining_connections(0) - self.logger.info("Force close phase complete") - - def is_shutting_down(self) -> bool: - """Check if shutdown is in progress""" - return self._shutdown_initiated - - def get_shutdown_status(self) -> ShutdownStatus: - """Get current shutdown status""" - duration = None - if self._shutdown_start_time: - duration = time.time() - self._shutdown_start_time - - return ShutdownStatus( - phase=self._phase.value, - initiated=self._shutdown_initiated, - complete=self._shutdown_complete, - active_connections=sum(len(conns) for conns in self._active_connections.values()), - draining_connections=len(self._draining_connections), - duration=duration, - ) - - async def wait_for_shutdown(self, timeout: float | None = None) -> bool: - """ - Wait for shutdown to complete. - - Returns: - True if shutdown completed, False if timeout - """ - if not self._shutdown_initiated: - return True - - try: - await asyncio.wait_for(self._wait_for_complete(), timeout=timeout) - return True - except asyncio.TimeoutError: - return False - - async def _wait_for_complete(self) -> None: - """Wait for shutdown to complete""" - while not self._shutdown_complete: - await asyncio.sleep(0.1) diff --git a/backend/tests/e2e/test_k8s_worker_create_pod.py b/backend/tests/e2e/test_k8s_worker_create_pod.py index c43bb2e5..5d95a931 100644 --- a/backend/tests/e2e/test_k8s_worker_create_pod.py +++ b/backend/tests/e2e/test_k8s_worker_create_pod.py @@ -5,11 +5,8 @@ from app.core.metrics import EventMetrics from app.domain.events.typed import CreatePodCommandEvent, EventMetadata from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager from app.services.k8s_worker.config import K8sWorkerConfig -from app.services.k8s_worker.worker import KubernetesWorker +from app.services.k8s_worker.worker_logic import K8sWorkerLogic from app.settings import Settings from dishka import AsyncContainer from kubernetes.client.rest import ApiException @@ -25,27 +22,27 @@ async def test_worker_creates_configmap_and_pod( ) -> None: ns = test_settings.K8S_NAMESPACE - schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) - store: EventStore = await scope.get(EventStore) producer: UnifiedProducer = await scope.get(UnifiedProducer) - idem: IdempotencyManager = await scope.get(IdempotencyManager) event_metrics: EventMetrics = await scope.get(EventMetrics) cfg = K8sWorkerConfig(namespace=ns, max_concurrent_pods=1) - worker = KubernetesWorker( + logic = K8sWorkerLogic( config=cfg, producer=producer, - schema_registry_manager=schema, settings=test_settings, - event_store=store, - idempotency_manager=idem, logger=_test_logger, event_metrics=event_metrics, ) - # Initialize k8s clients using worker's own method - worker._initialize_kubernetes_client() # noqa: SLF001 - if worker.v1 is None: + # Initialize k8s clients using logic's own method + try: + logic.initialize() + except RuntimeError as e: + if "default" in str(e): + pytest.skip("K8S_NAMESPACE is set to 'default', which is forbidden") + raise + + if logic.v1 is None: pytest.skip("Kubernetes cluster not available") exec_id = uuid.uuid4().hex[:8] @@ -68,27 +65,27 @@ async def test_worker_creates_configmap_and_pod( ) # Build and create ConfigMap + Pod - cm = worker.pod_builder.build_config_map( + cm = logic.pod_builder.build_config_map( command=cmd, script_content=cmd.script, - entrypoint_content=await worker._get_entrypoint_script(), # noqa: SLF001 + entrypoint_content=await logic._get_entrypoint_script(), # noqa: SLF001 ) try: - await worker._create_config_map(cm) # noqa: SLF001 + await logic._create_config_map(cm) # noqa: SLF001 except ApiException as e: if e.status in (403, 404): pytest.skip(f"Insufficient permissions or namespace not found: {e}") raise - pod = worker.pod_builder.build_pod_manifest(cmd) - await worker._create_pod(pod) # noqa: SLF001 + pod = logic.pod_builder.build_pod_manifest(cmd) + await logic._create_pod(pod) # noqa: SLF001 # Verify resources exist - got_cm = worker.v1.read_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) + got_cm = logic.v1.read_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) assert got_cm is not None - got_pod = worker.v1.read_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) + got_pod = logic.v1.read_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) assert got_pod is not None # Cleanup - worker.v1.delete_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) - worker.v1.delete_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) + logic.v1.delete_namespaced_pod(name=f"executor-{exec_id}", namespace=ns) + logic.v1.delete_namespaced_config_map(name=f"script-{exec_id}", namespace=ns) diff --git a/backend/tests/integration/events/test_consume_roundtrip.py b/backend/tests/integration/events/test_consume_roundtrip.py index 40b0d490..ff8d5554 100644 --- a/backend/tests/integration/events/test_consume_roundtrip.py +++ b/backend/tests/integration/events/test_consume_roundtrip.py @@ -50,8 +50,11 @@ async def _handle(_event: DomainEvent) -> None: settings=test_settings, logger=_test_logger, event_metrics=event_metrics, + topics=[KafkaTopic.EXECUTION_EVENTS], ) - await consumer.start([KafkaTopic.EXECUTION_EVENTS]) + + # Start consumer as background task + consumer_task = asyncio.create_task(consumer.run()) try: # Produce a request event @@ -62,4 +65,6 @@ async def _handle(_event: DomainEvent) -> None: # Wait for the handler to be called await asyncio.wait_for(received.wait(), timeout=10.0) finally: - await consumer.stop() + consumer_task.cancel() + with pytest.raises(asyncio.CancelledError): + await consumer_task diff --git a/backend/tests/integration/events/test_consumer_lifecycle.py b/backend/tests/integration/events/test_consumer_lifecycle.py index 8272e772..5ee140bf 100644 --- a/backend/tests/integration/events/test_consumer_lifecycle.py +++ b/backend/tests/integration/events/test_consumer_lifecycle.py @@ -1,3 +1,4 @@ +import asyncio import logging import pytest @@ -15,28 +16,41 @@ @pytest.mark.asyncio -async def test_consumer_start_status_seek_and_stop( +async def test_consumer_run_and_cancel( schema_registry: SchemaRegistryManager, event_metrics: EventMetrics, consumer_config: ConsumerConfig, test_settings: Settings, ) -> None: + """Test consumer run() blocks until cancelled and seek methods work.""" disp = EventDispatcher(logger=_test_logger) - c = UnifiedConsumer( + consumer = UnifiedConsumer( consumer_config, - event_dispatcher=disp, + dispatcher=disp, schema_registry=schema_registry, settings=test_settings, logger=_test_logger, event_metrics=event_metrics, + topics=[KafkaTopic.EXECUTION_EVENTS], ) - await c.start([KafkaTopic.EXECUTION_EVENTS]) + + # Track when consumer is running + consumer_started = asyncio.Event() + + async def run_with_signal() -> None: + consumer_started.set() + await consumer.run() + + task = asyncio.create_task(run_with_signal()) + try: - st = c.get_status() - assert st.state == "running" - # Exercise seek functions; don't force specific partition offsets - await c.seek_to_beginning() - await c.seek_to_end() - # No need to sleep; just ensure we can call seek APIs while running + # Wait for consumer to start + await asyncio.wait_for(consumer_started.wait(), timeout=5.0) + + # Exercise seek functions while consumer is running + await consumer.seek_to_beginning() + await consumer.seek_to_end() finally: - await c.stop() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task diff --git a/backend/tests/integration/events/test_event_dispatcher.py b/backend/tests/integration/events/test_event_dispatcher.py index 0940a88e..92d6544e 100644 --- a/backend/tests/integration/events/test_event_dispatcher.py +++ b/backend/tests/integration/events/test_event_dispatcher.py @@ -53,8 +53,11 @@ async def h2(_e: DomainEvent) -> None: settings=test_settings, logger=_test_logger, event_metrics=event_metrics, + topics=[KafkaTopic.EXECUTION_EVENTS], ) - await consumer.start([KafkaTopic.EXECUTION_EVENTS]) + + # Start consumer as background task + consumer_task = asyncio.create_task(consumer.run()) # Produce a request event via DI producer: UnifiedProducer = await scope.get(UnifiedProducer) @@ -64,4 +67,6 @@ async def h2(_e: DomainEvent) -> None: try: await asyncio.wait_for(asyncio.gather(h1_called.wait(), h2_called.wait()), timeout=10.0) finally: - await consumer.stop() + consumer_task.cancel() + with pytest.raises(asyncio.CancelledError): + await consumer_task diff --git a/backend/tests/integration/idempotency/test_consumer_idempotent.py b/backend/tests/integration/idempotency/test_consumer_idempotent.py index 658a553e..fe8dfc5a 100644 --- a/backend/tests/integration/idempotency/test_consumer_idempotent.py +++ b/backend/tests/integration/idempotency/test_consumer_idempotent.py @@ -61,11 +61,12 @@ async def handle(_ev: DomainEvent) -> None: # Real consumer with idempotent wrapper base = UnifiedConsumer( consumer_config, - event_dispatcher=disp, + dispatcher=disp, schema_registry=schema_registry, settings=test_settings, logger=_test_logger, event_metrics=event_metrics, + topics=[KafkaTopic.EXECUTION_EVENTS], ) wrapper = IdempotentConsumerWrapper( consumer=base, @@ -76,10 +77,14 @@ async def handle(_ev: DomainEvent) -> None: logger=_test_logger, ) - await wrapper.start([KafkaTopic.EXECUTION_EVENTS]) + # Start wrapper as background task + wrapper_task = asyncio.create_task(wrapper.run()) + try: # Await the future directly - true async, no polling await asyncio.wait_for(handled_future, timeout=10.0) assert seen["n"] >= 1 finally: - await wrapper.stop() + wrapper_task.cancel() + with pytest.raises(asyncio.CancelledError): + await wrapper_task diff --git a/backend/tests/integration/idempotency/test_decorator_idempotent.py b/backend/tests/integration/idempotency/test_decorator_idempotent.py deleted file mode 100644 index 65e5b8b8..00000000 --- a/backend/tests/integration/idempotency/test_decorator_idempotent.py +++ /dev/null @@ -1,52 +0,0 @@ -import logging - -import pytest -from app.domain.events.typed import DomainEvent -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.services.idempotency.middleware import idempotent_handler -from dishka import AsyncContainer - -from tests.helpers import make_execution_requested_event - -_test_logger = logging.getLogger("test.idempotency.decorator_idempotent") - - -pytestmark = [pytest.mark.integration] - - -@pytest.mark.asyncio -async def test_decorator_blocks_duplicate_event(scope: AsyncContainer) -> None: - idm: IdempotencyManager = await scope.get(IdempotencyManager) - - calls = {"n": 0} - - @idempotent_handler(idempotency_manager=idm, key_strategy="event_based", logger=_test_logger) - async def h(ev: DomainEvent) -> None: - calls["n"] += 1 - - ev = make_execution_requested_event(execution_id="exec-deco-1") - - await h(ev) - await h(ev) # duplicate - assert calls["n"] == 1 - - -@pytest.mark.asyncio -async def test_decorator_custom_key_blocks(scope: AsyncContainer) -> None: - idm: IdempotencyManager = await scope.get(IdempotencyManager) - - calls = {"n": 0} - - def fixed_key(_ev: DomainEvent) -> str: - return "fixed-key" - - @idempotent_handler(idempotency_manager=idm, key_strategy="custom", custom_key_func=fixed_key, logger=_test_logger) - async def h(ev: DomainEvent) -> None: - calls["n"] += 1 - - e1 = make_execution_requested_event(execution_id="exec-deco-2a") - e2 = make_execution_requested_event(execution_id="exec-deco-2b") - - await h(e1) - await h(e2) # different event ids but same custom key - assert calls["n"] == 1 diff --git a/backend/tests/integration/idempotency/test_idempotency.py b/backend/tests/integration/idempotency/test_idempotency.py index 69ff00cb..25f60111 100644 --- a/backend/tests/integration/idempotency/test_idempotency.py +++ b/backend/tests/integration/idempotency/test_idempotency.py @@ -11,7 +11,7 @@ from app.domain.events.typed import DomainEvent from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus from app.services.idempotency.idempotency_manager import IdempotencyConfig, IdempotencyManager -from app.services.idempotency.middleware import IdempotentEventHandler, idempotent_handler +from app.services.idempotency.middleware import IdempotentEventHandler from app.services.idempotency.redis_repository import RedisIdempotencyRepository from tests.helpers import make_execution_requested_event @@ -338,38 +338,6 @@ async def on_duplicate(event: DomainEvent, result: Any) -> None: assert duplicate_events[0][0] == real_event assert duplicate_events[0][1].is_duplicate is True - @pytest.mark.asyncio - async def test_decorator_integration(self, idempotency_manager: IdempotencyManager) -> None: - """Test the @idempotent_handler decorator""" - processed_events: list[DomainEvent] = [] - - @idempotent_handler( - idempotency_manager=idempotency_manager, - key_strategy="content_hash", - ttl_seconds=300, - logger=_test_logger, - ) - async def my_handler(event: DomainEvent) -> None: - processed_events.append(event) - - # Process same event twice - real_event = make_execution_requested_event(execution_id="decor-1") - await my_handler(real_event) - await my_handler(real_event) - - # Should only process once - assert len(processed_events) == 1 - - # Create event with same ID and same content for content hash match - similar_event = make_execution_requested_event( - execution_id=real_event.execution_id, - script=real_event.script, - ) - - # Should still be blocked (content hash) - await my_handler(similar_event) - assert len(processed_events) == 1 # Still only one - @pytest.mark.asyncio async def test_custom_key_function(self, idempotency_manager: IdempotencyManager) -> None: """Test handler with custom key function""" diff --git a/backend/tests/integration/result_processor/test_result_processor.py b/backend/tests/integration/result_processor/test_result_processor.py index 2e62554f..df31989a 100644 --- a/backend/tests/integration/result_processor/test_result_processor.py +++ b/backend/tests/integration/result_processor/test_result_processor.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid import pytest from app.core.database_context import Database @@ -7,7 +8,7 @@ from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus -from app.domain.enums.kafka import KafkaTopic +from app.domain.enums.kafka import GroupId, KafkaTopic from app.domain.events.typed import EventMetadata, ExecutionCompletedEvent, ResultStoredEvent from app.domain.execution import DomainExecutionCreate from app.domain.execution.models import ResourceUsageDomain @@ -15,7 +16,8 @@ from app.events.core.dispatcher import EventDispatcher from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.idempotency import IdempotencyManager -from app.services.result_processor.processor import ResultProcessor +from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.services.result_processor import ProcessorLogic from app.settings import Settings from dishka import AsyncContainer @@ -49,7 +51,7 @@ async def test_result_processor_persists_and_emits( producer: UnifiedProducer = await scope.get(UnifiedProducer) idem: IdempotencyManager = await scope.get(IdempotencyManager) - # Create a base execution to satisfy ResultProcessor lookup + # Create a base execution to satisfy ProcessorLogic lookup created = await repo.create_execution(DomainExecutionCreate( script="print('x')", user_id="u1", @@ -59,34 +61,79 @@ async def test_result_processor_persists_and_emits( )) execution_id = created.execution_id - # Build and start the processor - processor = ResultProcessor( + # Build the ProcessorLogic and wire up the consumer + logic = ProcessorLogic( execution_repo=repo, producer=producer, - schema_registry=schema_registry, settings=test_settings, - idempotency_manager=idem, logger=_test_logger, execution_metrics=execution_metrics, + ) + + # Create dispatcher and register handlers + processor_dispatcher = EventDispatcher(logger=_test_logger) + logic.register_handlers(processor_dispatcher) + + # Create consumer config with unique group id + processor_consumer_config = ConsumerConfig( + bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"{GroupId.RESULT_PROCESSOR}.test.{uuid.uuid4().hex[:8]}", + max_poll_records=1, + enable_auto_commit=True, + auto_offset_reset="earliest", + session_timeout_ms=test_settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=test_settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=test_settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=test_settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + # Create processor consumer + processor_consumer = UnifiedConsumer( + processor_consumer_config, + dispatcher=processor_dispatcher, + schema_registry=schema_registry, + settings=test_settings, + logger=_test_logger, event_metrics=event_metrics, + topics=[KafkaTopic.EXECUTION_COMPLETED, KafkaTopic.EXECUTION_FAILED, KafkaTopic.EXECUTION_TIMEOUT], + ) + + # Wrap with idempotency + processor_wrapper = IdempotentConsumerWrapper( + consumer=processor_consumer, + dispatcher=processor_dispatcher, + idempotency_manager=idem, + logger=_test_logger, + default_key_strategy="content_hash", + default_ttl_seconds=7200, + enable_for_all_handlers=True, ) - # Setup a small consumer to capture ResultStoredEvent - dispatcher = EventDispatcher(logger=_test_logger) + # Setup a separate consumer to capture ResultStoredEvent + stored_dispatcher = EventDispatcher(logger=_test_logger) stored_received = asyncio.Event() - @dispatcher.register(EventType.RESULT_STORED) + @stored_dispatcher.register(EventType.RESULT_STORED) async def _stored(event: ResultStoredEvent) -> None: if event.execution_id == execution_id: stored_received.set() + stored_consumer_config = ConsumerConfig( + bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"test.result_stored.{uuid.uuid4().hex[:8]}", + max_poll_records=1, + enable_auto_commit=True, + auto_offset_reset="earliest", + ) + stored_consumer = UnifiedConsumer( - consumer_config, - dispatcher, + stored_consumer_config, + stored_dispatcher, schema_registry=schema_registry, settings=test_settings, logger=_test_logger, event_metrics=event_metrics, + topics=[KafkaTopic.EXECUTION_RESULTS], ) # Produce the event BEFORE starting consumers (auto_offset_reset="earliest" will read it) @@ -106,19 +153,29 @@ async def _stored(event: ResultStoredEvent) -> None: ) await producer.produce(evt, key=execution_id) - # Start consumers after producing - await stored_consumer.start([KafkaTopic.EXECUTION_RESULTS]) + # Start consumers as background tasks + processor_task = asyncio.create_task(processor_wrapper.run()) + stored_task = asyncio.create_task(stored_consumer.run()) try: - async with processor: - # Await the ResultStoredEvent - signals that processing is complete - await asyncio.wait_for(stored_received.wait(), timeout=12.0) - - # Now verify DB persistence - should be done since event was emitted - doc = await db.get_collection("executions").find_one({"execution_id": execution_id}) - assert doc is not None, f"Execution {execution_id} not found in DB after ResultStoredEvent" - assert doc.get("status") == ExecutionStatus.COMPLETED, ( - f"Expected COMPLETED status, got {doc.get('status')}" - ) + # Await the ResultStoredEvent - signals that processing is complete + await asyncio.wait_for(stored_received.wait(), timeout=12.0) + + # Now verify DB persistence - should be done since event was emitted + doc = await db.get_collection("executions").find_one({"execution_id": execution_id}) + assert doc is not None, f"Execution {execution_id} not found in DB after ResultStoredEvent" + assert doc.get("status") == ExecutionStatus.COMPLETED, ( + f"Expected COMPLETED status, got {doc.get('status')}" + ) finally: - await stored_consumer.stop() + # Cancel and cleanup both consumers + processor_task.cancel() + stored_task.cancel() + try: + await processor_task + except asyncio.CancelledError: + pass + try: + await stored_task + except asyncio.CancelledError: + pass diff --git a/backend/tests/integration/services/coordinator/test_execution_coordinator.py b/backend/tests/integration/services/coordinator/test_execution_coordinator.py index c3d3ed61..c45e300f 100644 --- a/backend/tests/integration/services/coordinator/test_execution_coordinator.py +++ b/backend/tests/integration/services/coordinator/test_execution_coordinator.py @@ -1,5 +1,5 @@ import pytest -from app.services.coordinator.coordinator import ExecutionCoordinator +from app.services.coordinator.coordinator_logic import CoordinatorLogic from dishka import AsyncContainer from tests.helpers import make_execution_requested_event @@ -8,11 +8,11 @@ @pytest.mark.asyncio async def test_handle_requested_and_schedule(scope: AsyncContainer) -> None: - coord: ExecutionCoordinator = await scope.get(ExecutionCoordinator) + logic: CoordinatorLogic = await scope.get(CoordinatorLogic) ev = make_execution_requested_event(execution_id="e-real-1") # Handler now schedules immediately - no polling needed - await coord._handle_execution_requested(ev) # noqa: SLF001 + await logic._handle_execution_requested(ev) # noqa: SLF001 # Execution should be active immediately after handler returns - assert "e-real-1" in coord._active_executions # noqa: SLF001 + assert "e-real-1" in logic._active_executions # noqa: SLF001 diff --git a/backend/tests/integration/services/sse/test_partitioned_event_router.py b/backend/tests/integration/services/sse/test_partitioned_event_router.py index af620341..fd8b046f 100644 --- a/backend/tests/integration/services/sse/test_partitioned_event_router.py +++ b/backend/tests/integration/services/sse/test_partitioned_event_router.py @@ -3,69 +3,39 @@ from uuid import uuid4 import pytest -from app.core.metrics import EventMetrics from app.events.core import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager from app.schemas_pydantic.sse import RedisSSEMessage -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.services.sse.event_router import SSEEventRouter from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings from tests.helpers import make_execution_requested_event pytestmark = [pytest.mark.integration, pytest.mark.redis] -_test_logger = logging.getLogger("test.services.sse.partitioned_event_router_integration") +_test_logger = logging.getLogger("test.services.sse.event_router_integration") @pytest.mark.asyncio -async def test_router_bridges_to_redis( +async def test_event_router_bridges_to_redis( sse_redis_bus: SSERedisBus, - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - test_settings: Settings, ) -> None: - router = SSEKafkaRedisBridge( - schema_registry=schema_registry, - settings=test_settings, - event_metrics=event_metrics, - sse_bus=sse_redis_bus, - logger=_test_logger, - ) + """Test that SSEEventRouter routes events to Redis correctly.""" + router = SSEEventRouter(sse_bus=sse_redis_bus, logger=_test_logger) + + # Register handlers with dispatcher disp = EventDispatcher(logger=_test_logger) - router._register_routing_handlers(disp) + router.register_handlers(disp) # Open Redis subscription for our execution id execution_id = f"e-{uuid4().hex[:8]}" subscription = await sse_redis_bus.open_subscription(execution_id) + # Create and route an event ev = make_execution_requested_event(execution_id=execution_id) handler = disp.get_handlers(ev.event_type)[0] await handler(ev) - # Await the subscription directly - true async, no polling + # Await the subscription - verify event arrived in Redis msg = await asyncio.wait_for(subscription.get(RedisSSEMessage), timeout=2.0) assert msg is not None assert str(msg.event_type) == str(ev.event_type) - - -@pytest.mark.asyncio -async def test_router_start_and_stop( - sse_redis_bus: SSERedisBus, - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - test_settings: Settings, -) -> None: - test_settings.SSE_CONSUMER_POOL_SIZE = 1 - router = SSEKafkaRedisBridge( - schema_registry=schema_registry, - settings=test_settings, - event_metrics=event_metrics, - sse_bus=sse_redis_bus, - logger=_test_logger, - ) - - async with router: - assert router.get_stats()["num_consumers"] == 1 - - assert router.get_stats()["num_consumers"] == 0 diff --git a/backend/tests/unit/services/idempotency/test_middleware.py b/backend/tests/unit/services/idempotency/test_middleware.py index e3f69ece..4d0e6b2f 100644 --- a/backend/tests/unit/services/idempotency/test_middleware.py +++ b/backend/tests/unit/services/idempotency/test_middleware.py @@ -42,7 +42,6 @@ def idempotent_event_handler( idempotency_manager=mock_idempotency_manager, key_strategy="event_based", ttl_seconds=3600, - cache_result=True, logger=_test_logger ) diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index d775fd94..bfbf5e50 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -146,24 +146,36 @@ def make_pod_monitor( @pytest.mark.asyncio -async def test_start_and_stop_lifecycle( +async def test_run_and_cancel_lifecycle( event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, pod_monitor_config: PodMonitorConfig, ) -> None: + """Test that run() blocks until cancelled and cleans up on cancellation.""" pod_monitor_config.enable_state_reconciliation = False spy = SpyMapper() pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, event_mapper=spy) # type: ignore[arg-type] - # Replace _watch_loop to avoid real watch loop - async def _quick_watch() -> None: - return None + # Track when watch_loop is entered + watch_started = asyncio.Event() - pm._watch_loop = _quick_watch # type: ignore[method-assign] + async def _blocking_watch() -> None: + watch_started.set() + await asyncio.sleep(10) - async with pm: - assert pm._watch_task is not None + pm._watch_loop = _blocking_watch # type: ignore[method-assign] + # Start run() as a task + task = asyncio.create_task(pm.run()) + + # Wait until we're actually in the watch loop + await asyncio.wait_for(watch_started.wait(), timeout=1.0) + + # Cancel it - run() catches CancelledError and exits gracefully + task.cancel() + await task # Should complete without raising (graceful shutdown) + + # Verify cleanup happened assert spy.cleared is True @@ -470,8 +482,9 @@ async def mock_run_watch() -> None: pm._run_watch = mock_run_watch # type: ignore[method-assign] - # watch_loop catches CancelledError and exits gracefully (doesn't propagate) - await pm._watch_loop() + # watch_loop propagates CancelledError (correct behavior for structured concurrency) + with pytest.raises(asyncio.CancelledError): + await pm._watch_loop() assert len(watch_count) == 3 @@ -500,8 +513,9 @@ async def mock_backoff() -> None: pm._run_watch = mock_run_watch # type: ignore[method-assign] pm._backoff = mock_backoff # type: ignore[method-assign] - # watch_loop catches CancelledError and exits gracefully - await pm._watch_loop() + # watch_loop propagates CancelledError + with pytest.raises(asyncio.CancelledError): + await pm._watch_loop() # Resource version should be reset on 410 assert pm._last_resource_version is None @@ -532,18 +546,19 @@ async def mock_backoff() -> None: pm._run_watch = mock_run_watch # type: ignore[method-assign] pm._backoff = mock_backoff # type: ignore[method-assign] - # watch_loop catches CancelledError and exits gracefully - await pm._watch_loop() + # watch_loop propagates CancelledError + with pytest.raises(asyncio.CancelledError): + await pm._watch_loop() assert backoff_count == 1 @pytest.mark.asyncio -async def test_pod_monitor_context_manager_lifecycle( +async def test_pod_monitor_run_lifecycle( event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, pod_monitor_config: PodMonitorConfig, ) -> None: - """Test PodMonitor lifecycle via async context manager.""" + """Test PodMonitor lifecycle via run() method.""" pod_monitor_config.enable_state_reconciliation = False mock_v1 = FakeV1Api() @@ -568,29 +583,53 @@ async def test_pod_monitor_context_manager_lifecycle( kubernetes_metrics=kubernetes_metrics, ) - async with monitor: - assert monitor._watch_task is not None - assert monitor._clients is mock_k8s_clients - assert monitor._v1 is mock_v1 + # Verify DI wiring + assert monitor._clients is mock_k8s_clients + assert monitor._v1 is mock_v1 + + # Track when watch_loop is entered + watch_started = asyncio.Event() + + async def _blocking_watch() -> None: + watch_started.set() + await asyncio.sleep(10) + + monitor._watch_loop = _blocking_watch # type: ignore[method-assign] + + # Start and cancel - run() exits gracefully on cancel + task = asyncio.create_task(monitor.run()) + await asyncio.wait_for(watch_started.wait(), timeout=1.0) + task.cancel() + await task # Should complete without raising @pytest.mark.asyncio -async def test_stop_with_tasks( +async def test_cleanup_on_cancel( event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, pod_monitor_config: PodMonitorConfig, ) -> None: - """Test cleanup of tasks on context exit.""" + """Test cleanup of tracked pods on cancellation.""" pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - # Replace _watch_loop to avoid real watch and add tracked pods - async def _quick_watch() -> None: + watch_started = asyncio.Event() + + # Replace _watch_loop to add tracked pods and wait + async def _blocking_watch() -> None: pm._tracked_pods = {"pod1"} + watch_started.set() + await asyncio.sleep(10) + + pm._watch_loop = _blocking_watch # type: ignore[method-assign] - pm._watch_loop = _quick_watch # type: ignore[method-assign] + task = asyncio.create_task(pm.run()) + await asyncio.wait_for(watch_started.wait(), timeout=1.0) + assert "pod1" in pm._tracked_pods - async with pm: - assert pm._watch_task is not None + # Cancel - run() exits gracefully + task.cancel() + await task # Should complete without raising + # Cleanup should have cleared tracked pods assert len(pm._tracked_pods) == 0 @@ -667,8 +706,9 @@ async def mock_backoff() -> None: pm._run_watch = mock_run_watch # type: ignore[method-assign] pm._backoff = mock_backoff # type: ignore[method-assign] - # watch_loop catches CancelledError and exits gracefully - await pm._watch_loop() + # watch_loop propagates CancelledError + with pytest.raises(asyncio.CancelledError): + await pm._watch_loop() assert backoff_count == 1 @@ -735,8 +775,9 @@ async def mock_run_watch() -> None: pm._reconcile = mock_reconcile # type: ignore[method-assign] pm._run_watch = mock_run_watch # type: ignore[method-assign] - # watch_loop catches CancelledError and exits gracefully - await pm._watch_loop() + # watch_loop propagates CancelledError + with pytest.raises(asyncio.CancelledError): + await pm._watch_loop() # Reconcile should be called before each watch restart assert reconcile_count == 2 diff --git a/backend/tests/unit/services/result_processor/test_processor.py b/backend/tests/unit/services/result_processor/test_processor.py index c13fe0ab..e230e5a9 100644 --- a/backend/tests/unit/services/result_processor/test_processor.py +++ b/backend/tests/unit/services/result_processor/test_processor.py @@ -2,49 +2,27 @@ from unittest.mock import MagicMock import pytest -from app.core.metrics import EventMetrics, ExecutionMetrics +from app.core.metrics import ExecutionMetrics from app.domain.enums.events import EventType -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic -from app.services.result_processor.processor import ResultProcessor, ResultProcessorConfig +from app.events.core import EventDispatcher +from app.services.result_processor.processor_logic import ProcessorLogic pytestmark = pytest.mark.unit _test_logger = logging.getLogger("test.services.result_processor.processor") -class TestResultProcessorConfig: - def test_default_values(self) -> None: - config = ResultProcessorConfig() - assert config.consumer_group == GroupId.RESULT_PROCESSOR - # Topics should match centralized CONSUMER_GROUP_SUBSCRIPTIONS mapping - assert set(config.topics) == CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.RESULT_PROCESSOR] - assert KafkaTopic.EXECUTION_EVENTS in config.topics - assert config.result_topic == KafkaTopic.EXECUTION_RESULTS - assert config.batch_size == 10 - assert config.processing_timeout == 300 - - def test_custom_values(self) -> None: - config = ResultProcessorConfig(batch_size=20, processing_timeout=600) - assert config.batch_size == 20 - assert config.processing_timeout == 600 - - -def test_create_dispatcher_registers_handlers( - execution_metrics: ExecutionMetrics, event_metrics: EventMetrics -) -> None: - rp = ResultProcessor( +def test_register_handlers_registers_expected_event_types(execution_metrics: ExecutionMetrics) -> None: + logic = ProcessorLogic( execution_repo=MagicMock(), producer=MagicMock(), - schema_registry=MagicMock(), settings=MagicMock(), - idempotency_manager=MagicMock(), logger=_test_logger, execution_metrics=execution_metrics, - event_metrics=event_metrics, ) - dispatcher = rp._create_dispatcher() - assert dispatcher is not None + dispatcher = EventDispatcher(logger=_test_logger) + logic.register_handlers(dispatcher) + assert EventType.EXECUTION_COMPLETED in dispatcher._handlers assert EventType.EXECUTION_FAILED in dispatcher._handlers assert EventType.EXECUTION_TIMEOUT in dispatcher._handlers - diff --git a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py index 35e71820..848cc21d 100644 --- a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py +++ b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py @@ -1,5 +1,4 @@ import logging -from unittest.mock import MagicMock import pytest from app.core.metrics import EventMetrics @@ -10,13 +9,9 @@ from app.domain.events.typed import DomainEvent, ExecutionRequestedEvent from app.domain.saga.models import Saga, SagaConfig from app.events.core import UnifiedProducer -from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency.idempotency_manager import IdempotencyManager from app.services.saga.base_saga import BaseSaga -from app.services.saga.saga_orchestrator import SagaOrchestrator +from app.services.saga.saga_logic import SagaLogic from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep -from app.settings import Settings from tests.helpers import make_execution_requested_event @@ -52,23 +47,6 @@ async def produce( return None -class _FakeIdem(IdempotencyManager): - """Fake IdempotencyManager for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - async def close(self) -> None: - return None - - -class _FakeStore(EventStore): - """Fake EventStore for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - class _FakeAlloc(ResourceAllocationRepository): """Fake ResourceAllocationRepository for testing.""" @@ -100,15 +78,11 @@ def get_steps(self) -> list[SagaStep[ExecutionRequestedEvent]]: return [_StepOK()] -def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: - return SagaOrchestrator( +def _logic(event_metrics: EventMetrics) -> SagaLogic: + return SagaLogic( config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), saga_repository=_FakeRepo(), producer=_FakeProd(), - schema_registry_manager=MagicMock(spec=SchemaRegistryManager), - settings=MagicMock(spec=Settings), - event_store=_FakeStore(), - idempotency_manager=_FakeIdem(), resource_allocation_repository=_FakeAlloc(), logger=_test_logger, event_metrics=event_metrics, @@ -117,33 +91,29 @@ def _orch(event_metrics: EventMetrics) -> SagaOrchestrator: @pytest.mark.asyncio async def test_min_success_flow(event_metrics: EventMetrics) -> None: - orch = _orch(event_metrics) - orch.register_saga(_Saga) + logic = _logic(event_metrics) + logic.register_saga(_Saga) # Handle the event - await orch._handle_event(make_execution_requested_event(execution_id="e")) + await logic.handle_event(make_execution_requested_event(execution_id="e")) # basic sanity; deep behavior covered by integration - assert len(orch._sagas) > 0 + assert len(logic._sagas) > 0 # noqa: SLF001 @pytest.mark.asyncio async def test_should_trigger_and_existing_short_circuit(event_metrics: EventMetrics) -> None: fake_repo = _FakeRepo() - orch = SagaOrchestrator( + logic = SagaLogic( config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), saga_repository=fake_repo, producer=_FakeProd(), - schema_registry_manager=MagicMock(spec=SchemaRegistryManager), - settings=MagicMock(spec=Settings), - event_store=_FakeStore(), - idempotency_manager=_FakeIdem(), resource_allocation_repository=_FakeAlloc(), logger=_test_logger, event_metrics=event_metrics, ) - orch.register_saga(_Saga) - assert orch._should_trigger_saga(_Saga, make_execution_requested_event(execution_id="e")) is True + logic.register_saga(_Saga) + assert logic._should_trigger_saga(_Saga, make_execution_requested_event(execution_id="e")) is True # noqa: SLF001 # Existing short-circuit returns existing ID s = Saga(saga_id="sX", saga_name="s", execution_id="e", state=SagaState.RUNNING) fake_repo.existing[("e", "s")] = s - sid = await orch._start_saga("s", make_execution_requested_event(execution_id="e")) + sid = await logic._start_saga("s", make_execution_requested_event(execution_id="e")) # noqa: SLF001 assert sid == "sX" diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py index 15e3ff9f..a1204957 100644 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py @@ -1,19 +1,15 @@ import logging -from unittest.mock import MagicMock import pytest -from app.core.metrics import EventMetrics from app.domain.enums.events import EventType from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionStartedEvent from app.events.core import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.services.sse.event_router import SSEEventRouter from app.services.sse.redis_bus import SSERedisBus -from app.settings import Settings pytestmark = pytest.mark.unit -_test_logger = logging.getLogger("test.services.sse.kafka_redis_bridge") +_test_logger = logging.getLogger("test.services.sse.event_router") class _FakeBus(SSERedisBus): @@ -31,23 +27,16 @@ def _make_metadata() -> EventMetadata: @pytest.mark.asyncio -async def test_register_and_route_events_without_kafka() -> None: - # Build the bridge but don't call start(); directly test routing handlers +async def test_event_router_registers_and_routes_events() -> None: + """Test that SSEEventRouter registers handlers and routes events to Redis.""" fake_bus = _FakeBus() - mock_settings = MagicMock(spec=Settings) - mock_settings.KAFKA_BOOTSTRAP_SERVERS = "kafka:9092" - mock_settings.SSE_CONSUMER_POOL_SIZE = 1 - - bridge = SSEKafkaRedisBridge( - schema_registry=MagicMock(spec=SchemaRegistryManager), - settings=mock_settings, - event_metrics=MagicMock(spec=EventMetrics), - sse_bus=fake_bus, - logger=_test_logger, - ) + router = SSEEventRouter(sse_bus=fake_bus, logger=_test_logger) + # Register handlers with dispatcher disp = EventDispatcher(_test_logger) - bridge._register_routing_handlers(disp) + router.register_handlers(disp) + + # Verify handler was registered handlers = disp.get_handlers(EventType.EXECUTION_STARTED) assert len(handlers) > 0 @@ -59,6 +48,3 @@ async def test_register_and_route_events_without_kafka() -> None: # Proper event is published await h(ExecutionStartedEvent(execution_id="exec-123", pod_name="p", metadata=_make_metadata())) assert fake_bus.published and fake_bus.published[-1][0] == "exec-123" - - s = bridge.get_stats() - assert s["num_consumers"] == 0 diff --git a/backend/tests/unit/services/sse/test_shutdown_manager.py b/backend/tests/unit/services/sse/test_shutdown_manager.py deleted file mode 100644 index 7f9ab3c2..00000000 --- a/backend/tests/unit/services/sse/test_shutdown_manager.py +++ /dev/null @@ -1,97 +0,0 @@ -import asyncio -import logging - -import pytest -from app.core.metrics import ConnectionMetrics -from app.services.sse.sse_shutdown_manager import SSEShutdownManager - -pytestmark = pytest.mark.unit - -_test_logger = logging.getLogger("test.services.sse.shutdown_manager") - - -class _FakeRouter: - """Fake router for testing.""" - - def __init__(self) -> None: - self.stopped = False - - async def __aenter__(self) -> "_FakeRouter": - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - self.stopped = True - - -@pytest.mark.asyncio -async def test_shutdown_graceful_notify_and_drain(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager( - router=_FakeRouter(), - logger=_test_logger, - connection_metrics=connection_metrics, - drain_timeout=1.0, - notification_timeout=0.01, - force_close_timeout=0.1, - ) - - # Register two connections and arrange that they unregister when notified - ev1 = await mgr.register_connection("e1", "c1") - ev2 = await mgr.register_connection("e1", "c2") - assert ev1 is not None and ev2 is not None - - async def on_shutdown(event: asyncio.Event, cid: str) -> None: - await asyncio.wait_for(event.wait(), timeout=0.5) - await mgr.unregister_connection("e1", cid) - - t1 = asyncio.create_task(on_shutdown(ev1, "c1")) - t2 = asyncio.create_task(on_shutdown(ev2, "c2")) - - await mgr.initiate_shutdown() - done = await mgr.wait_for_shutdown(timeout=1.0) - assert done is True - status = mgr.get_shutdown_status() - assert status.phase == "complete" - await asyncio.gather(t1, t2) - - -@pytest.mark.asyncio -async def test_shutdown_force_close_and_rejects_new(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager( - router=_FakeRouter(), - logger=_test_logger, - connection_metrics=connection_metrics, - drain_timeout=0.01, - notification_timeout=0.01, - force_close_timeout=0.01, - ) - - # Register a connection but never unregister -> force close path - ev = await mgr.register_connection("e1", "c1") - assert ev is not None - - # Initiate shutdown - await mgr.initiate_shutdown() - assert mgr.is_shutting_down() is True - status = mgr.get_shutdown_status() - assert status.draining_connections == 0 - - # New connections should be rejected - ev2 = await mgr.register_connection("e2", "c2") - assert ev2 is None - - -@pytest.mark.asyncio -async def test_get_shutdown_status_transitions(connection_metrics: ConnectionMetrics) -> None: - m = SSEShutdownManager( - router=_FakeRouter(), - logger=_test_logger, - connection_metrics=connection_metrics, - drain_timeout=0.01, - notification_timeout=0.0, - force_close_timeout=0.0, - ) - st0 = m.get_shutdown_status() - assert st0.phase == "ready" - await m.initiate_shutdown() - st1 = m.get_shutdown_status() - assert st1.phase in ("draining", "complete", "closing", "notifying") diff --git a/backend/tests/unit/services/sse/test_sse_connection_registry.py b/backend/tests/unit/services/sse/test_sse_connection_registry.py new file mode 100644 index 00000000..2b5f1de3 --- /dev/null +++ b/backend/tests/unit/services/sse/test_sse_connection_registry.py @@ -0,0 +1,76 @@ +import logging + +import pytest +from app.core.metrics import ConnectionMetrics +from app.services.sse.sse_connection_registry import SSEConnectionRegistry + +pytestmark = pytest.mark.unit + +_test_logger = logging.getLogger("test.services.sse.connection_registry") + + +@pytest.mark.asyncio +async def test_register_and_unregister(connection_metrics: ConnectionMetrics) -> None: + """Test basic connection registration and unregistration.""" + registry = SSEConnectionRegistry( + logger=_test_logger, + connection_metrics=connection_metrics, + ) + + # Initially empty + assert registry.get_connection_count() == 0 + assert registry.get_execution_count() == 0 + + # Register connections + await registry.register_connection("exec-1", "conn-1") + assert registry.get_connection_count() == 1 + assert registry.get_execution_count() == 1 + + await registry.register_connection("exec-1", "conn-2") + assert registry.get_connection_count() == 2 + assert registry.get_execution_count() == 1 # Same execution + + await registry.register_connection("exec-2", "conn-3") + assert registry.get_connection_count() == 3 + assert registry.get_execution_count() == 2 + + # Unregister + await registry.unregister_connection("exec-1", "conn-1") + assert registry.get_connection_count() == 2 + assert registry.get_execution_count() == 2 + + await registry.unregister_connection("exec-1", "conn-2") + assert registry.get_connection_count() == 1 + assert registry.get_execution_count() == 1 # exec-1 removed + + await registry.unregister_connection("exec-2", "conn-3") + assert registry.get_connection_count() == 0 + assert registry.get_execution_count() == 0 + + +@pytest.mark.asyncio +async def test_unregister_nonexistent(connection_metrics: ConnectionMetrics) -> None: + """Test unregistering a connection that doesn't exist.""" + registry = SSEConnectionRegistry( + logger=_test_logger, + connection_metrics=connection_metrics, + ) + + # Should not raise + await registry.unregister_connection("nonexistent", "conn-1") + assert registry.get_connection_count() == 0 + + +@pytest.mark.asyncio +async def test_duplicate_registration(connection_metrics: ConnectionMetrics) -> None: + """Test registering the same connection twice.""" + registry = SSEConnectionRegistry( + logger=_test_logger, + connection_metrics=connection_metrics, + ) + + await registry.register_connection("exec-1", "conn-1") + await registry.register_connection("exec-1", "conn-1") # Duplicate + + # Set behavior - duplicates ignored + assert registry.get_connection_count() == 1 diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index 48ff1751..4174ee57 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -11,11 +11,10 @@ from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.execution import DomainExecution, ResourceUsageDomain -from app.domain.sse import ShutdownStatus, SSEExecutionStatusDomain, SSEHealthDomain -from app.services.sse.kafka_redis_bridge import SSEKafkaRedisBridge +from app.domain.sse import SSEExecutionStatusDomain, SSEHealthDomain from app.services.sse.redis_bus import SSERedisBus, SSERedisSubscription +from app.services.sse.sse_connection_registry import SSEConnectionRegistry from app.services.sse.sse_service import SSEService -from app.services.sse.sse_shutdown_manager import SSEShutdownManager from app.settings import Settings from pydantic import BaseModel @@ -77,45 +76,27 @@ async def get_execution(self, execution_id: str) -> DomainExecution | None: # n return self.exec_for_result -class _FakeShutdown(SSEShutdownManager): - def __init__(self) -> None: +class _FakeRegistry(SSEConnectionRegistry): + """Fake registry that tracks registrations without real metrics.""" + + def __init__(self, active_connections: int = 0, active_executions: int = 0) -> None: # Skip parent __init__ - self._evt = asyncio.Event() - self._initiated = False + self._fake_connection_count = active_connections + self._fake_execution_count = active_executions self.registered: list[tuple[str, str]] = [] self.unregistered: list[tuple[str, str]] = [] - async def register_connection(self, execution_id: str, connection_id: str) -> asyncio.Event: + async def register_connection(self, execution_id: str, connection_id: str) -> None: self.registered.append((execution_id, connection_id)) - return self._evt async def unregister_connection(self, execution_id: str, connection_id: str) -> None: self.unregistered.append((execution_id, connection_id)) - def is_shutting_down(self) -> bool: - return self._initiated - - def get_shutdown_status(self) -> ShutdownStatus: - return ShutdownStatus( - phase="ready", - initiated=self._initiated, - complete=False, - active_connections=0, - draining_connections=0, - ) - - def initiate(self) -> None: - self._initiated = True - self._evt.set() + def get_connection_count(self) -> int: + return self._fake_connection_count - -class _FakeRouter(SSEKafkaRedisBridge): - def __init__(self) -> None: - # Skip parent __init__ - pass - - def get_stats(self) -> dict[str, int | bool]: - return {"num_consumers": 3, "active_executions": 2, "is_running": True, "total_buffers": 0} + def get_execution_count(self) -> int: + return self._fake_execution_count def _make_fake_settings() -> Settings: @@ -133,8 +114,8 @@ def _decode(evt: dict[str, Any]) -> dict[str, Any]: async def test_execution_stream_closes_on_failed_event(connection_metrics: ConnectionMetrics) -> None: repo = _FakeRepo() bus = _FakeBus() - sm = _FakeShutdown() - svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, + registry = _FakeRegistry() + svc = SSEService(repository=repo, num_consumers=3, sse_bus=bus, connection_registry=registry, settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) agen = svc.create_execution_stream("exec-1", user_id="u1") @@ -177,8 +158,8 @@ async def test_execution_stream_result_stored_includes_result_payload(connection exit_code=0, ) bus = _FakeBus() - sm = _FakeShutdown() - svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, + registry = _FakeRegistry() + svc = SSEService(repository=repo, num_consumers=3, sse_bus=bus, connection_registry=registry, settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) agen = svc.create_execution_stream("exec-2", user_id="u1") @@ -200,10 +181,10 @@ async def test_execution_stream_result_stored_includes_result_payload(connection async def test_notification_stream_connected_and_heartbeat_and_message(connection_metrics: ConnectionMetrics) -> None: repo = _FakeRepo() bus = _FakeBus() - sm = _FakeShutdown() + registry = _FakeRegistry() settings = _make_fake_settings() settings.SSE_HEARTBEAT_INTERVAL = 0 # emit immediately - svc = SSEService(repository=repo, router=_FakeRouter(), sse_bus=bus, shutdown_manager=sm, settings=settings, + svc = SSEService(repository=repo, num_consumers=3, sse_bus=bus, connection_registry=registry, settings=settings, logger=_test_logger, connection_metrics=connection_metrics) agen = svc.create_notification_stream("u1") @@ -232,19 +213,18 @@ async def test_notification_stream_connected_and_heartbeat_and_message(connectio notif = await agen.__anext__() assert _decode(notif)["event_type"] == "notification" - # Stop the stream by initiating shutdown and advancing once more (loop checks flag) - sm.initiate() - # It may loop until it sees the flag; push a None to release get(timeout) - await bus.notif_sub.push(None) - # Give the generator a chance to observe the flag and finish - with pytest.raises(StopAsyncIteration): - await asyncio.wait_for(agen.__anext__(), timeout=0.2) + # Stream runs until cancelled - cancel the generator + await agen.aclose() @pytest.mark.asyncio async def test_health_status_shape(connection_metrics: ConnectionMetrics) -> None: - svc = SSEService(repository=_FakeRepo(), router=_FakeRouter(), sse_bus=_FakeBus(), shutdown_manager=_FakeShutdown(), + # Create registry with 2 active connections and 2 executions for testing + registry = _FakeRegistry(active_connections=2, active_executions=2) + svc = SSEService(repository=_FakeRepo(), num_consumers=3, sse_bus=_FakeBus(), connection_registry=registry, settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) h = await svc.get_health_status() assert isinstance(h, SSEHealthDomain) - assert h.active_consumers == 3 and h.active_executions == 2 + assert h.active_consumers == 3 + assert h.active_connections == 2 + assert h.active_executions == 2 diff --git a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py b/backend/tests/unit/services/sse/test_sse_shutdown_manager.py deleted file mode 100644 index 54ab54f3..00000000 --- a/backend/tests/unit/services/sse/test_sse_shutdown_manager.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -import logging - -import pytest -from app.core.metrics import ConnectionMetrics -from app.services.sse.sse_shutdown_manager import SSEShutdownManager - -pytestmark = pytest.mark.unit - -_test_logger = logging.getLogger("test.services.sse.sse_shutdown_manager") - - -class _FakeRouter: - """Fake router for testing.""" - - def __init__(self) -> None: - self.stopped = False - - async def __aenter__(self) -> "_FakeRouter": - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - self.stopped = True - - -@pytest.mark.asyncio -async def test_register_unregister_and_shutdown_flow(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager( - router=_FakeRouter(), - logger=_test_logger, - connection_metrics=connection_metrics, - drain_timeout=0.5, - notification_timeout=0.1, - force_close_timeout=0.1, - ) - - # Register two connections - e1 = await mgr.register_connection("exec-1", "c1") - e2 = await mgr.register_connection("exec-1", "c2") - assert e1 is not None and e2 is not None - - # Start shutdown - it will block waiting for connections to drain - shutdown_task = asyncio.create_task(mgr.initiate_shutdown()) - - # Give shutdown task a chance to start and enter drain phase - await asyncio.sleep(0) # Yield control once - - # Simulate clients acknowledging and disconnecting - e1.set() - await mgr.unregister_connection("exec-1", "c1") - e2.set() - await mgr.unregister_connection("exec-1", "c2") - - # Now shutdown can complete - await shutdown_task - assert mgr.get_shutdown_status().complete is True - - -@pytest.mark.asyncio -async def test_reject_new_connection_during_shutdown(connection_metrics: ConnectionMetrics) -> None: - mgr = SSEShutdownManager( - router=_FakeRouter(), - logger=_test_logger, - connection_metrics=connection_metrics, - drain_timeout=0.5, - notification_timeout=0.01, - force_close_timeout=0.01, - ) - # Pre-register one active connection - shutdown will block waiting for it - e = await mgr.register_connection("e", "c0") - assert e is not None - - # Start shutdown task - it sets _shutdown_initiated immediately then blocks on drain - shutdown_task = asyncio.create_task(mgr.initiate_shutdown()) - - # Yield control so shutdown task can start and set _shutdown_initiated - await asyncio.sleep(0) - - # Shutdown is now initiated (blocking on drain), new registrations should be rejected - assert mgr.is_shutting_down() is True - denied = await mgr.register_connection("e", "c1") - assert denied is None - - # Clean up - disconnect the blocking connection so shutdown can complete - e.set() - await mgr.unregister_connection("e", "c0") - await shutdown_task diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py index 77346b3a..0f71f42b 100644 --- a/backend/workers/run_coordinator.py +++ b/backend/workers/run_coordinator.py @@ -9,7 +9,8 @@ from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.coordinator.coordinator import ExecutionCoordinator +from app.services.coordinator.coordinator_logic import CoordinatorLogic +from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.settings import Settings from beanie import init_beanie @@ -27,8 +28,8 @@ async def run_coordinator(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - coordinator = await container.get(ExecutionCoordinator) + consumer = await container.get(IdempotentConsumerWrapper) + logic = await container.get(CoordinatorLogic) # Shutdown event - signal handlers just set this shutdown_event = asyncio.Event() @@ -36,16 +37,33 @@ async def run_coordinator(settings: Settings) -> None: for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, shutdown_event.set) - logger.info("ExecutionCoordinator started and running") + logger.info("ExecutionCoordinator initialized, starting run...") + + async def run_coordinator_tasks() -> None: + """Run consumer and scheduling loop using TaskGroup.""" + async with asyncio.TaskGroup() as tg: + tg.create_task(consumer.run()) + tg.create_task(logic.scheduling_loop()) try: - # Wait for shutdown signal - while not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await coordinator.get_status() - logger.info(f"Coordinator status: {status}") + # Run coordinator until shutdown signal + run_task = asyncio.create_task(run_coordinator_tasks()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + + done, pending = await asyncio.wait( + [run_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel remaining tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: - # Container cleanup stops everything logger.info("Initiating graceful shutdown...") await container.close() diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 657785f8..ea16a46a 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -9,7 +9,8 @@ from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.k8s_worker.worker import KubernetesWorker +from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.services.k8s_worker.worker_logic import K8sWorkerLogic from app.settings import Settings from beanie import init_beanie @@ -27,8 +28,8 @@ async def run_kubernetes_worker(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - worker = await container.get(KubernetesWorker) + consumer = await container.get(IdempotentConsumerWrapper) + logic = await container.get(K8sWorkerLogic) # Shutdown event - signal handlers just set this shutdown_event = asyncio.Event() @@ -36,17 +37,36 @@ async def run_kubernetes_worker(settings: Settings) -> None: for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, shutdown_event.set) - logger.info("KubernetesWorker started and running") + logger.info("KubernetesWorker initialized, starting run...") + + async def run_worker_tasks() -> None: + """Run consumer and daemonset setup using TaskGroup.""" + async with asyncio.TaskGroup() as tg: + tg.create_task(consumer.run()) + tg.create_task(logic.ensure_daemonset_task()) try: - # Wait for shutdown signal - while not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await worker.get_status() - logger.info(f"Kubernetes worker status: {status}") + # Run worker until shutdown signal + run_task = asyncio.create_task(run_worker_tasks()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + + done, pending = await asyncio.wait( + [run_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel remaining tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: - # Container cleanup stops everything logger.info("Initiating graceful shutdown...") + # Wait for active pod creations to complete + await logic.wait_for_active_creations() await container.close() diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 36e9f7f7..4549148f 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -13,8 +13,6 @@ from app.settings import Settings from beanie import init_beanie -RECONCILIATION_LOG_INTERVAL: int = 60 - async def run_pod_monitor(settings: Settings) -> None: """Run the pod monitor service.""" @@ -29,7 +27,6 @@ async def run_pod_monitor(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers monitor = await container.get(PodMonitor) # Shutdown event - signal handlers just set this @@ -38,16 +35,27 @@ async def run_pod_monitor(settings: Settings) -> None: for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, shutdown_event.set) - logger.info("PodMonitor started and running") + logger.info("PodMonitor initialized, starting run...") try: - # Wait for shutdown signal - while not shutdown_event.is_set(): - await asyncio.sleep(RECONCILIATION_LOG_INTERVAL) - status = await monitor.get_status() - logger.info(f"Pod monitor status: {status}") + # Run monitor until shutdown signal + run_task = asyncio.create_task(monitor.run()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + + done, pending = await asyncio.wait( + [run_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel remaining tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: - # Container cleanup stops everything logger.info("Initiating graceful shutdown...") await container.close() diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 5431b011..6325fc35 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -1,25 +1,20 @@ import asyncio import logging import signal -from contextlib import AsyncExitStack from app.core.container import create_result_processor_container from app.core.logging import setup_logger -from app.core.metrics import EventMetrics, ExecutionMetrics from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.db.repositories.execution_repository import ExecutionRepository from app.domain.enums.kafka import GroupId -from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency import IdempotencyManager -from app.services.result_processor.processor import ProcessingState, ResultProcessor +from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.settings import Settings from beanie import init_beanie from pymongo.asynchronous.mongo_client import AsyncMongoClient async def run_result_processor(settings: Settings) -> None: + """Run the result processor service.""" db_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 @@ -27,26 +22,9 @@ async def run_result_processor(settings: Settings) -> None: await init_beanie(database=db_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) container = create_result_processor_container(settings) - producer = await container.get(UnifiedProducer) - schema_registry = await container.get(SchemaRegistryManager) - idempotency_manager = await container.get(IdempotencyManager) - execution_repo = await container.get(ExecutionRepository) - execution_metrics = await container.get(ExecutionMetrics) - event_metrics = await container.get(EventMetrics) logger = await container.get(logging.Logger) - logger.info(f"Beanie ODM initialized with {len(ALL_DOCUMENTS)} document models") - - # ResultProcessor is manually created (not from DI), so we own its lifecycle - processor = ResultProcessor( - execution_repo=execution_repo, - producer=producer, - schema_registry=schema_registry, - settings=settings, - idempotency_manager=idempotency_manager, - logger=logger, - execution_metrics=execution_metrics, - event_metrics=event_metrics, - ) + + consumer = await container.get(IdempotentConsumerWrapper) # Shutdown event - signal handlers just set this shutdown_event = asyncio.Event() @@ -54,21 +32,30 @@ async def run_result_processor(settings: Settings) -> None: for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, shutdown_event.set) - # We own the processor, so we use async with to manage its lifecycle - async with AsyncExitStack() as stack: - stack.callback(db_client.close) - stack.push_async_callback(container.close) - await stack.enter_async_context(processor) + logger.info("ResultProcessor consumer initialized, starting run...") - logger.info("ResultProcessor started and running") + try: + # Run consumer until shutdown signal + run_task = asyncio.create_task(consumer.run()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + + done, pending = await asyncio.wait( + [run_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) - # Wait for shutdown signal or service to stop - while processor._state == ProcessingState.PROCESSING and not shutdown_event.is_set(): - await asyncio.sleep(60) - status = await processor.get_status() - logger.info(f"ResultProcessor status: {status}") + # Cancel remaining tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: logger.info("Initiating graceful shutdown...") + await container.close() + await db_client.close() def main() -> None: diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 8027e2e4..87b4754d 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -9,7 +9,8 @@ from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.saga import SagaOrchestrator +from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.services.saga.saga_logic import SagaLogic from app.settings import Settings from beanie import init_beanie @@ -27,8 +28,14 @@ async def run_saga_orchestrator(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) - # Services are already started by the DI container providers - await container.get(SagaOrchestrator) + consumer = await container.get(IdempotentConsumerWrapper | None) + logic = await container.get(SagaLogic) + + # Handle case where no sagas have triggers + if consumer is None: + logger.warning("No consumer provided (no saga triggers), exiting") + await container.close() + return # Shutdown event - signal handlers just set this shutdown_event = asyncio.Event() @@ -36,14 +43,33 @@ async def run_saga_orchestrator(settings: Settings) -> None: for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, shutdown_event.set) - logger.info("Saga orchestrator started and running") + logger.info(f"SagaOrchestrator initialized for saga: {logic.config.name}, starting run...") + + async def run_orchestrator_tasks() -> None: + """Run consumer and timeout checker using TaskGroup.""" + async with asyncio.TaskGroup() as tg: + tg.create_task(consumer.run()) + tg.create_task(logic.check_timeouts_loop()) try: - # Wait for shutdown signal - while not shutdown_event.is_set(): - await asyncio.sleep(1) + # Run orchestrator until shutdown signal + run_task = asyncio.create_task(run_orchestrator_tasks()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + + done, pending = await asyncio.wait( + [run_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel remaining tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: - # Container cleanup stops everything logger.info("Initiating graceful shutdown...") await container.close() From 66456d0ce404677f3cb3baa66093e656259244ea Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Tue, 20 Jan 2026 20:56:51 +0100 Subject: [PATCH 08/21] Di usage instead of stateful services --- backend/app/core/container.py | 18 - backend/app/core/dishka_lifespan.py | 5 +- backend/app/core/providers.py | 253 +---------- backend/app/domain/enums/kafka.py | 11 - backend/app/domain/events/typed.py | 40 +- backend/app/domain/execution/__init__.py | 2 + backend/app/domain/execution/models.py | 6 +- backend/app/events/schema/schema_registry.py | 54 ++- backend/app/services/coordinator/__init__.py | 11 - .../services/coordinator/coordinator_logic.py | 404 ------------------ .../app/services/coordinator/queue_manager.py | 217 ---------- .../services/coordinator/resource_manager.py | 325 -------------- backend/app/services/idempotency/__init__.py | 5 - .../idempotency/faststream_middleware.py | 81 ++++ .../idempotency/idempotency_manager.py | 49 +-- .../app/services/idempotency/middleware.py | 144 ------- .../app/services/pod_monitor/event_mapper.py | 6 +- .../result_processor/processor_logic.py | 10 +- backend/app/services/saga/execution_saga.py | 54 ++- backend/app/services/saga/saga_logic.py | 31 +- backend/pyproject.toml | 7 +- backend/tests/integration/conftest.py | 8 +- .../idempotency/test_consumer_idempotent.py | 90 ---- .../idempotency/test_idempotency.py | 182 +------- .../idempotency/test_idempotent_handler.py | 62 --- .../result_processor/test_result_processor.py | 24 +- .../coordinator/test_execution_coordinator.py | 18 - .../coordinator/test_queue_manager.py | 37 -- .../coordinator/test_resource_manager.py | 70 --- .../idempotency/test_idempotency_manager.py | 6 +- .../services/idempotency/test_middleware.py | 121 ------ .../unit/services/pod_monitor/test_monitor.py | 11 +- backend/uv.lock | 151 +++++-- backend/workers/run_coordinator.py | 95 ---- backend/workers/run_k8s_worker.py | 207 ++++++--- backend/workers/run_pod_monitor.py | 62 ++- backend/workers/run_result_processor.py | 208 ++++++--- backend/workers/run_saga_orchestrator.py | 219 +++++++--- 38 files changed, 884 insertions(+), 2420 deletions(-) delete mode 100644 backend/app/services/coordinator/__init__.py delete mode 100644 backend/app/services/coordinator/coordinator_logic.py delete mode 100644 backend/app/services/coordinator/queue_manager.py delete mode 100644 backend/app/services/coordinator/resource_manager.py create mode 100644 backend/app/services/idempotency/faststream_middleware.py delete mode 100644 backend/app/services/idempotency/middleware.py delete mode 100644 backend/tests/integration/idempotency/test_consumer_idempotent.py delete mode 100644 backend/tests/integration/idempotency/test_idempotent_handler.py delete mode 100644 backend/tests/integration/services/coordinator/test_execution_coordinator.py delete mode 100644 backend/tests/unit/services/coordinator/test_queue_manager.py delete mode 100644 backend/tests/unit/services/coordinator/test_resource_manager.py delete mode 100644 backend/tests/unit/services/idempotency/test_middleware.py delete mode 100644 backend/workers/run_coordinator.py diff --git a/backend/app/core/container.py b/backend/app/core/container.py index f45c2033..0d62da6c 100644 --- a/backend/app/core/container.py +++ b/backend/app/core/container.py @@ -5,7 +5,6 @@ AdminServicesProvider, AuthProvider, BusinessServicesProvider, - CoordinatorProvider, CoreServicesProvider, DatabaseProvider, EventProvider, @@ -79,23 +78,6 @@ def create_result_processor_container(settings: Settings) -> AsyncContainer: ) -def create_coordinator_container(settings: Settings) -> AsyncContainer: - """Create DI container for the ExecutionCoordinator worker.""" - return make_async_container( - SettingsProvider(), - LoggingProvider(), - DatabaseProvider(), - RedisProvider(), - CoreServicesProvider(), - MetricsProvider(), - RepositoryProvider(), - MessagingProvider(), - EventProvider(), - CoordinatorProvider(), - context={Settings: settings}, - ) - - def create_k8s_worker_container(settings: Settings) -> AsyncContainer: """Create DI container for the KubernetesWorker.""" return make_async_container( diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 23dbcdf5..7b00c6a1 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -113,7 +113,10 @@ async def run_sse_consumers() -> None: asyncio.create_task(event_store_consumer.run(), name="event_store_consumer"), asyncio.create_task(notification_service.run(), name="notification_service"), ] - logger.info(f"Background services started ({len(sse_consumers)} SSE consumers)") + logger.info( + "Background services started", + extra={"sse_consumer_count": len(sse_consumers)}, + ) try: yield diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index d18ea80a..f4cc6d9c 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -11,7 +11,6 @@ from app.core.logging import setup_logger from app.core.metrics import ( ConnectionMetrics, - CoordinatorMetrics, DatabaseMetrics, DLQMetrics, EventMetrics, @@ -52,15 +51,12 @@ from app.infrastructure.kafka.topics import get_all_topics from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService -from app.services.coordinator.coordinator_logic import CoordinatorLogic from app.services.event_bus import EventBus, EventBusEvent from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService from app.services.execution_service import ExecutionService from app.services.grafana_alert_processor import GrafanaAlertProcessor from app.services.idempotency import IdempotencyConfig, IdempotencyManager -from app.services.idempotency.idempotency_manager import create_idempotency_manager -from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.services.idempotency.redis_repository import RedisIdempotencyRepository from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.worker_logic import K8sWorkerLogic @@ -211,22 +207,28 @@ async def get_dlq_manager( async with manager: yield manager + @provide + def get_idempotency_config(self) -> IdempotencyConfig: + return IdempotencyConfig() + @provide def get_idempotency_repository(self, redis_client: redis.Redis) -> RedisIdempotencyRepository: return RedisIdempotencyRepository(redis_client, key_prefix="idempotency") @provide - async def get_idempotency_manager( - self, repo: RedisIdempotencyRepository, logger: logging.Logger, database_metrics: DatabaseMetrics - ) -> AsyncIterator[IdempotencyManager]: - manager = create_idempotency_manager( - repository=repo, config=IdempotencyConfig(), logger=logger, database_metrics=database_metrics + def get_idempotency_manager( + self, + repository: RedisIdempotencyRepository, + logger: logging.Logger, + metrics: DatabaseMetrics, + config: IdempotencyConfig, + ) -> IdempotencyManager: + return IdempotencyManager( + repository=repository, + logger=logger, + metrics=metrics, + config=config, ) - await manager.initialize() - try: - yield manager - finally: - await manager.close() class EventProvider(Provider): @@ -316,10 +318,6 @@ def get_health_metrics(self, settings: Settings) -> HealthMetrics: def get_kubernetes_metrics(self, settings: Settings) -> KubernetesMetrics: return KubernetesMetrics(settings) - @provide - def get_coordinator_metrics(self, settings: Settings) -> CoordinatorMetrics: - return CoordinatorMetrics(settings) - @provide def get_dlq_metrics(self, settings: Settings) -> DLQMetrics: return DLQMetrics(settings) @@ -703,75 +701,6 @@ def get_admin_user_service( ) -class CoordinatorProvider(Provider): - scope = Scope.APP - - @provide - def get_coordinator_logic( - self, - kafka_producer: UnifiedProducer, - execution_repository: ExecutionRepository, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - ) -> CoordinatorLogic: - return CoordinatorLogic( - producer=kafka_producer, - execution_repository=execution_repository, - logger=logger, - coordinator_metrics=coordinator_metrics, - ) - - @provide - def get_coordinator_consumer( - self, - logic: CoordinatorLogic, - schema_registry: SchemaRegistryManager, - settings: Settings, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ) -> IdempotentConsumerWrapper: - """Create consumer with handlers wired to CoordinatorLogic.""" - # Create dispatcher and register handlers from logic - dispatcher = EventDispatcher(logger=logger) - logic.register_handlers(dispatcher) - - # Build consumer - consumer_config = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.EXECUTION_COORDINATOR}.{settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - max_poll_records=100, - fetch_max_wait_ms=500, - fetch_min_bytes=1, - ) - - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.EXECUTION_COORDINATOR]) - consumer = UnifiedConsumer( - consumer_config, - dispatcher=dispatcher, - schema_registry=schema_registry, - settings=settings, - logger=logger, - event_metrics=event_metrics, - topics=topics, - ) - - return IdempotentConsumerWrapper( - consumer=consumer, - dispatcher=dispatcher, - idempotency_manager=idempotency_manager, - logger=logger, - default_key_strategy="event_based", - default_ttl_seconds=7200, - enable_for_all_handlers=True, - ) - - class K8sWorkerProvider(Provider): scope = Scope.APP @@ -795,53 +724,6 @@ def get_k8s_worker_logic( logic.initialize() return logic - @provide - def get_k8s_worker_consumer( - self, - logic: K8sWorkerLogic, - schema_registry: SchemaRegistryManager, - settings: Settings, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ) -> IdempotentConsumerWrapper: - """Create consumer with handlers wired to K8sWorkerLogic.""" - # Create dispatcher and register handlers from logic - dispatcher = EventDispatcher(logger=logger) - logic.register_handlers(dispatcher) - - # Build consumer - consumer_config = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{logic.config.consumer_group}.{settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.K8S_WORKER]) - consumer = UnifiedConsumer( - consumer_config, - dispatcher=dispatcher, - schema_registry=schema_registry, - settings=settings, - logger=logger, - event_metrics=event_metrics, - topics=topics, - ) - - return IdempotentConsumerWrapper( - consumer=consumer, - dispatcher=dispatcher, - idempotency_manager=idempotency_manager, - logger=logger, - default_key_strategy="content_hash", - default_ttl_seconds=3600, - enable_for_all_handlers=True, - ) - class PodMonitorProvider(Provider): scope = Scope.APP @@ -898,60 +780,6 @@ def get_saga_logic( logic.register_default_sagas() return logic - @provide - def get_saga_consumer( - self, - logic: SagaLogic, - schema_registry: SchemaRegistryManager, - settings: Settings, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ) -> IdempotentConsumerWrapper | None: - """Create consumer with handlers wired to SagaLogic.""" - # Get topics from registered sagas - topics = logic.get_trigger_topics() - if not topics: - logger.warning("No trigger events found in registered sagas") - return None - - # Create dispatcher and register handlers from logic - dispatcher = EventDispatcher(logger=logger) - logic.register_handlers(dispatcher) - - # Build consumer - consumer_config = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"saga-{logic.config.name}.{settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - consumer = UnifiedConsumer( - config=consumer_config, - dispatcher=dispatcher, - schema_registry=schema_registry, - settings=settings, - logger=logger, - event_metrics=event_metrics, - topics=list(topics), - ) - - logger.info(f"Saga consumer configured for topics: {topics}") - - return IdempotentConsumerWrapper( - consumer=consumer, - dispatcher=dispatcher, - idempotency_manager=idempotency_manager, - logger=logger, - default_key_strategy="event_based", - default_ttl_seconds=7200, - enable_for_all_handlers=False, - ) - class EventReplayProvider(Provider): scope = Scope.APP @@ -993,52 +821,3 @@ def get_processor_logic( logger=logger, execution_metrics=execution_metrics, ) - - @provide - def get_processor_consumer( - self, - logic: ProcessorLogic, - schema_registry: SchemaRegistryManager, - settings: Settings, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - event_metrics: EventMetrics, - ) -> IdempotentConsumerWrapper: - """Create consumer with handlers wired to ProcessorLogic.""" - # Create dispatcher and register handlers from logic - dispatcher = EventDispatcher(logger=logger) - logic.register_handlers(dispatcher) - - # Build consumer - consumer_config = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.RESULT_PROCESSOR}.{settings.KAFKA_GROUP_SUFFIX}", - max_poll_records=1, - enable_auto_commit=True, - auto_offset_reset="earliest", - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.RESULT_PROCESSOR]) - consumer = UnifiedConsumer( - consumer_config, - dispatcher=dispatcher, - schema_registry=schema_registry, - settings=settings, - logger=logger, - event_metrics=event_metrics, - topics=topics, - ) - - return IdempotentConsumerWrapper( - consumer=consumer, - dispatcher=dispatcher, - idempotency_manager=idempotency_manager, - logger=logger, - default_key_strategy="content_hash", - default_ttl_seconds=7200, - enable_for_all_handlers=True, - ) diff --git a/backend/app/domain/enums/kafka.py b/backend/app/domain/enums/kafka.py index 81d78e51..97b5a5a8 100644 --- a/backend/app/domain/enums/kafka.py +++ b/backend/app/domain/enums/kafka.py @@ -57,7 +57,6 @@ class KafkaTopic(StringEnum): class GroupId(StringEnum): """Kafka consumer group IDs.""" - EXECUTION_COORDINATOR = "execution-coordinator" K8S_WORKER = "k8s-worker" POD_MONITOR = "pod-monitor" RESULT_PROCESSOR = "result-processor" @@ -71,10 +70,6 @@ class GroupId(StringEnum): # Consumer group topic subscriptions CONSUMER_GROUP_SUBSCRIPTIONS: Dict[GroupId, Set[KafkaTopic]] = { - GroupId.EXECUTION_COORDINATOR: { - KafkaTopic.EXECUTION_EVENTS, - KafkaTopic.EXECUTION_RESULTS, - }, GroupId.K8S_WORKER: { KafkaTopic.SAGA_COMMANDS, # Receives CreatePodCommand/DeletePodCommand from coordinator }, @@ -108,12 +103,6 @@ class GroupId(StringEnum): # Consumer group event filters CONSUMER_GROUP_EVENTS: Dict[GroupId, Set[EventType]] = { - GroupId.EXECUTION_COORDINATOR: { - EventType.EXECUTION_REQUESTED, - EventType.EXECUTION_COMPLETED, - EventType.EXECUTION_FAILED, - EventType.EXECUTION_CANCELLED, - }, GroupId.K8S_WORKER: { EventType.EXECUTION_STARTED, }, diff --git a/backend/app/domain/events/typed.py b/backend/app/domain/events/typed.py index 5157be88..fbc85b39 100644 --- a/backend/app/domain/events/typed.py +++ b/backend/app/domain/events/typed.py @@ -2,18 +2,33 @@ from typing import Annotated, Literal from uuid import uuid4 +from dataclasses_avroschema.pydantic import AvroBaseModel from pydantic import ConfigDict, Discriminator, Field, TypeAdapter -from pydantic_avro.to_avro.base import AvroBase from app.domain.enums.auth import LoginMethod from app.domain.enums.common import Environment from app.domain.enums.events import EventType from app.domain.enums.notification import NotificationChannel, NotificationSeverity from app.domain.enums.storage import ExecutionErrorType, StorageType -from app.domain.execution import ResourceUsageDomain +# --- Avro-compatible nested models --- -class EventMetadata(AvroBase): + +class ResourceUsageAvro(AvroBaseModel): + """Resource usage data - Avro-compatible version for events.""" + + model_config = ConfigDict(from_attributes=True) + + execution_time_wall_seconds: float = 0.0 + cpu_time_jiffies: int = 0 + clk_tck_hertz: int = 0 + peak_memory_kb: int = 0 + + class Meta: + namespace = "com.integr8scode.events" + + +class EventMetadata(AvroBaseModel): """Event metadata - embedded in all events.""" model_config = ConfigDict(from_attributes=True, use_enum_values=True) @@ -26,8 +41,11 @@ class EventMetadata(AvroBase): user_agent: str | None = None environment: Environment = Environment.PRODUCTION + class Meta: + namespace = "com.integr8scode.events" + -class BaseEvent(AvroBase): +class BaseEvent(AvroBaseModel): """Base fields for all domain events.""" model_config = ConfigDict(from_attributes=True) @@ -39,6 +57,9 @@ class BaseEvent(AvroBase): aggregate_id: str | None = None metadata: EventMetadata + class Meta: + namespace = "com.integr8scode.events" + # --- Execution Events --- @@ -94,7 +115,7 @@ class ExecutionCompletedEvent(BaseEvent): event_type: Literal[EventType.EXECUTION_COMPLETED] = EventType.EXECUTION_COMPLETED execution_id: str exit_code: int - resource_usage: ResourceUsageDomain | None = None + resource_usage: ResourceUsageAvro | None = None stdout: str = "" stderr: str = "" @@ -105,7 +126,7 @@ class ExecutionFailedEvent(BaseEvent): exit_code: int error_type: ExecutionErrorType | None = None error_message: str = "" - resource_usage: ResourceUsageDomain | None = None + resource_usage: ResourceUsageAvro | None = None stdout: str = "" stderr: str = "" @@ -114,7 +135,7 @@ class ExecutionTimeoutEvent(BaseEvent): event_type: Literal[EventType.EXECUTION_TIMEOUT] = EventType.EXECUTION_TIMEOUT execution_id: str timeout_seconds: int - resource_usage: ResourceUsageDomain | None = None + resource_usage: ResourceUsageAvro | None = None stdout: str = "" stderr: str = "" @@ -555,7 +576,7 @@ class DLQMessageDiscardedEvent(BaseEvent): # --- Archived Event (for deleted events) --- -class ArchivedEvent(AvroBase): +class ArchivedEvent(AvroBaseModel): """Archived event with deletion metadata. Wraps the original event data.""" model_config = ConfigDict(from_attributes=True) @@ -573,6 +594,9 @@ class ArchivedEvent(AvroBase): deleted_by: str | None = None deletion_reason: str | None = None + class Meta: + namespace = "com.integr8scode.events" + # --- Discriminated Union: TYPE SYSTEM handles dispatch --- diff --git a/backend/app/domain/execution/__init__.py b/backend/app/domain/execution/__init__.py index d4275c9b..8b836982 100644 --- a/backend/app/domain/execution/__init__.py +++ b/backend/app/domain/execution/__init__.py @@ -11,6 +11,7 @@ LanguageInfoDomain, ResourceLimitsDomain, ResourceUsageDomain, + ResourceUsageDomainAdapter, ) __all__ = [ @@ -21,6 +22,7 @@ "LanguageInfoDomain", "ResourceLimitsDomain", "ResourceUsageDomain", + "ResourceUsageDomainAdapter", "RuntimeNotSupportedError", "EventPublishError", "ExecutionNotFoundError", diff --git a/backend/app/domain/execution/models.py b/backend/app/domain/execution/models.py index 2a46c8ea..9005f6cc 100644 --- a/backend/app/domain/execution/models.py +++ b/backend/app/domain/execution/models.py @@ -4,7 +4,7 @@ from typing import Any, Optional from uuid import uuid4 -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from app.domain.enums.execution import ExecutionStatus from app.domain.enums.storage import ExecutionErrorType @@ -19,6 +19,10 @@ class ResourceUsageDomain(BaseModel): peak_memory_kb: int = 0 +# TypeAdapter for Avro -> Domain conversion (handles None) +ResourceUsageDomainAdapter: TypeAdapter[ResourceUsageDomain | None] = TypeAdapter(ResourceUsageDomain | None) + + class DomainExecution(BaseModel): model_config = ConfigDict(from_attributes=True) diff --git a/backend/app/events/schema/schema_registry.py b/backend/app/events/schema/schema_registry.py index 6e4337a4..a53306b6 100644 --- a/backend/app/events/schema/schema_registry.py +++ b/backend/app/events/schema/schema_registry.py @@ -1,3 +1,9 @@ +""" +Schema Registry Manager using dataclasses-avroschema. + +Handles Avro serialization with Confluent wire format for Kafka messaging. +""" + import logging import struct from functools import lru_cache @@ -37,7 +43,6 @@ class SchemaRegistryManager: def __init__(self, settings: Settings, logger: logging.Logger): self.logger = logger - self.namespace = "com.integr8scode.events" self.subject_prefix = settings.SCHEMA_SUBJECT_PREFIX parts = settings.SCHEMA_REGISTRY_AUTH.split(":", 1) auth: tuple[str, str] | None = (parts[0], parts[1]) if len(parts) == 2 else None @@ -46,19 +51,26 @@ def __init__(self, settings: Settings, logger: logging.Logger): self._schema_id_cache: dict[type[DomainEvent], int] = {} self._id_to_class_cache: dict[int, type[DomainEvent]] = {} - async def register_schema(self, subject: str, event_class: type[DomainEvent]) -> int: - """Register schema and return schema ID.""" - avro_schema = schema.AvroSchema(event_class.avro_schema(namespace=self.namespace)) - schema_id: int = await self._client.register(subject, avro_schema) + async def _ensure_schema_registered(self, event_class: type[DomainEvent]) -> int: + """Lazily register schema and return schema ID.""" + if event_class in self._schema_id_cache: + return self._schema_id_cache[event_class] + + subject = f"{self.subject_prefix}{event_class.__name__}-value" + avro_schema_dict = event_class.avro_schema_to_python() + avro_schema_obj = schema.AvroSchema(avro_schema_dict) + + schema_id: int = await self._client.register(subject, avro_schema_obj) self._schema_id_cache[event_class] = schema_id self._id_to_class_cache[schema_id] = event_class - self.logger.info(f"Registered schema for {event_class.__name__}: ID {schema_id}") + self.logger.debug(f"Registered schema {event_class.__name__}: ID {schema_id}") return schema_id async def _get_event_class_by_id(self, schema_id: int) -> type[DomainEvent] | None: """Get event class by schema ID.""" if schema_id in self._id_to_class_cache: return self._id_to_class_cache[schema_id] + schema_obj = await self._client.get_by_id(schema_id) if schema_obj and (class_name := schema_obj.raw_schema.get("name")): if cls := _get_event_class_mapping().get(class_name): @@ -69,13 +81,25 @@ async def _get_event_class_by_id(self, schema_id: int) -> type[DomainEvent] | No async def serialize_event(self, event: DomainEvent) -> bytes: """Serialize event to Confluent wire format: [0x00][4-byte schema id][Avro binary].""" - subject = f"{self.subject_prefix}{event.__class__.__name__}-value" - avro_schema = schema.AvroSchema(event.__class__.avro_schema(namespace=self.namespace)) + event_class = event.__class__ + subject = f"{self.subject_prefix}{event_class.__name__}-value" + + # Ensure schema is registered (lazy registration) + await self._ensure_schema_registered(event_class) + + # Get schema and serialize + avro_schema_dict = event_class.avro_schema_to_python() + avro_schema_obj = schema.AvroSchema(avro_schema_dict) + + # Prepare payload - use model_dump for dict representation payload: dict[str, Any] = event.model_dump(mode="python", by_alias=False, exclude_unset=False) payload.pop("event_type", None) + + # Convert datetime to microseconds for Avro logical type if "timestamp" in payload and payload["timestamp"] is not None: payload["timestamp"] = int(payload["timestamp"].timestamp() * 1_000_000) - return await self._serializer.encode_record_with_schema(subject, avro_schema, payload) + + return await self._serializer.encode_record_with_schema(subject, avro_schema_obj, payload) async def deserialize_event(self, data: bytes, topic: str) -> DomainEvent: """Deserialize from Confluent wire format to DomainEvent.""" @@ -83,15 +107,20 @@ async def deserialize_event(self, data: bytes, topic: str) -> DomainEvent: raise ValueError("Invalid message: too short for wire format") if data[0:1] != MAGIC_BYTE: raise ValueError(f"Unknown magic byte: {data[0]:#x}") + schema_id = struct.unpack(">I", data[1:5])[0] event_class = await self._get_event_class_by_id(schema_id) if not event_class: raise ValueError(f"Unknown schema ID: {schema_id}") + obj = await self._serializer.decode_message(data) if not isinstance(obj, dict): raise ValueError(f"Deserialization returned {type(obj)}, expected dict") + + # Restore event_type if missing (it's a discriminator field with default) if (f := event_class.model_fields.get("event_type")) and f.default and "event_type" not in obj: obj["event_type"] = f.default + return event_class.model_validate(obj) def deserialize_json(self, data: dict[str, Any]) -> DomainEvent: @@ -108,16 +137,17 @@ async def set_compatibility(self, subject: str, mode: str) -> None: if mode not in valid: raise ValueError(f"Invalid compatibility mode: {mode}") await self._client.update_compatibility(level=mode, subject=subject) - self.logger.info(f"Set {subject} compatibility to {mode}") + self.logger.debug(f"Set {subject} compatibility to {mode}") async def initialize_schemas(self) -> None: - """Initialize all event schemas in the registry.""" + """Initialize all event schemas in the registry with FORWARD compatibility.""" for event_class in _get_all_event_classes(): subject = f"{self.subject_prefix}{event_class.__name__}-value" await self.set_compatibility(subject, "FORWARD") - await self.register_schema(subject, event_class) + await self._ensure_schema_registered(event_class) self.logger.info(f"Initialized {len(_get_all_event_classes())} event schemas") async def initialize_event_schemas(registry: SchemaRegistryManager) -> None: + """Initialize all event schemas in the registry.""" await registry.initialize_schemas() diff --git a/backend/app/services/coordinator/__init__.py b/backend/app/services/coordinator/__init__.py deleted file mode 100644 index c3fd1ffb..00000000 --- a/backend/app/services/coordinator/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from app.services.coordinator.coordinator_logic import CoordinatorLogic -from app.services.coordinator.queue_manager import QueueManager, QueuePriority -from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager - -__all__ = [ - "CoordinatorLogic", - "QueueManager", - "QueuePriority", - "ResourceManager", - "ResourceAllocation", -] diff --git a/backend/app/services/coordinator/coordinator_logic.py b/backend/app/services/coordinator/coordinator_logic.py deleted file mode 100644 index 528983d6..00000000 --- a/backend/app/services/coordinator/coordinator_logic.py +++ /dev/null @@ -1,404 +0,0 @@ -import asyncio -import logging -import time -from typing import Any, TypeAlias -from uuid import uuid4 - -from app.core.metrics import CoordinatorMetrics -from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType -from app.domain.enums.storage import ExecutionErrorType -from app.domain.events.typed import ( - CreatePodCommandEvent, - DomainEvent, - EventMetadata, - ExecutionAcceptedEvent, - ExecutionCancelledEvent, - ExecutionCompletedEvent, - ExecutionFailedEvent, - ExecutionRequestedEvent, -) -from app.events.core import EventDispatcher, UnifiedProducer -from app.services.coordinator.queue_manager import QueueManager, QueuePriority -from app.services.coordinator.resource_manager import ResourceAllocation, ResourceManager - -ExecutionMap: TypeAlias = dict[str, ResourceAllocation] - - -class CoordinatorLogic: - """ - Business logic for execution coordination. - - Handles: - - Execution request queuing and validation - - Resource allocation and management - - Scheduling loop for processing queued executions - - Event publishing (ExecutionAccepted, CreatePodCommand, ExecutionFailed) - - This class is stateful and must be instantiated once per coordinator instance. - """ - - def __init__( - self, - producer: UnifiedProducer, - execution_repository: ExecutionRepository, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - max_concurrent_scheduling: int = 10, - scheduling_interval_seconds: float = 0.5, - ): - self.logger = logger - self.metrics = coordinator_metrics - - # Components - self.queue_manager = QueueManager( - logger=self.logger, - coordinator_metrics=coordinator_metrics, - max_queue_size=10000, - max_executions_per_user=100, - stale_timeout_seconds=3600, - ) - - self.resource_manager = ResourceManager( - logger=self.logger, - coordinator_metrics=coordinator_metrics, - total_cpu_cores=32.0, - total_memory_mb=65536, - total_gpu_count=0, - ) - - # Kafka producer (injected, lifecycle managed by DI) - self.producer = producer - self.execution_repository = execution_repository - - # Scheduling - self.max_concurrent_scheduling = max_concurrent_scheduling - self.scheduling_interval = scheduling_interval_seconds - self._scheduling_semaphore = asyncio.Semaphore(max_concurrent_scheduling) - - # State tracking - self._active_executions: set[str] = set() - self._execution_resources: ExecutionMap = {} - - def register_handlers(self, dispatcher: EventDispatcher) -> None: - """Register event handlers with the dispatcher.""" - - @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def handle_requested(event: ExecutionRequestedEvent) -> None: - await self._route_execution_event(event) - - @dispatcher.register(EventType.EXECUTION_COMPLETED) - async def handle_completed(event: ExecutionCompletedEvent) -> None: - await self._route_execution_result(event) - - @dispatcher.register(EventType.EXECUTION_FAILED) - async def handle_failed(event: ExecutionFailedEvent) -> None: - await self._route_execution_result(event) - - @dispatcher.register(EventType.EXECUTION_CANCELLED) - async def handle_cancelled(event: ExecutionCancelledEvent) -> None: - await self._route_execution_event(event) - - async def _route_execution_event(self, event: DomainEvent) -> None: - """Route execution events to appropriate handlers based on event type.""" - self.logger.info( - f"COORDINATOR: Routing execution event - type: {event.event_type}, " - f"id: {event.event_id}, " - f"actual class: {type(event).__name__}" - ) - - if event.event_type == EventType.EXECUTION_REQUESTED: - await self._handle_execution_requested(event) - elif event.event_type == EventType.EXECUTION_CANCELLED: - await self._handle_execution_cancelled(event) - else: - self.logger.debug(f"Ignoring execution event type: {event.event_type}") - - async def _route_execution_result(self, event: DomainEvent) -> None: - """Route execution result events to appropriate handlers based on event type.""" - if event.event_type == EventType.EXECUTION_COMPLETED: - await self._handle_execution_completed(event) - elif event.event_type == EventType.EXECUTION_FAILED: - await self._handle_execution_failed(event) - else: - self.logger.debug(f"Ignoring execution result event type: {event.event_type}") - - async def _handle_execution_requested(self, event: ExecutionRequestedEvent) -> None: - """Handle execution requested event - add to queue for processing.""" - self.logger.info(f"HANDLER CALLED: _handle_execution_requested for event {event.event_id}") - start_time = time.time() - - try: - # Add to queue with priority - success, position, error = await self.queue_manager.add_execution( - event, - priority=QueuePriority(event.priority), - ) - - if not success: - # Publish queue full event - await self._publish_queue_full(event, error or "Queue is full") - self.metrics.record_coordinator_execution_scheduled("queue_full") - return - - # Publish ExecutionAcceptedEvent - if position is None: - position = 0 - await self._publish_execution_accepted(event, position, event.priority) - - # Track metrics - duration = time.time() - start_time - self.metrics.record_coordinator_scheduling_duration(duration) - self.metrics.record_coordinator_execution_scheduled("queued") - - self.logger.info(f"Execution {event.execution_id} added to queue at position {position}") - - # Schedule immediately if at front of queue (position 0) - if position == 0: - await self._schedule_execution(event) - - except Exception as e: - self.logger.error(f"Failed to handle execution request {event.execution_id}: {e}", exc_info=True) - self.metrics.record_coordinator_execution_scheduled("error") - - async def _handle_execution_cancelled(self, event: ExecutionCancelledEvent) -> None: - """Handle execution cancelled event.""" - execution_id = event.execution_id - - removed = await self.queue_manager.remove_execution(execution_id) - - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] - - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - if removed: - self.logger.info(f"Execution {execution_id} cancelled and removed from queue") - - async def _handle_execution_completed(self, event: ExecutionCompletedEvent) -> None: - """Handle execution completed event.""" - execution_id = event.execution_id - - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] - - # Remove from active set - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - self.logger.info(f"Execution {execution_id} completed, resources released") - - async def _handle_execution_failed(self, event: ExecutionFailedEvent) -> None: - """Handle execution failed event.""" - execution_id = event.execution_id - - # Release resources - if execution_id in self._execution_resources: - await self.resource_manager.release_allocation(execution_id) - del self._execution_resources[execution_id] - - # Remove from active set - self._active_executions.discard(execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - async def scheduling_loop(self) -> None: - """Main scheduling loop - processes queued executions.""" - self.logger.info("Scheduling loop started") - try: - while True: - try: - # Get next execution from queue - execution = await self.queue_manager.get_next_execution() - - if execution: - # Schedule execution - asyncio.create_task(self._schedule_execution(execution)) - else: - # No executions in queue, wait - await asyncio.sleep(self.scheduling_interval) - - except asyncio.CancelledError: - raise - except Exception as e: - self.logger.error(f"Error in scheduling loop: {e}", exc_info=True) - await asyncio.sleep(5) # Wait before retrying - except asyncio.CancelledError: - self.logger.info("Scheduling loop cancelled") - - async def _schedule_execution(self, event: ExecutionRequestedEvent) -> None: - """Schedule a single execution.""" - async with self._scheduling_semaphore: - start_time = time.time() - execution_id = event.execution_id - - # Atomic check-and-claim: no await between check and add prevents TOCTOU race - # when both eager scheduling (position=0) and scheduling_loop try to schedule - if execution_id in self._active_executions: - self.logger.debug(f"Execution {execution_id} already claimed, skipping") - return - self._active_executions.add(execution_id) - - try: - # Request resource allocation - allocation = await self.resource_manager.request_allocation( - execution_id, - event.language, - requested_cpu=None, # Use defaults for now - requested_memory_mb=None, - requested_gpu=0, - ) - - if not allocation: - # No resources available, release claim and requeue - self._active_executions.discard(execution_id) - await self.queue_manager.requeue_execution(event, increment_retry=False) - self.logger.info(f"No resources available for {execution_id}, requeued") - return - - # Track allocation (already in _active_executions from claim above) - self._execution_resources[execution_id] = allocation - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - - # Publish execution started event for workers - self.logger.info(f"About to publish ExecutionStartedEvent for {event.execution_id}") - try: - await self._publish_execution_started(event) - self.logger.info(f"Successfully published ExecutionStartedEvent for {event.execution_id}") - except Exception as publish_error: - self.logger.error( - f"Failed to publish ExecutionStartedEvent for {event.execution_id}: {publish_error}", - exc_info=True, - ) - raise - - # Track metrics - queue_time = start_time - event.timestamp.timestamp() - priority = getattr(event, "priority", QueuePriority.NORMAL) - self.metrics.record_coordinator_queue_time(queue_time, priority.name) - - scheduling_duration = time.time() - start_time - self.metrics.record_coordinator_scheduling_duration(scheduling_duration) - self.metrics.record_coordinator_execution_scheduled("scheduled") - - self.logger.info( - f"Scheduled execution {event.execution_id}. " - f"Queue time: {queue_time:.2f}s, " - f"Resources: {allocation.cpu_cores} CPU, " - f"{allocation.memory_mb}MB RAM" - ) - - except Exception as e: - self.logger.error(f"Failed to schedule execution {event.execution_id}: {e}", exc_info=True) - - # Release any allocated resources - if event.execution_id in self._execution_resources: - await self.resource_manager.release_allocation(event.execution_id) - del self._execution_resources[event.execution_id] - - self._active_executions.discard(event.execution_id) - self.metrics.update_coordinator_active_executions(len(self._active_executions)) - self.metrics.record_coordinator_execution_scheduled("error") - - # Publish failure event - await self._publish_scheduling_failed(event, str(e)) - - async def _build_command_metadata(self, request: ExecutionRequestedEvent) -> EventMetadata: - """Build metadata for CreatePodCommandEvent with guaranteed user_id.""" - # Prefer execution record user_id to avoid missing attribution - exec_rec = await self.execution_repository.get_execution(request.execution_id) - user_id: str = exec_rec.user_id if exec_rec and exec_rec.user_id else "system" - - return EventMetadata( - service_name="execution-coordinator", - service_version="1.0.0", - user_id=user_id, - correlation_id=request.metadata.correlation_id, - ) - - async def _publish_execution_started(self, request: ExecutionRequestedEvent) -> None: - """Send CreatePodCommandEvent to k8s-worker via SAGA_COMMANDS topic.""" - metadata = await self._build_command_metadata(request) - - create_pod_cmd = CreatePodCommandEvent( - saga_id=str(uuid4()), - execution_id=request.execution_id, - script=request.script, - language=request.language, - language_version=request.language_version, - runtime_image=request.runtime_image, - runtime_command=request.runtime_command, - runtime_filename=request.runtime_filename, - timeout_seconds=request.timeout_seconds, - cpu_limit=request.cpu_limit, - memory_limit=request.memory_limit, - cpu_request=request.cpu_request, - memory_request=request.memory_request, - priority=request.priority, - metadata=metadata, - ) - - await self.producer.produce(event_to_produce=create_pod_cmd, key=request.execution_id) - - async def _publish_execution_accepted( - self, request: ExecutionRequestedEvent, position: int, priority: int - ) -> None: - """Publish execution accepted event to notify that request was valid and queued.""" - self.logger.info(f"Publishing ExecutionAcceptedEvent for execution {request.execution_id}") - - event = ExecutionAcceptedEvent( - execution_id=request.execution_id, - queue_position=position, - estimated_wait_seconds=None, # Could calculate based on queue analysis - priority=priority, - metadata=request.metadata, - ) - - await self.producer.produce(event_to_produce=event) - self.logger.info(f"ExecutionAcceptedEvent published for {request.execution_id}") - - async def _publish_queue_full(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish queue full event.""" - # Get queue stats for context - queue_stats = await self.queue_manager.get_queue_stats() - - event = ExecutionFailedEvent( - execution_id=request.execution_id, - error_type=ExecutionErrorType.RESOURCE_LIMIT, - exit_code=-1, - stderr=f"Queue full: {error}. Queue size: {queue_stats.get('total_size', 'unknown')}", - resource_usage=None, - metadata=request.metadata, - error_message=error, - ) - - await self.producer.produce(event_to_produce=event, key=request.execution_id) - - async def _publish_scheduling_failed(self, request: ExecutionRequestedEvent, error: str) -> None: - """Publish scheduling failed event.""" - # Get resource stats for context - resource_stats = await self.resource_manager.get_resource_stats() - - event = ExecutionFailedEvent( - execution_id=request.execution_id, - error_type=ExecutionErrorType.SYSTEM_ERROR, - exit_code=-1, - stderr=f"Failed to schedule execution: {error}. " - f"Available resources: CPU={resource_stats.available.cpu_cores}, " - f"Memory={resource_stats.available.memory_mb}MB", - resource_usage=None, - metadata=request.metadata, - error_message=error, - ) - - await self.producer.produce(event_to_produce=event, key=request.execution_id) - - async def get_status(self) -> dict[str, Any]: - """Get coordinator status.""" - return { - "active_executions": len(self._active_executions), - "queue_stats": await self.queue_manager.get_queue_stats(), - "resource_stats": await self.resource_manager.get_resource_stats(), - } diff --git a/backend/app/services/coordinator/queue_manager.py b/backend/app/services/coordinator/queue_manager.py deleted file mode 100644 index 76b15c3c..00000000 --- a/backend/app/services/coordinator/queue_manager.py +++ /dev/null @@ -1,217 +0,0 @@ -import asyncio -import heapq -import logging -import time -from collections import defaultdict -from dataclasses import dataclass, field -from enum import IntEnum -from typing import Any, Dict, List, Tuple - -from app.core.metrics import CoordinatorMetrics -from app.domain.events.typed import ExecutionRequestedEvent - - -class QueuePriority(IntEnum): - CRITICAL = 0 - HIGH = 1 - NORMAL = 5 - LOW = 8 - BACKGROUND = 10 - - -@dataclass(order=True) -class QueuedExecution: - priority: int - timestamp: float = field(compare=False) - event: ExecutionRequestedEvent = field(compare=False) - retry_count: int = field(default=0, compare=False) - - @property - def execution_id(self) -> str: - return self.event.execution_id - - @property - def user_id(self) -> str: - return self.event.metadata.user_id or "anonymous" - - @property - def age_seconds(self) -> float: - return time.time() - self.timestamp - - -class QueueManager: - def __init__( - self, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - max_queue_size: int = 10000, - max_executions_per_user: int = 100, - stale_timeout_seconds: int = 3600, - ) -> None: - self.logger = logger - self.metrics = coordinator_metrics - self.max_queue_size = max_queue_size - self.max_executions_per_user = max_executions_per_user - self.stale_timeout_seconds = stale_timeout_seconds - - self._queue: List[QueuedExecution] = [] - self._queue_lock = asyncio.Lock() - self._user_execution_count: Dict[str, int] = defaultdict(int) - self._execution_users: Dict[str, str] = {} - - async def add_execution( - self, event: ExecutionRequestedEvent, priority: QueuePriority | None = None - ) -> Tuple[bool, int | None, str | None]: - async with self._queue_lock: - if len(self._queue) >= self.max_queue_size: - return False, None, "Queue is full" - - user_id = event.metadata.user_id or "anonymous" - - if self._user_execution_count[user_id] >= self.max_executions_per_user: - return False, None, f"User execution limit exceeded ({self.max_executions_per_user})" - - if priority is None: - priority = QueuePriority(event.priority) - - queued = QueuedExecution(priority=priority.value, timestamp=time.time(), event=event) - - heapq.heappush(self._queue, queued) - self._track_execution(event.execution_id, user_id) - position = self._get_queue_position(event.execution_id) - - # Update single authoritative metric for execution request queue depth - self.metrics.update_execution_request_queue_size(len(self._queue)) - - self.logger.info( - f"Added execution {event.execution_id} to queue. " - f"Priority: {priority.name}, Position: {position}, " - f"Queue size: {len(self._queue)}" - ) - - return True, position, None - - async def get_next_execution(self) -> ExecutionRequestedEvent | None: - async with self._queue_lock: - while self._queue: - queued = heapq.heappop(self._queue) - - if self._is_stale(queued): - self._untrack_execution(queued.execution_id) - self._record_removal("stale") - continue - - self._untrack_execution(queued.execution_id) - self._record_wait_time(queued) - # Update metric after removal from the queue - self.metrics.update_execution_request_queue_size(len(self._queue)) - - self.logger.info( - f"Retrieved execution {queued.execution_id} from queue. " - f"Wait time: {queued.age_seconds:.2f}s, Queue size: {len(self._queue)}" - ) - - return queued.event - - return None - - async def remove_execution(self, execution_id: str) -> bool: - async with self._queue_lock: - initial_size = len(self._queue) - self._queue = [q for q in self._queue if q.execution_id != execution_id] - - if len(self._queue) < initial_size: - heapq.heapify(self._queue) - self._untrack_execution(execution_id) - # Update metric after explicit removal - self.metrics.update_execution_request_queue_size(len(self._queue)) - self.logger.info(f"Removed execution {execution_id} from queue") - return True - - return False - - async def get_queue_position(self, execution_id: str) -> int | None: - async with self._queue_lock: - return self._get_queue_position(execution_id) - - async def get_queue_stats(self) -> Dict[str, Any]: - async with self._queue_lock: - priority_counts: Dict[str, int] = defaultdict(int) - user_counts: Dict[str, int] = defaultdict(int) - - for queued in self._queue: - priority_name = QueuePriority(queued.priority).name - priority_counts[priority_name] += 1 - user_counts[queued.user_id] += 1 - - top_users = dict(sorted(user_counts.items(), key=lambda x: x[1], reverse=True)[:10]) - - return { - "total_size": len(self._queue), - "priority_distribution": dict(priority_counts), - "top_users": top_users, - "max_queue_size": self.max_queue_size, - "utilization_percent": (len(self._queue) / self.max_queue_size) * 100, - } - - async def requeue_execution( - self, event: ExecutionRequestedEvent, increment_retry: bool = True - ) -> Tuple[bool, int | None, str | None]: - def _next_lower(p: QueuePriority) -> QueuePriority: - order = [ - QueuePriority.CRITICAL, - QueuePriority.HIGH, - QueuePriority.NORMAL, - QueuePriority.LOW, - QueuePriority.BACKGROUND, - ] - try: - idx = order.index(p) - except ValueError: - # Fallback: treat unknown numeric as NORMAL - idx = order.index(QueuePriority.NORMAL) - return order[min(idx + 1, len(order) - 1)] - - if increment_retry: - original_priority = QueuePriority(event.priority) - new_priority = _next_lower(original_priority) - else: - new_priority = QueuePriority(event.priority) - - return await self.add_execution(event, priority=new_priority) - - def _get_queue_position(self, execution_id: str) -> int | None: - for position, queued in enumerate(self._queue, 1): - if queued.execution_id == execution_id: - return position - return None - - def _is_stale(self, queued: QueuedExecution) -> bool: - return queued.age_seconds > self.stale_timeout_seconds - - def _track_execution(self, execution_id: str, user_id: str) -> None: - self._user_execution_count[user_id] += 1 - self._execution_users[execution_id] = user_id - - def _untrack_execution(self, execution_id: str) -> None: - if execution_id in self._execution_users: - user_id = self._execution_users.pop(execution_id) - self._user_execution_count[user_id] -= 1 - if self._user_execution_count[user_id] <= 0: - del self._user_execution_count[user_id] - - def _record_removal(self, reason: str) -> None: - # No-op: we keep a single queue depth metric and avoid operation counters - return - - def _record_wait_time(self, queued: QueuedExecution) -> None: - self.metrics.record_queue_wait_time_by_priority( - queued.age_seconds, QueuePriority(queued.priority).name, "default" - ) - - def _update_add_metrics(self, priority: QueuePriority) -> None: - # Deprecated in favor of single execution queue depth metric - self.metrics.update_execution_request_queue_size(len(self._queue)) - - def _update_queue_size(self) -> None: - self.metrics.update_execution_request_queue_size(len(self._queue)) diff --git a/backend/app/services/coordinator/resource_manager.py b/backend/app/services/coordinator/resource_manager.py deleted file mode 100644 index de2cbba6..00000000 --- a/backend/app/services/coordinator/resource_manager.py +++ /dev/null @@ -1,325 +0,0 @@ -import asyncio -import logging -from dataclasses import dataclass -from typing import Dict, List - -from app.core.metrics import CoordinatorMetrics - - -@dataclass -class ResourceAllocation: - """Resource allocation for an execution""" - - cpu_cores: float - memory_mb: int - gpu_count: int = 0 - - @property - def cpu_millicores(self) -> int: - """Get CPU in millicores for Kubernetes""" - return int(self.cpu_cores * 1000) - - @property - def memory_bytes(self) -> int: - """Get memory in bytes""" - return self.memory_mb * 1024 * 1024 - - -@dataclass -class ResourcePool: - """Available resource pool""" - - total_cpu_cores: float - total_memory_mb: int - total_gpu_count: int - - available_cpu_cores: float - available_memory_mb: int - available_gpu_count: int - - # Resource limits per execution - max_cpu_per_execution: float = 4.0 - max_memory_per_execution_mb: int = 8192 - max_gpu_per_execution: int = 1 - - # Minimum resources to keep available - min_available_cpu_cores: float = 2.0 - min_available_memory_mb: int = 4096 - - -@dataclass -class ResourceGroup: - """Resource group with usage information""" - - cpu_cores: float - memory_mb: int - gpu_count: int - - -@dataclass -class ResourceStats: - """Resource statistics""" - - total: ResourceGroup - available: ResourceGroup - allocated: ResourceGroup - utilization: Dict[str, float] - allocation_count: int - limits: Dict[str, int | float] - - -@dataclass -class ResourceAllocationInfo: - """Information about a resource allocation""" - - execution_id: str - cpu_cores: float - memory_mb: int - gpu_count: int - cpu_percentage: float - memory_percentage: float - - -class ResourceManager: - """Manages resource allocation for executions""" - - def __init__( - self, - logger: logging.Logger, - coordinator_metrics: CoordinatorMetrics, - total_cpu_cores: float = 32.0, - total_memory_mb: int = 65536, # 64GB - total_gpu_count: int = 0, - overcommit_factor: float = 1.2, # Allow 20% overcommit - ): - self.logger = logger - self.metrics = coordinator_metrics - self.pool = ResourcePool( - total_cpu_cores=total_cpu_cores * overcommit_factor, - total_memory_mb=int(total_memory_mb * overcommit_factor), - total_gpu_count=total_gpu_count, - available_cpu_cores=total_cpu_cores * overcommit_factor, - available_memory_mb=int(total_memory_mb * overcommit_factor), - available_gpu_count=total_gpu_count, - ) - - # Adjust minimum reserve thresholds proportionally for small pools. - # Keep at most 10% of total as reserve (but not higher than defaults). - # This avoids refusing small, reasonable allocations on modest clusters. - self.pool.min_available_cpu_cores = min( - self.pool.min_available_cpu_cores, - max(0.1 * self.pool.total_cpu_cores, 0.0), - ) - self.pool.min_available_memory_mb = min( - self.pool.min_available_memory_mb, - max(int(0.1 * self.pool.total_memory_mb), 0), - ) - - # Track allocations - self._allocations: Dict[str, ResourceAllocation] = {} - self._allocation_lock = asyncio.Lock() - - # Default allocations by language - self.default_allocations = { - "python": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "javascript": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "go": ResourceAllocation(cpu_cores=0.25, memory_mb=256), - "rust": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "java": ResourceAllocation(cpu_cores=1.0, memory_mb=1024), - "cpp": ResourceAllocation(cpu_cores=0.5, memory_mb=512), - "r": ResourceAllocation(cpu_cores=1.0, memory_mb=2048), - } - - # Update initial metrics - self._update_metrics() - - async def request_allocation( - self, - execution_id: str, - language: str, - requested_cpu: float | None = None, - requested_memory_mb: int | None = None, - requested_gpu: int = 0, - ) -> ResourceAllocation | None: - """ - Request resource allocation for execution - - Returns: - ResourceAllocation if successful, None if resources unavailable - """ - async with self._allocation_lock: - # Check if already allocated - if execution_id in self._allocations: - self.logger.warning(f"Execution {execution_id} already has allocation") - return self._allocations[execution_id] - - # Determine requested resources - if requested_cpu is None or requested_memory_mb is None: - # Use defaults based on language - default = self.default_allocations.get(language, ResourceAllocation(cpu_cores=0.5, memory_mb=512)) - requested_cpu = requested_cpu or default.cpu_cores - requested_memory_mb = requested_memory_mb or default.memory_mb - - # Apply limits - requested_cpu = min(requested_cpu, self.pool.max_cpu_per_execution) - requested_memory_mb = min(requested_memory_mb, self.pool.max_memory_per_execution_mb) - requested_gpu = min(requested_gpu, self.pool.max_gpu_per_execution) - - # Check availability (considering minimum reserves) - cpu_after = self.pool.available_cpu_cores - requested_cpu - memory_after = self.pool.available_memory_mb - requested_memory_mb - gpu_after = self.pool.available_gpu_count - requested_gpu - - if ( - cpu_after < self.pool.min_available_cpu_cores - or memory_after < self.pool.min_available_memory_mb - or gpu_after < 0 - ): - self.logger.warning( - f"Insufficient resources for execution {execution_id}. " - f"Requested: {requested_cpu} CPU, {requested_memory_mb}MB RAM, " - f"{requested_gpu} GPU. Available: {self.pool.available_cpu_cores} CPU, " - f"{self.pool.available_memory_mb}MB RAM, {self.pool.available_gpu_count} GPU" - ) - return None - - # Create allocation - allocation = ResourceAllocation( - cpu_cores=requested_cpu, memory_mb=requested_memory_mb, gpu_count=requested_gpu - ) - - # Update pool - self.pool.available_cpu_cores = cpu_after - self.pool.available_memory_mb = memory_after - self.pool.available_gpu_count = gpu_after - - # Track allocation - self._allocations[execution_id] = allocation - - # Update metrics - self._update_metrics() - - self.logger.info( - f"Allocated resources for execution {execution_id}: " - f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, " - f"{allocation.gpu_count} GPU" - ) - - return allocation - - async def release_allocation(self, execution_id: str) -> bool: - """Release resource allocation""" - async with self._allocation_lock: - if execution_id not in self._allocations: - self.logger.warning(f"No allocation found for execution {execution_id}") - return False - - allocation = self._allocations[execution_id] - - # Return resources to pool - self.pool.available_cpu_cores += allocation.cpu_cores - self.pool.available_memory_mb += allocation.memory_mb - self.pool.available_gpu_count += allocation.gpu_count - - # Remove allocation - del self._allocations[execution_id] - - # Update metrics - self._update_metrics() - - self.logger.info( - f"Released resources for execution {execution_id}: " - f"{allocation.cpu_cores} CPU, {allocation.memory_mb}MB RAM, " - f"{allocation.gpu_count} GPU" - ) - - return True - - async def get_allocation(self, execution_id: str) -> ResourceAllocation | None: - """Get current allocation for execution""" - async with self._allocation_lock: - return self._allocations.get(execution_id) - - async def can_allocate(self, cpu_cores: float, memory_mb: int, gpu_count: int = 0) -> bool: - """Check if resources can be allocated""" - async with self._allocation_lock: - cpu_after = self.pool.available_cpu_cores - cpu_cores - memory_after = self.pool.available_memory_mb - memory_mb - gpu_after = self.pool.available_gpu_count - gpu_count - - return ( - cpu_after >= self.pool.min_available_cpu_cores - and memory_after >= self.pool.min_available_memory_mb - and gpu_after >= 0 - ) - - async def get_resource_stats(self) -> ResourceStats: - """Get resource statistics""" - async with self._allocation_lock: - allocated_cpu = self.pool.total_cpu_cores - self.pool.available_cpu_cores - allocated_memory = self.pool.total_memory_mb - self.pool.available_memory_mb - allocated_gpu = self.pool.total_gpu_count - self.pool.available_gpu_count - - gpu_percent = (allocated_gpu / self.pool.total_gpu_count * 100) if self.pool.total_gpu_count > 0 else 0 - - return ResourceStats( - total=ResourceGroup( - cpu_cores=self.pool.total_cpu_cores, - memory_mb=self.pool.total_memory_mb, - gpu_count=self.pool.total_gpu_count, - ), - available=ResourceGroup( - cpu_cores=self.pool.available_cpu_cores, - memory_mb=self.pool.available_memory_mb, - gpu_count=self.pool.available_gpu_count, - ), - allocated=ResourceGroup(cpu_cores=allocated_cpu, memory_mb=allocated_memory, gpu_count=allocated_gpu), - utilization={ - "cpu_percent": (allocated_cpu / self.pool.total_cpu_cores * 100), - "memory_percent": (allocated_memory / self.pool.total_memory_mb * 100), - "gpu_percent": gpu_percent, - }, - allocation_count=len(self._allocations), - limits={ - "max_cpu_per_execution": self.pool.max_cpu_per_execution, - "max_memory_per_execution_mb": self.pool.max_memory_per_execution_mb, - "max_gpu_per_execution": self.pool.max_gpu_per_execution, - }, - ) - - async def get_allocations_by_resource_usage(self) -> List[ResourceAllocationInfo]: - """Get allocations sorted by resource usage""" - async with self._allocation_lock: - allocations = [] - for exec_id, allocation in self._allocations.items(): - allocations.append( - ResourceAllocationInfo( - execution_id=str(exec_id), - cpu_cores=allocation.cpu_cores, - memory_mb=allocation.memory_mb, - gpu_count=allocation.gpu_count, - cpu_percentage=(allocation.cpu_cores / self.pool.total_cpu_cores * 100), - memory_percentage=(allocation.memory_mb / self.pool.total_memory_mb * 100), - ) - ) - - # Sort by total resource usage - allocations.sort(key=lambda x: x.cpu_percentage + x.memory_percentage, reverse=True) - - return allocations - - def _update_metrics(self) -> None: - """Update metrics""" - cpu_usage = self.pool.total_cpu_cores - self.pool.available_cpu_cores - cpu_percent = cpu_usage / self.pool.total_cpu_cores * 100 - self.metrics.update_resource_usage("cpu", cpu_percent) - - memory_usage = self.pool.total_memory_mb - self.pool.available_memory_mb - memory_percent = memory_usage / self.pool.total_memory_mb * 100 - self.metrics.update_resource_usage("memory", memory_percent) - - gpu_usage = self.pool.total_gpu_count - self.pool.available_gpu_count - gpu_percent = gpu_usage / max(1, self.pool.total_gpu_count) * 100 - self.metrics.update_resource_usage("gpu", gpu_percent) - - self.metrics.update_coordinator_active_executions(len(self._allocations)) diff --git a/backend/app/services/idempotency/__init__.py b/backend/app/services/idempotency/__init__.py index fc82d7d3..3351c82d 100644 --- a/backend/app/services/idempotency/__init__.py +++ b/backend/app/services/idempotency/__init__.py @@ -4,9 +4,7 @@ IdempotencyKeyStrategy, IdempotencyManager, IdempotencyResult, - create_idempotency_manager, ) -from app.services.idempotency.middleware import IdempotentConsumerWrapper, IdempotentEventHandler __all__ = [ "IdempotencyManager", @@ -14,7 +12,4 @@ "IdempotencyResult", "IdempotencyStatus", "IdempotencyKeyStrategy", - "create_idempotency_manager", - "IdempotentEventHandler", - "IdempotentConsumerWrapper", ] diff --git a/backend/app/services/idempotency/faststream_middleware.py b/backend/app/services/idempotency/faststream_middleware.py new file mode 100644 index 00000000..caf620d5 --- /dev/null +++ b/backend/app/services/idempotency/faststream_middleware.py @@ -0,0 +1,81 @@ +"""FastStream middleware for idempotent event processing. + +Uses Dishka's request-scoped container to resolve dependencies per-message. +Must be added to broker AFTER setup_dishka() is called. +""" + +from collections.abc import Awaitable, Callable +from typing import Any + +from faststream import BaseMiddleware +from faststream.message import StreamMessage + +from app.domain.events.typed import DomainEvent +from app.events.schema.schema_registry import SchemaRegistryManager +from app.services.idempotency.idempotency_manager import IdempotencyManager + + +class IdempotencyMiddleware(BaseMiddleware): + """ + FastStream middleware providing idempotent message processing. + + Resolves IdempotencyManager and SchemaRegistryManager from Dishka's + request-scoped container (available via context after DishkaMiddleware runs). + + Flow: + 1. DishkaMiddleware.consume_scope creates request container in context + 2. This middleware's consume_scope resolves dependencies from container + 3. Decodes Avro message and checks idempotency + 4. Skips handler if duplicate, otherwise proceeds and marks result + """ + + async def consume_scope( + self, + call_next: Callable[[Any], Awaitable[Any]], + msg: StreamMessage[Any], + ) -> Any: + """Check idempotency before processing, mark completed/failed after.""" + # Get Dishka request container from context (set by DishkaMiddleware) + container = self.context.get_local("dishka") + if container is None: + # Dishka not set up or middleware order wrong - skip idempotency + return await call_next(msg) + + # Resolve dependencies from request-scoped container + try: + idempotency = await container.get(IdempotencyManager) + schema_registry = await container.get(SchemaRegistryManager) + except Exception: + # Dependencies not available - skip idempotency + return await call_next(msg) + + # Decode message to get event for idempotency check + body = msg.body + if not isinstance(body, bytes): + # Not Avro bytes - skip idempotency + return await call_next(msg) + + try: + event: DomainEvent = await schema_registry.deserialize_event(body, "idempotency") + except Exception: + # Failed to decode - let handler deal with it + return await call_next(msg) + + # Check idempotency + result = await idempotency.check_and_reserve( + event=event, + key_strategy="event_based", + ) + + if result.is_duplicate: + # Skip handler for duplicates + return None + + # Not a duplicate - proceed with processing + try: + handler_result = await call_next(msg) + await idempotency.mark_completed(event=event, key_strategy="event_based") + return handler_result + except Exception as e: + await idempotency.mark_failed(event=event, error=str(e), key_strategy="event_based") + raise diff --git a/backend/app/services/idempotency/idempotency_manager.py b/backend/app/services/idempotency/idempotency_manager.py index e30b6efe..00b3b71c 100644 --- a/backend/app/services/idempotency/idempotency_manager.py +++ b/backend/app/services/idempotency/idempotency_manager.py @@ -1,4 +1,3 @@ -import asyncio import hashlib import json import logging @@ -30,7 +29,6 @@ class IdempotencyConfig(BaseModel): processing_timeout_seconds: int = 300 enable_result_caching: bool = True max_result_size_bytes: int = 1048576 - enable_metrics: bool = True collection_name: str = "idempotency_keys" @@ -69,30 +67,15 @@ async def health_check(self) -> None: ... class IdempotencyManager: def __init__( self, - config: IdempotencyConfig, repository: IdempotencyRepoProtocol, logger: logging.Logger, - database_metrics: DatabaseMetrics, + metrics: DatabaseMetrics, + config: IdempotencyConfig, ) -> None: - self.config = config - self.metrics = database_metrics - self._repo: IdempotencyRepoProtocol = repository - self._stats_update_task: asyncio.Task[None] | None = None + self._repo = repository self.logger = logger - - async def initialize(self) -> None: - if self.config.enable_metrics and self._stats_update_task is None: - self._stats_update_task = asyncio.create_task(self._update_stats_loop()) - self.logger.info("Idempotency manager ready") - - async def close(self) -> None: - if self._stats_update_task: - self._stats_update_task.cancel() - try: - await self._stats_update_task - except asyncio.CancelledError: - pass - self.logger.info("Closed idempotency manager") + self.metrics = metrics + self.config = config def _generate_key( self, event: BaseEvent, key_strategy: str, custom_key: str | None = None, fields: set[str] | None = None @@ -307,25 +290,3 @@ async def get_stats(self) -> IdempotencyStats: } total = sum(status_counts.values()) return IdempotencyStats(total_keys=total, status_counts=status_counts, prefix=self.config.key_prefix) - - async def _update_stats_loop(self) -> None: - while True: - try: - stats = await self.get_stats() - self.metrics.update_idempotency_keys_active(stats.total_keys, self.config.key_prefix) - await asyncio.sleep(60) - except asyncio.CancelledError: - break - except Exception as e: - self.logger.error(f"Failed to update idempotency stats: {e}") - await asyncio.sleep(300) - - -def create_idempotency_manager( - *, - repository: IdempotencyRepoProtocol, - config: IdempotencyConfig | None = None, - logger: logging.Logger, - database_metrics: DatabaseMetrics, -) -> IdempotencyManager: - return IdempotencyManager(config or IdempotencyConfig(), repository, logger, database_metrics) diff --git a/backend/app/services/idempotency/middleware.py b/backend/app/services/idempotency/middleware.py deleted file mode 100644 index a6fc772f..00000000 --- a/backend/app/services/idempotency/middleware.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Idempotent event processing middleware""" - -import asyncio -import logging -from collections.abc import Awaitable, Callable -from typing import Any - -from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent -from app.events.core import EventDispatcher, UnifiedConsumer -from app.services.idempotency.idempotency_manager import IdempotencyManager - - -class IdempotentEventHandler: - """Wrapper for event handlers with idempotency support.""" - - def __init__( - self, - handler: Callable[[DomainEvent], Awaitable[None]], - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - key_strategy: str = "event_based", - custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, - ttl_seconds: int | None = None, - on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, - ): - self.handler = handler - self.idempotency_manager = idempotency_manager - self.logger = logger - self.key_strategy = key_strategy - self.custom_key_func = custom_key_func - self.fields = fields - self.ttl_seconds = ttl_seconds - self.on_duplicate = on_duplicate - - async def __call__(self, event: DomainEvent) -> None: - custom_key = self.custom_key_func(event) if self.key_strategy == "custom" and self.custom_key_func else None - - result = await self.idempotency_manager.check_and_reserve( - event=event, - key_strategy=self.key_strategy, - custom_key=custom_key, - ttl_seconds=self.ttl_seconds, - fields=self.fields, - ) - - if result.is_duplicate: - self.logger.info(f"Duplicate event: {event.event_type} ({event.event_id})") - if self.on_duplicate: - if asyncio.iscoroutinefunction(self.on_duplicate): - await self.on_duplicate(event, result) - else: - await asyncio.to_thread(self.on_duplicate, event, result) - return - - try: - await self.handler(event) - await self.idempotency_manager.mark_completed( - event=event, key_strategy=self.key_strategy, custom_key=custom_key, fields=self.fields - ) - except Exception as e: - await self.idempotency_manager.mark_failed( - event=event, error=str(e), key_strategy=self.key_strategy, custom_key=custom_key, fields=self.fields - ) - raise - - -class IdempotentConsumerWrapper: - """Wrapper for UnifiedConsumer with automatic idempotency. - - Usage: - dispatcher = EventDispatcher() - dispatcher.register(EventType.FOO, handle_foo) - - consumer = UnifiedConsumer(..., dispatcher=dispatcher) - wrapper = IdempotentConsumerWrapper(consumer, dispatcher, idempotency_manager, ...) - await wrapper.run() # Handlers are wrapped with idempotency, then consumer runs - """ - - def __init__( - self, - consumer: UnifiedConsumer, - dispatcher: EventDispatcher, - idempotency_manager: IdempotencyManager, - logger: logging.Logger, - default_key_strategy: str = "event_based", - default_ttl_seconds: int = 3600, - enable_for_all_handlers: bool = True, - ): - self._consumer = consumer - self._dispatcher = dispatcher - self._idempotency_manager = idempotency_manager - self._logger = logger - self._default_key_strategy = default_key_strategy - self._default_ttl_seconds = default_ttl_seconds - self._enable_for_all_handlers = enable_for_all_handlers - - async def run(self) -> None: - """Wrap handlers with idempotency, then run consumer.""" - if self._enable_for_all_handlers: - self._wrap_handlers() - self._logger.info("IdempotentConsumerWrapper running") - await self._consumer.run() - - def _wrap_handlers(self) -> None: - """Wrap all registered handlers with idempotency.""" - original_handlers = self._dispatcher.get_all_handlers() - - for event_type, handlers in original_handlers.items(): - wrapped: list[Callable[[DomainEvent], Awaitable[None]]] = [ - IdempotentEventHandler( - handler=h, - idempotency_manager=self._idempotency_manager, - logger=self._logger, - key_strategy=self._default_key_strategy, - ttl_seconds=self._default_ttl_seconds, - ) - for h in handlers - ] - self._dispatcher.replace_handlers(event_type, wrapped) - - def register_idempotent_handler( - self, - event_type: EventType, - handler: Callable[[DomainEvent], Awaitable[None]], - key_strategy: str | None = None, - custom_key_func: Callable[[DomainEvent], str] | None = None, - fields: set[str] | None = None, - ttl_seconds: int | None = None, - on_duplicate: Callable[[DomainEvent, Any], Any] | None = None, - ) -> None: - """Register an idempotent handler for an event type.""" - wrapped = IdempotentEventHandler( - handler=handler, - idempotency_manager=self._idempotency_manager, - logger=self._logger, - key_strategy=key_strategy or self._default_key_strategy, - custom_key_func=custom_key_func, - fields=fields, - ttl_seconds=ttl_seconds or self._default_ttl_seconds, - on_duplicate=on_duplicate, - ) - self._dispatcher.register(event_type)(wrapped) diff --git a/backend/app/services/pod_monitor/event_mapper.py b/backend/app/services/pod_monitor/event_mapper.py index c34b530f..34a451e7 100644 --- a/backend/app/services/pod_monitor/event_mapper.py +++ b/backend/app/services/pod_monitor/event_mapper.py @@ -17,8 +17,8 @@ PodRunningEvent, PodScheduledEvent, PodTerminatedEvent, + ResourceUsageAvro, ) -from app.domain.execution import ResourceUsageDomain # Python 3.12 type aliases type PodPhase = str @@ -43,7 +43,7 @@ class PodLogs: stdout: str stderr: str exit_code: int - resource_usage: ResourceUsageDomain + resource_usage: ResourceUsageAvro class EventMapper(Protocol): @@ -498,7 +498,7 @@ def _try_parse_json(self, text: str) -> PodLogs | None: stdout=data.get("stdout", ""), stderr=data.get("stderr", ""), exit_code=data.get("exit_code", 0), - resource_usage=ResourceUsageDomain(**data.get("resource_usage", {})), + resource_usage=ResourceUsageAvro(**data.get("resource_usage", {})), ) def _log_extraction_error(self, pod_name: str, error: str) -> None: diff --git a/backend/app/services/result_processor/processor_logic.py b/backend/app/services/result_processor/processor_logic.py index 3bae92cb..1a61164e 100644 --- a/backend/app/services/result_processor/processor_logic.py +++ b/backend/app/services/result_processor/processor_logic.py @@ -15,7 +15,7 @@ ResultFailedEvent, ResultStoredEvent, ) -from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain +from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain, ResourceUsageDomainAdapter from app.events.core import EventDispatcher, UnifiedProducer from app.settings import Settings @@ -29,8 +29,6 @@ class ProcessorLogic: - Storing results in database - Publishing ResultStored/ResultFailed events - Recording metrics - - This class is stateful and must be instantiated once per processor instance. """ def __init__( @@ -99,7 +97,7 @@ async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: exit_code=event.exit_code, stdout=event.stdout, stderr=event.stderr, - resource_usage=event.resource_usage, + resource_usage=ResourceUsageDomainAdapter.validate_python(event.resource_usage), metadata=event.metadata.model_dump(), ) @@ -127,7 +125,7 @@ async def _handle_failed(self, event: ExecutionFailedEvent) -> None: exit_code=event.exit_code or -1, stdout=event.stdout, stderr=event.stderr, - resource_usage=event.resource_usage, + resource_usage=ResourceUsageDomainAdapter.validate_python(event.resource_usage), metadata=event.metadata.model_dump(), error_type=event.error_type, ) @@ -157,7 +155,7 @@ async def _handle_timeout(self, event: ExecutionTimeoutEvent) -> None: exit_code=-1, stdout=event.stdout, stderr=event.stderr, - resource_usage=event.resource_usage, + resource_usage=ResourceUsageDomainAdapter.validate_python(event.resource_usage), metadata=event.metadata.model_dump(), error_type=ExecutionErrorType.TIMEOUT, ) diff --git a/backend/app/services/saga/execution_saga.py b/backend/app/services/saga/execution_saga.py index 5cc430e2..2d229bde 100644 --- a/backend/app/services/saga/execution_saga.py +++ b/backend/app/services/saga/execution_saga.py @@ -3,7 +3,13 @@ from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository from app.domain.enums.events import EventType -from app.domain.events.typed import CreatePodCommandEvent, DeletePodCommandEvent, EventMetadata, ExecutionRequestedEvent +from app.domain.events.typed import ( + CreatePodCommandEvent, + DeletePodCommandEvent, + EventMetadata, + ExecutionAcceptedEvent, + ExecutionRequestedEvent, +) from app.domain.saga import DomainResourceAllocationCreate from app.events.core import UnifiedProducer @@ -53,6 +59,51 @@ def get_compensation(self) -> CompensationStep | None: return None +class AcceptExecutionStep(SagaStep[ExecutionRequestedEvent]): + """Publish ExecutionAcceptedEvent to confirm request is being processed.""" + + def __init__(self, producer: Optional[UnifiedProducer] = None) -> None: + super().__init__("accept_execution") + self.producer: UnifiedProducer | None = producer + + async def execute(self, context: SagaContext, event: ExecutionRequestedEvent) -> bool: + """Publish acceptance event.""" + try: + execution_id = context.get("execution_id") + logger.info(f"Publishing ExecutionAcceptedEvent for {execution_id}") + + if not self.producer: + raise RuntimeError("Producer dependency not injected") + + accepted_event = ExecutionAcceptedEvent( + execution_id=execution_id, + queue_position=0, + estimated_wait_seconds=None, + priority=event.priority, + metadata=EventMetadata( + service_name="saga-orchestrator", + service_version="1.0.0", + user_id=event.metadata.user_id, + correlation_id=event.metadata.correlation_id, + ), + ) + + await self.producer.produce(event_to_produce=accepted_event) + context.set("accepted", True) + logger.info(f"ExecutionAcceptedEvent published for {execution_id}") + + return True + + except Exception as e: + logger.error(f"Failed to publish acceptance: {e}") + context.set_error(e) + return False + + def get_compensation(self) -> CompensationStep | None: + """No compensation needed - acceptance is just a notification.""" + return None + + class AllocateResourcesStep(SagaStep[ExecutionRequestedEvent]): """Allocate resources for execution""" @@ -347,6 +398,7 @@ def get_steps(self) -> list[SagaStep[Any]]: publish_commands = bool(getattr(self, "_publish_commands", False)) return [ ValidateExecutionStep(), + AcceptExecutionStep(producer=producer), AllocateResourcesStep(alloc_repo=alloc_repo), QueueExecutionStep(), CreatePodStep(producer=producer, publish_commands=publish_commands), diff --git a/backend/app/services/saga/saga_logic.py b/backend/app/services/saga/saga_logic.py index caed3457..ed1476ce 100644 --- a/backend/app/services/saga/saga_logic.py +++ b/backend/app/services/saga/saga_logic.py @@ -296,26 +296,29 @@ async def _fail_saga(self, instance: Saga, error_message: str) -> None: self.logger.error(f"Saga {instance.saga_id} failed: {error_message}") + async def check_timeouts_once(self) -> None: + """Check for saga timeouts (single check).""" + cutoff_time = datetime.now(UTC) - timedelta(seconds=self.config.timeout_seconds) + + timed_out = await self._repo.find_timed_out_sagas(cutoff_time) + + for instance in timed_out: + self.logger.warning(f"Saga {instance.saga_id} timed out") + + instance.state = SagaState.TIMEOUT + instance.error_message = f"Saga timed out after {self.config.timeout_seconds} seconds" + instance.completed_at = datetime.now(UTC) + + await self._save_saga(instance) + self._running_instances.pop(instance.saga_id, None) + async def check_timeouts_loop(self) -> None: """Check for saga timeouts (runs until cancelled).""" try: while True: # Check every 30 seconds await asyncio.sleep(30) - - cutoff_time = datetime.now(UTC) - timedelta(seconds=self.config.timeout_seconds) - - timed_out = await self._repo.find_timed_out_sagas(cutoff_time) - - for instance in timed_out: - self.logger.warning(f"Saga {instance.saga_id} timed out") - - instance.state = SagaState.TIMEOUT - instance.error_message = f"Saga timed out after {self.config.timeout_seconds} seconds" - instance.completed_at = datetime.now(UTC) - - await self._save_saga(instance) - self._running_instances.pop(instance.saga_id, None) + await self.check_timeouts_once() except asyncio.CancelledError: self.logger.info("Timeout checker cancelled") diff --git a/backend/pyproject.toml b/backend/pyproject.toml index febd8c01..7377b600 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -23,12 +23,12 @@ dependencies = [ "charset-normalizer==3.4.0", "click==8.1.7", "ConfigArgParse==1.7.1", - "aiokafka==0.13.0", + "aiokafka>=0.12.0,<0.13", "python-schema-registry-client==2.6.1", "contourpy==1.3.3", "cycler==0.12.1", "Deprecated==1.2.14", - "dishka==1.6.0", + "dishka==1.7.2", "dnspython==2.7.0", "durationpy==0.9", "email-validator==2.3.0", @@ -89,7 +89,7 @@ dependencies = [ "pyasn1==0.6.1", "pyasn1_modules==0.4.2", "pydantic==2.9.2", - "pydantic-avro==0.9.1", + "dataclasses-avroschema[pydantic]>=0.65.0", "pydantic-settings==2.5.2", "pydantic_core==2.23.4", "Pygments==2.19.2", @@ -124,6 +124,7 @@ dependencies = [ "yarl==1.20.1", "zipp==3.20.2", "monggregate==0.22.1", + "faststream[kafka]>=0.6.0", ] [build-system] diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index e188a3b7..329ba48c 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -95,12 +95,6 @@ async def idempotency_manager( processing_timeout_seconds=5, enable_result_caching=True, max_result_size_bytes=1024, - enable_metrics=False, ) repo = RedisIdempotencyRepository(redis_client, key_prefix=prefix) - mgr = IdempotencyManager(cfg, repo, _test_logger, database_metrics=database_metrics) - await mgr.initialize() - try: - yield mgr - finally: - await mgr.close() + yield IdempotencyManager(repository=repo, logger=_test_logger, metrics=database_metrics, config=cfg) diff --git a/backend/tests/integration/idempotency/test_consumer_idempotent.py b/backend/tests/integration/idempotency/test_consumer_idempotent.py deleted file mode 100644 index fe8dfc5a..00000000 --- a/backend/tests/integration/idempotency/test_consumer_idempotent.py +++ /dev/null @@ -1,90 +0,0 @@ -import asyncio -import logging - -import pytest -from app.core.metrics import EventMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher as Disp -from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper -from app.settings import Settings -from dishka import AsyncContainer - -from tests.helpers import make_execution_requested_event - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [ - pytest.mark.integration, - pytest.mark.kafka, - pytest.mark.redis, - pytest.mark.xdist_group("kafka_consumers"), -] - -_test_logger = logging.getLogger("test.idempotency.consumer_idempotent") - - -@pytest.mark.asyncio -async def test_consumer_idempotent_wrapper_blocks_duplicates( - scope: AsyncContainer, - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - consumer_config: ConsumerConfig, - test_settings: Settings, -) -> None: - producer: UnifiedProducer = await scope.get(UnifiedProducer) - idm: IdempotencyManager = await scope.get(IdempotencyManager) - - # Future resolves when handler processes an event - no polling needed - handled_future: asyncio.Future[None] = asyncio.get_running_loop().create_future() - seen = {"n": 0} - - # Build a dispatcher that signals completion via future - disp: Disp = EventDispatcher(logger=_test_logger) - - @disp.register(EventType.EXECUTION_REQUESTED) - async def handle(_ev: DomainEvent) -> None: - seen["n"] += 1 - if not handled_future.done(): - handled_future.set_result(None) - - # Produce messages BEFORE starting consumer (auto_offset_reset="earliest" will read them) - execution_id = f"e-{consumer_config.group_id}" - ev = make_execution_requested_event(execution_id=execution_id) - await producer.produce(ev, key=execution_id) - await producer.produce(ev, key=execution_id) - - # Real consumer with idempotent wrapper - base = UnifiedConsumer( - consumer_config, - dispatcher=disp, - schema_registry=schema_registry, - settings=test_settings, - logger=_test_logger, - event_metrics=event_metrics, - topics=[KafkaTopic.EXECUTION_EVENTS], - ) - wrapper = IdempotentConsumerWrapper( - consumer=base, - idempotency_manager=idm, - dispatcher=disp, - default_key_strategy="event_based", - enable_for_all_handlers=True, - logger=_test_logger, - ) - - # Start wrapper as background task - wrapper_task = asyncio.create_task(wrapper.run()) - - try: - # Await the future directly - true async, no polling - await asyncio.wait_for(handled_future, timeout=10.0) - assert seen["n"] >= 1 - finally: - wrapper_task.cancel() - with pytest.raises(asyncio.CancelledError): - await wrapper_task diff --git a/backend/tests/integration/idempotency/test_idempotency.py b/backend/tests/integration/idempotency/test_idempotency.py index 25f60111..ddee8562 100644 --- a/backend/tests/integration/idempotency/test_idempotency.py +++ b/backend/tests/integration/idempotency/test_idempotency.py @@ -1,18 +1,11 @@ import asyncio import json import logging -import uuid from datetime import datetime, timedelta, timezone -from typing import Any import pytest -import redis.asyncio as redis -from app.core.metrics import DatabaseMetrics -from app.domain.events.typed import DomainEvent from app.domain.idempotency import IdempotencyRecord, IdempotencyStatus -from app.services.idempotency.idempotency_manager import IdempotencyConfig, IdempotencyManager -from app.services.idempotency.middleware import IdempotentEventHandler -from app.services.idempotency.redis_repository import RedisIdempotencyRepository +from app.services.idempotency.idempotency_manager import IdempotencyManager from tests.helpers import make_execution_requested_event @@ -231,163 +224,8 @@ async def test_remove_key(self, idempotency_manager: IdempotencyManager) -> None assert result2.is_duplicate is False -class TestIdempotentEventHandlerIntegration: - """Test IdempotentEventHandler with real components""" - - @pytest.mark.asyncio - async def test_handler_processes_new_event(self, idempotency_manager: IdempotencyManager) -> None: - """Test that handler processes new events""" - processed_events: list[DomainEvent] = [] - - async def actual_handler(event: DomainEvent) -> None: - processed_events.append(event) - - # Create idempotent handler - handler = IdempotentEventHandler( - handler=actual_handler, - idempotency_manager=idempotency_manager, - key_strategy="event_based", - logger=_test_logger, - ) - - # Process event - real_event = make_execution_requested_event(execution_id="handler-test-123") - await handler(real_event) - - # Verify event was processed - assert len(processed_events) == 1 - assert processed_events[0] == real_event - - @pytest.mark.asyncio - async def test_handler_blocks_duplicate(self, idempotency_manager: IdempotencyManager) -> None: - """Test that handler blocks duplicate events""" - processed_events: list[DomainEvent] = [] - - async def actual_handler(event: DomainEvent) -> None: - processed_events.append(event) - - # Create idempotent handler - handler = IdempotentEventHandler( - handler=actual_handler, - idempotency_manager=idempotency_manager, - key_strategy="event_based", - logger=_test_logger, - ) - - # Process event twice - real_event = make_execution_requested_event(execution_id="handler-dup-123") - await handler(real_event) - await handler(real_event) - - # Verify event was processed only once - assert len(processed_events) == 1 - - @pytest.mark.asyncio - async def test_handler_with_failure(self, idempotency_manager: IdempotencyManager) -> None: - """Test handler marks failure on exception""" - - async def failing_handler(event: DomainEvent) -> None: # noqa: ARG001 - raise ValueError("Processing failed") - - handler = IdempotentEventHandler( - handler=failing_handler, - idempotency_manager=idempotency_manager, - key_strategy="event_based", - logger=_test_logger, - ) - - # Process event (should raise) - real_event = make_execution_requested_event(execution_id="handler-fail-1") - with pytest.raises(ValueError, match="Processing failed"): - await handler(real_event) - - # Verify marked as failed - key = f"{idempotency_manager.config.key_prefix}:{real_event.event_type}:{real_event.event_id}" - record = await idempotency_manager._repo.find_by_key(key) - assert record is not None - assert record.status == IdempotencyStatus.FAILED - assert record.error is not None - assert "Processing failed" in record.error - - @pytest.mark.asyncio - async def test_handler_duplicate_callback(self, idempotency_manager: IdempotencyManager) -> None: - """Test duplicate callback is invoked""" - duplicate_events: list[tuple[DomainEvent, Any]] = [] - - async def actual_handler(event: DomainEvent) -> None: # noqa: ARG001 - pass # Do nothing - - async def on_duplicate(event: DomainEvent, result: Any) -> None: - duplicate_events.append((event, result)) - - handler = IdempotentEventHandler( - handler=actual_handler, - idempotency_manager=idempotency_manager, - key_strategy="event_based", - on_duplicate=on_duplicate, - logger=_test_logger, - ) - - # Process twice - real_event = make_execution_requested_event(execution_id="handler-dup-cb-1") - await handler(real_event) - await handler(real_event) - - # Verify duplicate callback was called - assert len(duplicate_events) == 1 - assert duplicate_events[0][0] == real_event - assert duplicate_events[0][1].is_duplicate is True - - @pytest.mark.asyncio - async def test_custom_key_function(self, idempotency_manager: IdempotencyManager) -> None: - """Test handler with custom key function""" - processed_scripts: list[str] = [] - - async def process_script(event: DomainEvent) -> None: - script: str = getattr(event, "script", "") - processed_scripts.append(script) - - def extract_script_key(event: DomainEvent) -> str: - # Custom key based on script content only - script: str = getattr(event, "script", "") - return f"script:{hash(script)}" - - handler = IdempotentEventHandler( - handler=process_script, - idempotency_manager=idempotency_manager, - key_strategy="custom", - custom_key_func=extract_script_key, - logger=_test_logger, - ) - - # Events with same script - event1 = make_execution_requested_event( - execution_id="id1", - script="print('hello')", - service_name="test-service", - ) - - event2 = make_execution_requested_event( - execution_id="id2", - language="python", - language_version="3.9", # Different version - runtime_image="python:3.9-slim", - runtime_command=("python",), - runtime_filename="main.py", - timeout_seconds=60, # Different timeout - cpu_limit="200m", - memory_limit="256Mi", - cpu_request="100m", - memory_request="128Mi", - service_name="test-service", - ) - - await handler(event1) - await handler(event2) - - # Should only process once (same script) - assert len(processed_scripts) == 1 - assert processed_scripts[0] == "print('hello')" +class TestIdempotencyManagerValidation: + """Test IdempotencyManager validation and edge cases""" @pytest.mark.asyncio async def test_invalid_key_strategy(self, idempotency_manager: IdempotencyManager) -> None: @@ -445,20 +283,6 @@ async def test_cleanup_expired_keys(self, idempotency_manager: IdempotencyManage record = await idempotency_manager._repo.find_by_key(expired_key) assert record is not None # Still exists until explicit cleanup - @pytest.mark.asyncio - async def test_metrics_enabled(self, redis_client: redis.Redis, database_metrics: DatabaseMetrics) -> None: - """Test manager with metrics enabled""" - config = IdempotencyConfig(key_prefix=f"metrics:{uuid.uuid4().hex[:6]}", enable_metrics=True) - repository = RedisIdempotencyRepository(redis_client, key_prefix=config.key_prefix) - manager = IdempotencyManager(config, repository, _test_logger, database_metrics=database_metrics) - - # Initialize with metrics - await manager.initialize() - assert manager._stats_update_task is not None - - # Cleanup - await manager.close() - @pytest.mark.asyncio async def test_content_hash_with_fields(self, idempotency_manager: IdempotencyManager) -> None: """Test content hash with specific fields""" diff --git a/backend/tests/integration/idempotency/test_idempotent_handler.py b/backend/tests/integration/idempotency/test_idempotent_handler.py deleted file mode 100644 index c7ef5730..00000000 --- a/backend/tests/integration/idempotency/test_idempotent_handler.py +++ /dev/null @@ -1,62 +0,0 @@ -import logging - -import pytest -from app.domain.events.typed import DomainEvent -from app.services.idempotency.idempotency_manager import IdempotencyManager -from app.services.idempotency.middleware import IdempotentEventHandler -from dishka import AsyncContainer - -from tests.helpers import make_execution_requested_event - -pytestmark = [pytest.mark.integration] - -_test_logger = logging.getLogger("test.idempotency.idempotent_handler") - - -@pytest.mark.asyncio -async def test_idempotent_handler_blocks_duplicates(scope: AsyncContainer) -> None: - manager: IdempotencyManager = await scope.get(IdempotencyManager) - - processed: list[str | None] = [] - - async def _handler(ev: DomainEvent) -> None: - processed.append(ev.event_id) - - handler = IdempotentEventHandler( - handler=_handler, - idempotency_manager=manager, - key_strategy="event_based", - logger=_test_logger, - ) - - ev = make_execution_requested_event(execution_id="exec-dup-1") - - await handler(ev) - await handler(ev) # duplicate - - assert processed == [ev.event_id] - - -@pytest.mark.asyncio -async def test_idempotent_handler_content_hash_blocks_same_content(scope: AsyncContainer) -> None: - manager: IdempotencyManager = await scope.get(IdempotencyManager) - - processed: list[str] = [] - - async def _handler(ev: DomainEvent) -> None: - processed.append(getattr(ev, "execution_id", "")) - - handler = IdempotentEventHandler( - handler=_handler, - idempotency_manager=manager, - key_strategy="content_hash", - logger=_test_logger, - ) - - e1 = make_execution_requested_event(execution_id="exec-dup-2") - e2 = make_execution_requested_event(execution_id="exec-dup-2") - - await handler(e1) - await handler(e2) - - assert processed == [e1.execution_id] diff --git a/backend/tests/integration/result_processor/test_result_processor.py b/backend/tests/integration/result_processor/test_result_processor.py index df31989a..445c1588 100644 --- a/backend/tests/integration/result_processor/test_result_processor.py +++ b/backend/tests/integration/result_processor/test_result_processor.py @@ -9,14 +9,11 @@ from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.kafka import GroupId, KafkaTopic -from app.domain.events.typed import EventMetadata, ExecutionCompletedEvent, ResultStoredEvent +from app.domain.events.typed import EventMetadata, ExecutionCompletedEvent, ResourceUsageAvro, ResultStoredEvent from app.domain.execution import DomainExecutionCreate -from app.domain.execution.models import ResourceUsageDomain from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer from app.events.core.dispatcher import EventDispatcher from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.idempotency import IdempotencyManager -from app.services.idempotency.middleware import IdempotentConsumerWrapper from app.services.result_processor import ProcessorLogic from app.settings import Settings from dishka import AsyncContainer @@ -38,7 +35,6 @@ async def test_result_processor_persists_and_emits( scope: AsyncContainer, schema_registry: SchemaRegistryManager, event_metrics: EventMetrics, - consumer_config: ConsumerConfig, test_settings: Settings, ) -> None: # Ensure schemas @@ -49,7 +45,6 @@ async def test_result_processor_persists_and_emits( db: Database = await scope.get(Database) repo: ExecutionRepository = await scope.get(ExecutionRepository) producer: UnifiedProducer = await scope.get(UnifiedProducer) - idem: IdempotencyManager = await scope.get(IdempotencyManager) # Create a base execution to satisfy ProcessorLogic lookup created = await repo.create_execution(DomainExecutionCreate( @@ -87,7 +82,7 @@ async def test_result_processor_persists_and_emits( request_timeout_ms=test_settings.KAFKA_REQUEST_TIMEOUT_MS, ) - # Create processor consumer + # Create processor consumer (idempotency is now handled by FastStream middleware in production) processor_consumer = UnifiedConsumer( processor_consumer_config, dispatcher=processor_dispatcher, @@ -98,17 +93,6 @@ async def test_result_processor_persists_and_emits( topics=[KafkaTopic.EXECUTION_COMPLETED, KafkaTopic.EXECUTION_FAILED, KafkaTopic.EXECUTION_TIMEOUT], ) - # Wrap with idempotency - processor_wrapper = IdempotentConsumerWrapper( - consumer=processor_consumer, - dispatcher=processor_dispatcher, - idempotency_manager=idem, - logger=_test_logger, - default_key_strategy="content_hash", - default_ttl_seconds=7200, - enable_for_all_handlers=True, - ) - # Setup a separate consumer to capture ResultStoredEvent stored_dispatcher = EventDispatcher(logger=_test_logger) stored_received = asyncio.Event() @@ -137,7 +121,7 @@ async def _stored(event: ResultStoredEvent) -> None: ) # Produce the event BEFORE starting consumers (auto_offset_reset="earliest" will read it) - usage = ResourceUsageDomain( + usage = ResourceUsageAvro( execution_time_wall_seconds=0.5, cpu_time_jiffies=100, clk_tck_hertz=100, @@ -154,7 +138,7 @@ async def _stored(event: ResultStoredEvent) -> None: await producer.produce(evt, key=execution_id) # Start consumers as background tasks - processor_task = asyncio.create_task(processor_wrapper.run()) + processor_task = asyncio.create_task(processor_consumer.run()) stored_task = asyncio.create_task(stored_consumer.run()) try: diff --git a/backend/tests/integration/services/coordinator/test_execution_coordinator.py b/backend/tests/integration/services/coordinator/test_execution_coordinator.py deleted file mode 100644 index c45e300f..00000000 --- a/backend/tests/integration/services/coordinator/test_execution_coordinator.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest -from app.services.coordinator.coordinator_logic import CoordinatorLogic -from dishka import AsyncContainer -from tests.helpers import make_execution_requested_event - -pytestmark = pytest.mark.integration - - -@pytest.mark.asyncio -async def test_handle_requested_and_schedule(scope: AsyncContainer) -> None: - logic: CoordinatorLogic = await scope.get(CoordinatorLogic) - ev = make_execution_requested_event(execution_id="e-real-1") - - # Handler now schedules immediately - no polling needed - await logic._handle_execution_requested(ev) # noqa: SLF001 - - # Execution should be active immediately after handler returns - assert "e-real-1" in logic._active_executions # noqa: SLF001 diff --git a/backend/tests/unit/services/coordinator/test_queue_manager.py b/backend/tests/unit/services/coordinator/test_queue_manager.py deleted file mode 100644 index ebec3a6b..00000000 --- a/backend/tests/unit/services/coordinator/test_queue_manager.py +++ /dev/null @@ -1,37 +0,0 @@ -import logging - -import pytest -from app.core.metrics import CoordinatorMetrics -from app.domain.events.typed import ExecutionRequestedEvent -from app.services.coordinator.queue_manager import QueueManager, QueuePriority - -from tests.helpers import make_execution_requested_event - -_test_logger = logging.getLogger("test.services.coordinator.queue_manager") - -pytestmark = pytest.mark.unit - - -def ev(execution_id: str, priority: int = QueuePriority.NORMAL.value) -> ExecutionRequestedEvent: - return make_execution_requested_event(execution_id=execution_id, priority=priority) - - -@pytest.mark.asyncio -async def test_requeue_execution_increments_priority(coordinator_metrics: CoordinatorMetrics) -> None: - qm = QueueManager(max_queue_size=10, logger=_test_logger, coordinator_metrics=coordinator_metrics) - # Use NORMAL priority which can be incremented to LOW - e = ev("x", priority=QueuePriority.NORMAL.value) - await qm.add_execution(e) - await qm.requeue_execution(e, increment_retry=True) - nxt = await qm.get_next_execution() - assert nxt is not None - - -@pytest.mark.asyncio -async def test_queue_stats_empty_and_after_add(coordinator_metrics: CoordinatorMetrics) -> None: - qm = QueueManager(max_queue_size=5, logger=_test_logger, coordinator_metrics=coordinator_metrics) - stats0 = await qm.get_queue_stats() - assert stats0["total_size"] == 0 - await qm.add_execution(ev("a")) - st = await qm.get_queue_stats() - assert st["total_size"] == 1 diff --git a/backend/tests/unit/services/coordinator/test_resource_manager.py b/backend/tests/unit/services/coordinator/test_resource_manager.py deleted file mode 100644 index 4f579e45..00000000 --- a/backend/tests/unit/services/coordinator/test_resource_manager.py +++ /dev/null @@ -1,70 +0,0 @@ -import logging - -import pytest -from app.core.metrics import CoordinatorMetrics -from app.services.coordinator.resource_manager import ResourceManager - -_test_logger = logging.getLogger("test.services.coordinator.resource_manager") - - -@pytest.mark.asyncio -async def test_request_allocation_defaults_and_limits(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager( - total_cpu_cores=8.0, total_memory_mb=16384, total_gpu_count=0, - logger=_test_logger, coordinator_metrics=coordinator_metrics - ) - - # Default for python - alloc = await rm.request_allocation("e1", "python") - assert alloc is not None - - assert alloc.cpu_cores > 0 - assert alloc.memory_mb > 0 - - # Respect per-exec max cap - alloc2 = await rm.request_allocation("e2", "python", requested_cpu=100.0, requested_memory_mb=999999) - assert alloc2 is not None - assert alloc2.cpu_cores <= rm.pool.max_cpu_per_execution - assert alloc2.memory_mb <= rm.pool.max_memory_per_execution_mb - - -@pytest.mark.asyncio -async def test_release_and_can_allocate(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager( - total_cpu_cores=4.0, total_memory_mb=8192, total_gpu_count=0, - logger=_test_logger, coordinator_metrics=coordinator_metrics - ) - - a = await rm.request_allocation("e1", "python", requested_cpu=1.0, requested_memory_mb=512) - assert a is not None - - ok = await rm.release_allocation("e1") - assert ok is True - - # After release, can allocate near limits while preserving headroom. - # Use a tiny epsilon to avoid edge rounding issues in >= comparisons. - epsilon_cpu = 1e-6 - epsilon_mem = 1 - can = await rm.can_allocate(cpu_cores=rm.pool.total_cpu_cores - rm.pool.min_available_cpu_cores - epsilon_cpu, - memory_mb=rm.pool.total_memory_mb - rm.pool.min_available_memory_mb - epsilon_mem, - gpu_count=0) - assert can is True - - -@pytest.mark.asyncio -async def test_resource_stats(coordinator_metrics: CoordinatorMetrics) -> None: - rm = ResourceManager( - total_cpu_cores=2.0, total_memory_mb=4096, total_gpu_count=0, - logger=_test_logger, coordinator_metrics=coordinator_metrics - ) - # Make sure the allocation succeeds - alloc = await rm.request_allocation("e1", "python", requested_cpu=0.5, requested_memory_mb=256) - assert alloc is not None, "Allocation should have succeeded" - - stats = await rm.get_resource_stats() - - assert stats.total.cpu_cores > 0 - assert stats.available.cpu_cores >= 0 - assert stats.allocated.cpu_cores > 0 # Should be > 0 since we allocated - assert stats.utilization["cpu_percent"] >= 0 - assert stats.allocation_count >= 1 # Should be at least 1 (may have system allocations) diff --git a/backend/tests/unit/services/idempotency/test_idempotency_manager.py b/backend/tests/unit/services/idempotency/test_idempotency_manager.py index ef4676fb..1a7c7f53 100644 --- a/backend/tests/unit/services/idempotency/test_idempotency_manager.py +++ b/backend/tests/unit/services/idempotency/test_idempotency_manager.py @@ -64,7 +64,6 @@ def test_default_config(self) -> None: assert config.processing_timeout_seconds == 300 assert config.enable_result_caching is True assert config.max_result_size_bytes == 1048576 - assert config.enable_metrics is True assert config.collection_name == "idempotency_keys" def test_custom_config(self) -> None: @@ -74,7 +73,6 @@ def test_custom_config(self) -> None: processing_timeout_seconds=600, enable_result_caching=False, max_result_size_bytes=2048, - enable_metrics=False, collection_name="custom_keys", ) assert config.key_prefix == "custom" @@ -82,13 +80,13 @@ def test_custom_config(self) -> None: assert config.processing_timeout_seconds == 600 assert config.enable_result_caching is False assert config.max_result_size_bytes == 2048 - assert config.enable_metrics is False assert config.collection_name == "custom_keys" def test_manager_generate_key_variants(database_metrics: DatabaseMetrics) -> None: repo = MagicMock() - mgr = IdempotencyManager(IdempotencyConfig(), repo, _test_logger, database_metrics=database_metrics) + config = IdempotencyConfig() + mgr = IdempotencyManager(repository=repo, logger=_test_logger, metrics=database_metrics, config=config) ev = MagicMock(spec=BaseEvent) ev.event_type = "t" ev.event_id = "e" diff --git a/backend/tests/unit/services/idempotency/test_middleware.py b/backend/tests/unit/services/idempotency/test_middleware.py deleted file mode 100644 index 4d0e6b2f..00000000 --- a/backend/tests/unit/services/idempotency/test_middleware.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -from unittest.mock import AsyncMock, MagicMock - -import pytest -from app.domain.events.typed import DomainEvent -from app.domain.idempotency import IdempotencyStatus -from app.services.idempotency.idempotency_manager import IdempotencyManager, IdempotencyResult -from app.services.idempotency.middleware import ( - IdempotentEventHandler, -) - -_test_logger = logging.getLogger("test.services.idempotency.middleware") - - -pytestmark = pytest.mark.unit - - -class TestIdempotentEventHandler: - @pytest.fixture - def mock_idempotency_manager(self) -> AsyncMock: - return AsyncMock(spec=IdempotencyManager) - - @pytest.fixture - def mock_handler(self) -> AsyncMock: - handler = AsyncMock() - handler.__name__ = "test_handler" - return handler - - @pytest.fixture - def event(self) -> MagicMock: - event = MagicMock(spec=DomainEvent) - event.event_type = "test.event" - event.event_id = "event-123" - return event - - @pytest.fixture - def idempotent_event_handler( - self, mock_handler: AsyncMock, mock_idempotency_manager: AsyncMock - ) -> IdempotentEventHandler: - return IdempotentEventHandler( - handler=mock_handler, - idempotency_manager=mock_idempotency_manager, - key_strategy="event_based", - ttl_seconds=3600, - logger=_test_logger - ) - - @pytest.mark.asyncio - async def test_call_with_fields( - self, mock_handler: AsyncMock, mock_idempotency_manager: AsyncMock, event: MagicMock - ) -> None: - # Setup with specific fields - fields = {"field1", "field2"} - - handler = IdempotentEventHandler( - handler=mock_handler, - idempotency_manager=mock_idempotency_manager, - key_strategy="content_hash", - fields=fields, - logger=_test_logger - ) - - idempotency_result = IdempotencyResult( - is_duplicate=False, - status=IdempotencyStatus.PROCESSING, - created_at=MagicMock(), - key="test-key" - ) - mock_idempotency_manager.check_and_reserve.return_value = idempotency_result - - # Execute - await handler(event) - - # Verify - mock_idempotency_manager.check_and_reserve.assert_called_once_with( - event=event, - key_strategy="content_hash", - custom_key=None, - ttl_seconds=None, - fields=fields - ) - - @pytest.mark.asyncio - async def test_call_handler_exception( - self, - idempotent_event_handler: IdempotentEventHandler, - mock_idempotency_manager: AsyncMock, - mock_handler: AsyncMock, - event: MagicMock, - ) -> None: - # Setup: Handler raises exception - idempotency_result = IdempotencyResult( - is_duplicate=False, - status=IdempotencyStatus.PROCESSING, - created_at=MagicMock(), - key="test-key" - ) - mock_idempotency_manager.check_and_reserve.return_value = idempotency_result - mock_handler.side_effect = Exception("Handler error") - - # Execute and verify exception is raised - with pytest.raises(Exception, match="Handler error"): - await idempotent_event_handler(event) - - # Verify failure is marked - mock_idempotency_manager.mark_failed.assert_called_once_with( - event=event, - error="Handler error", - key_strategy="event_based", - custom_key=None, - fields=None - ) - - # Duplicate handler and custom key behavior covered by integration tests - - -class TestIdempotentHandlerDecorator: - pass - -class TestIdempotentConsumerWrapper: - pass diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index bfbf5e50..ab063129 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -8,8 +8,13 @@ from app.core.k8s_clients import K8sClients from app.core.metrics import EventMetrics, KubernetesMetrics from app.db.repositories.event_repository import EventRepository -from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionCompletedEvent, ExecutionStartedEvent -from app.domain.execution.models import ResourceUsageDomain +from app.domain.events.typed import ( + DomainEvent, + EventMetadata, + ExecutionCompletedEvent, + ExecutionStartedEvent, + ResourceUsageAvro, +) from app.events.core import UnifiedProducer from app.services.kafka_event_service import KafkaEventService from app.services.pod_monitor.config import PodMonitorConfig @@ -400,7 +405,7 @@ async def test_publish_event_full_flow( execution_id="exec1", aggregate_id="exec1", exit_code=0, - resource_usage=ResourceUsageDomain(), + resource_usage=ResourceUsageAvro(), metadata=EventMetadata(service_name="test", service_version="1.0"), ) diff --git a/backend/uv.lock b/backend/uv.lock index 29f26cae..b058c135 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -102,33 +102,27 @@ wheels = [ [[package]] name = "aiokafka" -version = "0.13.0" +version = "0.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout" }, { name = "packaging" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/87/18/d3a4f8f9ad099fc59217b8cdf66eeecde3a9ef3bb31fe676e431a3b0010f/aiokafka-0.13.0.tar.gz", hash = "sha256:7d634af3c8d694a37a6c8535c54f01a740e74cccf7cc189ecc4a3d64e31ce122", size = 598580, upload-time = "2026-01-02T13:55:18.911Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/60/17/715ac23b4f8df3ff8d7c0a6f1c5fd3a179a8a675205be62d1d1bb27dffa2/aiokafka-0.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:231ecc0038c2736118f1c95149550dbbdf7b7a12069f70c005764fa1824c35d4", size = 346168, upload-time = "2026-01-02T13:54:49.128Z" }, - { url = "https://files.pythonhosted.org/packages/00/26/71c6f4cce2c710c6ffa18b9e294384157f46b0491d5b020de300802d167e/aiokafka-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e2817593cab4c71c1d3b265b2446da91121a467ff7477c65f0f39a80047bc28", size = 349037, upload-time = "2026-01-02T13:54:50.48Z" }, - { url = "https://files.pythonhosted.org/packages/82/18/7b86418a4d3dc1303e89c0391942258ead31c02309e90eb631f3081eec1d/aiokafka-0.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b80e0aa1c811a9a12edb0b94445a0638d61a345932f785d47901d28b8aad86c8", size = 1140066, upload-time = "2026-01-02T13:54:52.33Z" }, - { url = "https://files.pythonhosted.org/packages/f9/51/45e46b4407d39b950c8493e19498aeeb5af4fc461fb54fa0247da16bfd75/aiokafka-0.13.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:79672c456bd1642769e74fc2db1c34f23b15500e978fd38411662e8ca07590ad", size = 1130088, upload-time = "2026-01-02T13:54:53.786Z" }, - { url = "https://files.pythonhosted.org/packages/49/7f/6a66f6fd6fb73e15bd34f574e38703ba36d3f9256c80e7aba007bd8a9256/aiokafka-0.13.0-cp312-cp312-win32.whl", hash = "sha256:00bb4e3d5a237b8618883eb1dd8c08d671db91d3e8e33ac98b04edf64225658c", size = 309581, upload-time = "2026-01-02T13:54:55.444Z" }, - { url = "https://files.pythonhosted.org/packages/d3/e0/a2d5a8912699dd0fee28e6fb780358c63c7a4727517fffc110cb7e43f874/aiokafka-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:0f0cccdf2fd16927fbe077279524950676fbffa7b102d6b117041b3461b5d927", size = 329327, upload-time = "2026-01-02T13:54:56.981Z" }, - { url = "https://files.pythonhosted.org/packages/e3/f6/a74c49759233e98b61182ba3d49d5ac9c8de0643651892acba2704fba1cc/aiokafka-0.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:39d71c40cff733221a6b2afff4beeac5dacbd119fb99eec5198af59115264a1a", size = 343733, upload-time = "2026-01-02T13:54:58.536Z" }, - { url = "https://files.pythonhosted.org/packages/cf/52/4f7e80eee2c69cd8b047c18145469bf0dc27542a5dca3f96ff81ade575b0/aiokafka-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:faa2f5f3d0d2283a0c1a149748cc7e3a3862ef327fa5762e2461088eedde230a", size = 346258, upload-time = "2026-01-02T13:55:00.947Z" }, - { url = "https://files.pythonhosted.org/packages/81/9b/d2766bb3b0bad53eb25a88e51a884be4b77a1706053ad717b893b4daea4b/aiokafka-0.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b890d535e55f5073f939585bef5301634df669e97832fda77aa743498f008662", size = 1114744, upload-time = "2026-01-02T13:55:02.475Z" }, - { url = "https://files.pythonhosted.org/packages/8f/00/12e0a39cd4809149a09b4a52b629abc9bf80e7b8bad9950040b1adae99fc/aiokafka-0.13.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e22eb8a1475b9c0f45b553b6e2dcaf4ec3c0014bf4e389e00a0a0ec85d0e3bdc", size = 1105676, upload-time = "2026-01-02T13:55:04.036Z" }, - { url = "https://files.pythonhosted.org/packages/38/4a/0bc91e90faf55533fe6468461c2dd31c22b0e1d274b9386f341cca3f7eb7/aiokafka-0.13.0-cp313-cp313-win32.whl", hash = "sha256:ae507c7b09e882484f709f2e7172b3a4f75afffcd896d00517feb35c619495bb", size = 308257, upload-time = "2026-01-02T13:55:05.873Z" }, - { url = "https://files.pythonhosted.org/packages/23/63/5433d1aa10c4fb4cf85bd73013263c36d7da4604b0c77ed4d1ad42fae70c/aiokafka-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:fec1a7e3458365a72809edaa2b990f65ca39b01a2a579f879ac4da6c9b2dbc5c", size = 326968, upload-time = "2026-01-02T13:55:07.351Z" }, - { url = "https://files.pythonhosted.org/packages/3c/cc/45b04c3a5fd3d2d5f444889ecceb80b2f78d6d66aa45e3042767e55579e2/aiokafka-0.13.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9a403785f7092c72906c37f7618f7b16a4219eba8ed0bdda90fba410a7dd50b5", size = 344503, upload-time = "2026-01-02T13:55:08.723Z" }, - { url = "https://files.pythonhosted.org/packages/76/df/0b76fe3b93558ae71b856940e384909c4c2c7a1c330423003191e4ba7782/aiokafka-0.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:256807326831b7eee253ea1017bd2b19ab1c2298ce6b20a87fde97c253c572bc", size = 347621, upload-time = "2026-01-02T13:55:10.147Z" }, - { url = "https://files.pythonhosted.org/packages/34/1a/d59932f98fd3c106e2a7c8d4d5ebd8df25403436dfc27b3031918a37385e/aiokafka-0.13.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:64d90f91291da265d7f25296ba68fc6275684eebd6d1cf05a1b2abe6c2ba3543", size = 1111410, upload-time = "2026-01-02T13:55:11.763Z" }, - { url = "https://files.pythonhosted.org/packages/7e/04/fbf3e34ab3bc21e6e760c3fcd089375052fccc04eb8745459a82a58a647b/aiokafka-0.13.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b5a33cc043c8d199bcf101359d86f2d31fd54f4b157ac12028bdc34e3e1cf74a", size = 1094799, upload-time = "2026-01-02T13:55:13.795Z" }, - { url = "https://files.pythonhosted.org/packages/85/10/509f709fd3b7c3e568a5b8044be0e80a1504f8da6ddc72c128b21e270913/aiokafka-0.13.0-cp314-cp314-win32.whl", hash = "sha256:538950384b539ba2333d35a853f09214c0409e818e5d5f366ef759eea50bae9c", size = 311553, upload-time = "2026-01-02T13:55:15.928Z" }, - { url = "https://files.pythonhosted.org/packages/2b/18/424d6a4eb6f4835a371c1e2cfafce800540b33d957c6638795d911f98973/aiokafka-0.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:c906dd42daadd14b4506a2e6c62dfef3d4919b5953d32ae5e5f0d99efd103c89", size = 330648, upload-time = "2026-01-02T13:55:17.421Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/65/ca/42a962033e6a7926dcb789168bce81d0181ef4ddabce454d830b7e62370e/aiokafka-0.12.0.tar.gz", hash = "sha256:62423895b866f95b5ed8d88335295a37cc5403af64cb7cb0e234f88adc2dff94", size = 564955, upload-time = "2024-10-26T20:53:11.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/d4/baf1b2389995c6c312834792329a1993a303ff703ac023250ff977c5923b/aiokafka-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b01947553ff1120fa1cb1a05f2c3e5aa47a5378c720bafd09e6630ba18af02aa", size = 375031, upload-time = "2024-10-26T20:52:40.104Z" }, + { url = "https://files.pythonhosted.org/packages/54/ac/653070a4add8beea7aa8209ab396de87c7b4f9628fff15efcdbaea40e973/aiokafka-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e3c8ec1c0606fa645462c7353dc3e4119cade20c4656efa2031682ffaad361c0", size = 370619, upload-time = "2024-10-26T20:52:41.877Z" }, + { url = "https://files.pythonhosted.org/packages/80/f2/0ddaaa11876ab78e0f3b30f272c62eea70870e1a52a5afe985c7c1d098e1/aiokafka-0.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:577c1c48b240e9eba57b3d2d806fb3d023a575334fc3953f063179170cc8964f", size = 1192363, upload-time = "2024-10-26T20:52:44.028Z" }, + { url = "https://files.pythonhosted.org/packages/ae/48/541ccece0e593e24ee371dec0c33c23718bc010b04e998693e4c19091258/aiokafka-0.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7b815b2e5fed9912f1231be6196547a367b9eb3380b487ff5942f0c73a3fb5c", size = 1213231, upload-time = "2024-10-26T20:52:46.028Z" }, + { url = "https://files.pythonhosted.org/packages/99/3f/75bd0faa77dfecce34dd1c0edd317b608518b096809736f9987dd61f4cec/aiokafka-0.12.0-cp312-cp312-win32.whl", hash = "sha256:5a907abcdf02430df0829ac80f25b8bb849630300fa01365c76e0ae49306f512", size = 347752, upload-time = "2024-10-26T20:52:47.327Z" }, + { url = "https://files.pythonhosted.org/packages/ef/97/e2513a0c10585e51d4d9b42c9dd5f5ab15dfe150620a4893a2c6c20f0f4a/aiokafka-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:fdbd69ec70eea4a8dfaa5c35ff4852e90e1277fcc426b9380f0b499b77f13b16", size = 366068, upload-time = "2024-10-26T20:52:49.132Z" }, + { url = "https://files.pythonhosted.org/packages/30/84/f1f7e603cd07e877520b5a1e48e006cbc1fe448806cabbaa98aa732f530d/aiokafka-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f9e8ab97b935ca681a5f28cf22cf2b5112be86728876b3ec07e4ed5fc6c21f2d", size = 370960, upload-time = "2024-10-26T20:52:51.235Z" }, + { url = "https://files.pythonhosted.org/packages/d7/c7/5237b3687198c2129c0bafa4a96cf8ae3883e20cc860125bafe16af3778e/aiokafka-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ed991c120fe19fd9439f564201dd746c4839700ef270dd4c3ee6d4895f64fe83", size = 366597, upload-time = "2024-10-26T20:52:52.539Z" }, + { url = "https://files.pythonhosted.org/packages/6b/67/0154551292ec1c977e5def178ae5c947773e921aefb6877971e7fdf1942e/aiokafka-0.12.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c01abf9787b1c3f3af779ad8e76d5b74903f590593bc26f33ed48750503e7f7", size = 1152905, upload-time = "2024-10-26T20:52:54.089Z" }, + { url = "https://files.pythonhosted.org/packages/d9/20/69f913a76916e94c4e783dc7d0d05a25c384b25faec33e121062c62411fe/aiokafka-0.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08c84b3894d97fd02fcc8886f394000d0f5ce771fab5c498ea2b0dd2f6b46d5b", size = 1171893, upload-time = "2024-10-26T20:52:56.14Z" }, + { url = "https://files.pythonhosted.org/packages/16/65/41cc1b19e7dea623ef58f3bf1e2720377c5757a76d9799d53a1b5fc39255/aiokafka-0.12.0-cp313-cp313-win32.whl", hash = "sha256:63875fed922c8c7cf470d9b2a82e1b76b4a1baf2ae62e07486cf516fd09ff8f2", size = 345933, upload-time = "2024-10-26T20:52:57.518Z" }, + { url = "https://files.pythonhosted.org/packages/bf/0d/4cb57231ff650a01123a09075bf098d8fdaf94b15a1a58465066b2251e8b/aiokafka-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:bdc0a83eb386d2384325d6571f8ef65b4cfa205f8d1c16d7863e8d10cacd995a", size = 363194, upload-time = "2024-10-26T20:52:59.434Z" }, ] [[package]] @@ -299,6 +293,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/56/3124f61d37a7a4e7cc96afc5492c78ba0cb551151e530b54669ddd1436ef/cachetools-6.2.0-py3-none-any.whl", hash = "sha256:1c76a8960c0041fcc21097e357f882197c79da0dbff766e7317890a65d7d8ba6", size = 11276, upload-time = "2025-08-25T18:57:29.684Z" }, ] +[[package]] +name = "casefy" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/50/f5991618899c42d0c6339bd83fed5f694f56b204dfb3f2a052f0d586d4c5/casefy-1.0.0.tar.gz", hash = "sha256:bc99428475c2089c5f6a21297b4cfe4e83dff132cf3bb06655ddcb90632af1ed", size = 123432, upload-time = "2024-11-30T10:00:59.015Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/b6/797c17fe4804896836ef33b447f0a2641d59a8d1d63e63834fb3fbc87cd8/casefy-1.0.0-py3-none-any.whl", hash = "sha256:c89f96fb0fbd13691073b7a65c1e668e81453247d647479a3db105e86d7b0df9", size = 6299, upload-time = "2024-11-30T10:00:33.134Z" }, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -583,6 +586,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, ] +[[package]] +name = "dacite" +version = "1.9.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/55/a0/7ca79796e799a3e782045d29bf052b5cde7439a2bbb17f15ff44f7aacc63/dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09", size = 22420, upload-time = "2025-02-05T09:27:29.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/35/386550fd60316d1e37eccdda609b074113298f23cef5bddb2049823fe666/dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0", size = 16600, upload-time = "2025-02-05T09:27:24.345Z" }, +] + +[[package]] +name = "dataclasses-avroschema" +version = "0.66.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "casefy" }, + { name = "dacite" }, + { name = "fastavro" }, + { name = "inflection" }, + { name = "python-dateutil" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/09/6ae7907e80b1488bf1f13c21ce47a74b43c7aecfd6ab21265815ba0249bd/dataclasses_avroschema-0.66.2.tar.gz", hash = "sha256:dd2d2360d4a47f14799293e8a462424d3d872617501ee2a8a2100ecbffb27577", size = 45957, upload-time = "2025-12-03T13:54:02.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/a2/f0de317c74a53c558533d001e886d6273c6b8497c15c58b124fffb9063a1/dataclasses_avroschema-0.66.2-py3-none-any.whl", hash = "sha256:f839b21ac3a1f23b03e49c111db1393913e59d3cee37410a9cb80c9925951610", size = 59268, upload-time = "2025-12-03T13:54:00.42Z" }, +] + +[package.optional-dependencies] +pydantic = [ + { name = "pydantic", extra = ["email"] }, +] + [[package]] name = "deprecated" version = "1.2.14" @@ -597,11 +631,11 @@ wheels = [ [[package]] name = "dishka" -version = "1.6.0" +version = "1.7.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/04/f3add05678a3ac1ab7736faae45b18b5365d84b1cd3cf3af64b09a1d6a5f/dishka-1.6.0.tar.gz", hash = "sha256:f1fa5ec7e980d4f618d0c425d1bb81d8e9414894d8ec6553b197d2298774e12f", size = 65971, upload-time = "2025-05-18T21:40:53.259Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/d7/1be31f5ef32387059190353f9fa493ff4d07a1c75fa856c7566ca45e0800/dishka-1.7.2.tar.gz", hash = "sha256:47d4cb5162b28c61bf5541860e605ed5eaf5c667122299c7ef657c86fc8d5a49", size = 68132, upload-time = "2025-09-24T21:23:05.135Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/6b/f9cd08543c4f55bf129a0ebce5c09e43528235dd6e7cb906761ca094979a/dishka-1.6.0-py3-none-any.whl", hash = "sha256:ab1aedee152ce7bb11cfd2673d7ce4001fe2b330d14e84535d7525a68430b2c2", size = 90789, upload-time = "2025-05-18T21:40:51.352Z" }, + { url = "https://files.pythonhosted.org/packages/b7/b9/89381173b4f336e986d72471198614806cd313e0f85c143ccb677c310223/dishka-1.7.2-py3-none-any.whl", hash = "sha256:f6faa6ab321903926b825b3337d77172ee693450279b314434864978d01fbad3", size = 94774, upload-time = "2025-09-24T21:23:03.246Z" }, ] [[package]] @@ -653,6 +687,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, ] +[[package]] +name = "fast-depends" +version = "3.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/f3/41e955f5f0811de6ef9f00f8462f2ade7bc4a99b93714c9b134646baa831/fast_depends-3.0.5.tar.gz", hash = "sha256:c915a54d6e0d0f0393686d37c14d54d9ec7c43d7b9def3f3fc4f7b4d52f67f2a", size = 18235, upload-time = "2025-11-30T20:26:12.92Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/dd/76697228ae63dcbaf0a0a1b20fc996433a33f184ac4f578382b681dcf5ea/fast_depends-3.0.5-py3-none-any.whl", hash = "sha256:38a3d7044d3d6d0b1bed703691275c870316426e8a9bfa6b1c89e979b15659e2", size = 25362, upload-time = "2025-11-30T20:26:10.96Z" }, +] + +[package.optional-dependencies] +pydantic = [ + { name = "pydantic" }, +] + [[package]] name = "fastapi" version = "0.128.0" @@ -703,6 +755,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/93/b44f67589e4d439913dab6720f7e3507b0fa8b8e56d06f6fc875ced26afb/fastavro-1.12.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:43ded16b3f4a9f1a42f5970c2aa618acb23ea59c4fcaa06680bdf470b255e5a8", size = 3386636, upload-time = "2025-10-10T15:42:18.974Z" }, ] +[[package]] +name = "faststream" +version = "0.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "fast-depends", extra = ["pydantic"] }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/55/fc2a34405ac63aaf973a33884c5f77f87564bbb9a4343c1d103cbf8c87f5/faststream-0.6.5.tar.gz", hash = "sha256:ddef9e85631edf1aba87e81c8886067bd94ee752f41a09f1d1cd6f75f7e4fade", size = 302206, upload-time = "2025-12-29T16:44:04.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/1c/f60c0b15a8bce42ed7256d2eca6ea713bac259d6684ff392a907f95ca345/faststream-0.6.5-py3-none-any.whl", hash = "sha256:714b13b84cdbe2bdcf0b2b8a5e2b04648cb7683784a5297d445adfee9f2b4f7e", size = 507108, upload-time = "2025-12-29T16:44:03.297Z" }, +] + +[package.optional-dependencies] +kafka = [ + { name = "aiokafka" }, +] + [[package]] name = "fonttools" version = "4.61.1" @@ -1002,6 +1073,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "inflection" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/7e/691d061b7329bc8d54edbf0ec22fbfb2afe61facb681f9aaa9bff7a27d04/inflection-0.5.1.tar.gz", hash = "sha256:1a29730d366e996aaacffb2f1f1cb9593dc38e2ddd30c91250c6dde09ea9b417", size = 15091, upload-time = "2020-08-22T08:16:29.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/91/aa6bde563e0085a02a435aa99b49ef75b0a4b062635e606dab23ce18d720/inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2", size = 9454, upload-time = "2020-08-22T08:16:27.816Z" }, +] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1039,6 +1119,7 @@ dependencies = [ { name = "configargparse" }, { name = "contourpy" }, { name = "cycler" }, + { name = "dataclasses-avroschema", extra = ["pydantic"] }, { name = "deprecated" }, { name = "dishka" }, { name = "dnspython" }, @@ -1047,6 +1128,7 @@ dependencies = [ { name = "exceptiongroup" }, { name = "fastapi" }, { name = "fastavro" }, + { name = "faststream", extra = ["kafka"] }, { name = "fonttools" }, { name = "frozenlist" }, { name = "google-auth" }, @@ -1101,7 +1183,6 @@ dependencies = [ { name = "pyasn1" }, { name = "pyasn1-modules" }, { name = "pydantic" }, - { name = "pydantic-avro" }, { name = "pydantic-core" }, { name = "pydantic-settings" }, { name = "pygments" }, @@ -1161,7 +1242,7 @@ dev = [ requires-dist = [ { name = "aiohappyeyeballs", specifier = "==2.6.1" }, { name = "aiohttp", specifier = "==3.13.3" }, - { name = "aiokafka", specifier = "==0.13.0" }, + { name = "aiokafka", specifier = ">=0.12.0,<0.13" }, { name = "aiosignal", specifier = "==1.4.0" }, { name = "aiosmtplib", specifier = "==3.0.2" }, { name = "annotated-doc", specifier = "==0.0.4" }, @@ -1182,14 +1263,16 @@ requires-dist = [ { name = "configargparse", specifier = "==1.7.1" }, { name = "contourpy", specifier = "==1.3.3" }, { name = "cycler", specifier = "==0.12.1" }, + { name = "dataclasses-avroschema", extras = ["pydantic"], specifier = ">=0.65.0" }, { name = "deprecated", specifier = "==1.2.14" }, - { name = "dishka", specifier = "==1.6.0" }, + { name = "dishka", specifier = "==1.7.2" }, { name = "dnspython", specifier = "==2.7.0" }, { name = "durationpy", specifier = "==0.9" }, { name = "email-validator", specifier = "==2.3.0" }, { name = "exceptiongroup", specifier = "==1.2.2" }, { name = "fastapi", specifier = "==0.128.0" }, { name = "fastavro", specifier = "==1.12.1" }, + { name = "faststream", extras = ["kafka"], specifier = ">=0.6.0" }, { name = "fonttools", specifier = "==4.61.1" }, { name = "frozenlist", specifier = "==1.7.0" }, { name = "google-auth", specifier = "==2.45.0" }, @@ -1244,7 +1327,6 @@ requires-dist = [ { name = "pyasn1", specifier = "==0.6.1" }, { name = "pyasn1-modules", specifier = "==0.4.2" }, { name = "pydantic", specifier = "==2.9.2" }, - { name = "pydantic-avro", specifier = "==0.9.1" }, { name = "pydantic-core", specifier = "==2.23.4" }, { name = "pydantic-settings", specifier = "==2.5.2" }, { name = "pygments", specifier = "==2.19.2" }, @@ -2365,16 +2447,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/e4/ba44652d562cbf0bf320e0f3810206149c8a4e99cdbf66da82e97ab53a15/pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12", size = 434928, upload-time = "2024-09-17T15:59:51.827Z" }, ] -[[package]] -name = "pydantic-avro" -version = "0.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pydantic" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/8b/47ea4be231ba90984228486fe9a332cb6f18db6963d04207a1a9f310c45b/pydantic_avro-0.9.1.tar.gz", hash = "sha256:22f728340fad3353b232ec2b138496c26efb2ede5b74a2f18ab491d4ea37ec5b", size = 10015, upload-time = "2025-10-16T12:00:29.536Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/25/69/6bb45c70da28c3aa82772c136e9c87ff0498c4fd4875594ebe3f7a4cd47c/pydantic_avro-0.9.1-py3-none-any.whl", hash = "sha256:dcbec25c6f2021db594f3116dd94e029a4cb96ab63eec3dcb3ad4405b434c23a", size = 11510, upload-time = "2025-10-16T12:00:28.718Z" }, +[package.optional-dependencies] +email = [ + { name = "email-validator" }, ] [[package]] diff --git a/backend/workers/run_coordinator.py b/backend/workers/run_coordinator.py deleted file mode 100644 index 0f71f42b..00000000 --- a/backend/workers/run_coordinator.py +++ /dev/null @@ -1,95 +0,0 @@ -import asyncio -import logging -import signal - -from app.core.container import create_coordinator_container -from app.core.database_context import Database -from app.core.logging import setup_logger -from app.core.tracing import init_tracing -from app.db.docs import ALL_DOCUMENTS -from app.domain.enums.kafka import GroupId -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.coordinator.coordinator_logic import CoordinatorLogic -from app.services.idempotency.middleware import IdempotentConsumerWrapper -from app.settings import Settings -from beanie import init_beanie - - -async def run_coordinator(settings: Settings) -> None: - """Run the execution coordinator service.""" - - container = create_coordinator_container(settings) - logger = await container.get(logging.Logger) - logger.info("Starting ExecutionCoordinator with DI container...") - - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - - schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) - - consumer = await container.get(IdempotentConsumerWrapper) - logic = await container.get(CoordinatorLogic) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("ExecutionCoordinator initialized, starting run...") - - async def run_coordinator_tasks() -> None: - """Run consumer and scheduling loop using TaskGroup.""" - async with asyncio.TaskGroup() as tg: - tg.create_task(consumer.run()) - tg.create_task(logic.scheduling_loop()) - - try: - # Run coordinator until shutdown signal - run_task = asyncio.create_task(run_coordinator_tasks()) - shutdown_task = asyncio.create_task(shutdown_event.wait()) - - done, pending = await asyncio.wait( - [run_task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - # Cancel remaining tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - finally: - logger.info("Initiating graceful shutdown...") - await container.close() - - -def main() -> None: - """Main entry point for coordinator worker""" - settings = Settings() - - logger = setup_logger(settings.LOG_LEVEL) - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - logger.info("Starting ExecutionCoordinator worker...") - - if settings.ENABLE_TRACING: - init_tracing( - service_name=GroupId.EXECUTION_COORDINATOR, - settings=settings, - logger=logger, - service_version=settings.TRACING_SERVICE_VERSION, - enable_console_exporter=False, - sampling_rate=settings.TRACING_SAMPLING_RATE, - ) - logger.info("Tracing initialized for ExecutionCoordinator") - - asyncio.run(run_coordinator(settings)) - - -if __name__ == "__main__": - main() diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index ea16a46a..9037ec42 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -1,84 +1,65 @@ +""" +Kubernetes Worker using FastStream. + +This is the clean version: +- No asyncio.get_running_loop() +- No signal.SIGINT/SIGTERM handlers +- No create_task() at worker level +- No manual consumer loops +- No TaskGroup management + +Everything is handled by FastStream + Dishka DI. +""" + import asyncio import logging -import signal +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager -from app.core.container import create_k8s_worker_container -from app.core.database_context import Database from app.core.logging import setup_logger +from app.core.providers import ( + EventProvider, + K8sWorkerProvider, + LoggingProvider, + MessagingProvider, + MetricsProvider, + RedisProvider, + SettingsProvider, +) from app.core.tracing import init_tracing -from app.db.docs import ALL_DOCUMENTS -from app.domain.enums.kafka import GroupId +from app.domain.enums.events import EventType +from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId +from app.domain.events.typed import ( + CreatePodCommandEvent, + DeletePodCommandEvent, + DomainEvent, +) from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.k8s_worker.worker_logic import K8sWorkerLogic from app.settings import Settings -from beanie import init_beanie - - -async def run_kubernetes_worker(settings: Settings) -> None: - """Run the Kubernetes worker service.""" - - container = create_k8s_worker_container(settings) - logger = await container.get(logging.Logger) - logger.info("Starting KubernetesWorker with DI container...") - - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - - schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) - - consumer = await container.get(IdempotentConsumerWrapper) - logic = await container.get(K8sWorkerLogic) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("KubernetesWorker initialized, starting run...") - - async def run_worker_tasks() -> None: - """Run consumer and daemonset setup using TaskGroup.""" - async with asyncio.TaskGroup() as tg: - tg.create_task(consumer.run()) - tg.create_task(logic.ensure_daemonset_task()) - - try: - # Run worker until shutdown signal - run_task = asyncio.create_task(run_worker_tasks()) - shutdown_task = asyncio.create_task(shutdown_event.wait()) - - done, pending = await asyncio.wait( - [run_task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - # Cancel remaining tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - finally: - logger.info("Initiating graceful shutdown...") - # Wait for active pod creations to complete - await logic.wait_for_active_creations() - await container.close() +from dishka import make_async_container +from dishka.integrations.faststream import FromDishka, setup_dishka +from faststream import FastStream +from faststream.kafka import KafkaBroker def main() -> None: - """Main entry point for Kubernetes worker""" + """ + Entry point - minimal boilerplate. + + FastStream handles: + - Signal handling (SIGINT/SIGTERM) + - Consumer loop + - Graceful shutdown + """ settings = Settings() + # Setup logging logger = setup_logger(settings.LOG_LEVEL) - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - logger.info("Starting KubernetesWorker...") + logger.info("Starting KubernetesWorker (FastStream)...") + # Setup tracing if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.K8S_WORKER, @@ -88,9 +69,99 @@ def main() -> None: enable_console_exporter=False, sampling_rate=settings.TRACING_SAMPLING_RATE, ) - logger.info("Tracing initialized for KubernetesWorker") - asyncio.run(run_kubernetes_worker(settings)) + # Create DI container (no DatabaseProvider/RepositoryProvider - K8s worker doesn't use MongoDB) + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + RedisProvider(), + MetricsProvider(), + EventProvider(), + MessagingProvider(), + K8sWorkerProvider(), + context={Settings: settings}, + ) + + # Build topic list and group ID from config + topics = [ + f"{settings.KAFKA_TOPIC_PREFIX}{t}" + for t in CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.K8S_WORKER] + ] + group_id = f"{GroupId.K8S_WORKER}.{settings.KAFKA_GROUP_SUFFIX}" + + # Create broker + broker = KafkaBroker( + settings.KAFKA_BOOTSTRAP_SERVERS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + # Create lifespan for infrastructure initialization + @asynccontextmanager + async def lifespan(app: FastStream) -> AsyncIterator[None]: + """Initialize infrastructure before app starts.""" + app_logger = await container.get(logging.Logger) + app_logger.info("KubernetesWorker starting...") + + # Initialize schema registry + schema_registry = await container.get(SchemaRegistryManager) + await initialize_event_schemas(schema_registry) + + # Get worker logic and ensure daemonset (one-time initialization) + logic = await container.get(K8sWorkerLogic) + await logic.ensure_image_pre_puller_daemonset() + + # Decoder: Avro bytes → typed DomainEvent + async def decode_avro(body: bytes) -> DomainEvent: + return await schema_registry.deserialize_event(body, "k8s_worker") + + # Create subscriber with Avro decoder + subscriber = broker.subscriber( + *topics, + group_id=group_id, + auto_commit=False, + decoder=decode_avro, + ) + + # Route by event_type header (producer sets this, Kafka stores as bytes) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.CREATE_POD_COMMAND.encode()) + async def handle_create_pod_command( + event: CreatePodCommandEvent, + worker_logic: FromDishka[K8sWorkerLogic], + ) -> None: + await worker_logic._handle_create_pod_command(event) + + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.DELETE_POD_COMMAND.encode()) + async def handle_delete_pod_command( + event: DeletePodCommandEvent, + worker_logic: FromDishka[K8sWorkerLogic], + ) -> None: + await worker_logic._handle_delete_pod_command(event) + + # Default handler for unmatched events (prevents message loss) + @subscriber() + async def handle_other(event: DomainEvent) -> None: + pass + + app_logger.info("Infrastructure initialized, starting event processing...") + + yield + + # Graceful shutdown: wait for active pod creations + app_logger.info("KubernetesWorker shutting down...") + await logic.wait_for_active_creations() + await container.close() + + # Create FastStream app + app = FastStream(broker, lifespan=lifespan) + + # Setup Dishka integration for automatic DI in handlers + setup_dishka(container=container, app=app, auto_inject=True) + + # Add idempotency middleware (appends to end = most inner, runs after Dishka) + broker.add_middleware(IdempotencyMiddleware) + + # Run! FastStream handles signal handling, consumer loops, graceful shutdown + asyncio.run(app.run()) if __name__ == "__main__": diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 4549148f..9675755b 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -1,10 +1,31 @@ +""" +Pod Monitor Worker (Simplified). + +Note: Unlike other workers, PodMonitor watches Kubernetes pods directly +(not consuming Kafka messages), so FastStream's subscriber pattern doesn't apply. + +This version uses a minimal signal handling approach. +""" + import asyncio import logging import signal +from contextlib import suppress -from app.core.container import create_pod_monitor_container from app.core.database_context import Database from app.core.logging import setup_logger +from app.core.providers import ( + DatabaseProvider, + EventProvider, + KubernetesProvider, + LoggingProvider, + MessagingProvider, + MetricsProvider, + PodMonitorProvider, + RedisProvider, + RepositoryProvider, + SettingsProvider, +) from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId @@ -12,12 +33,25 @@ from app.services.pod_monitor.monitor import PodMonitor from app.settings import Settings from beanie import init_beanie +from dishka import make_async_container async def run_pod_monitor(settings: Settings) -> None: """Run the pod monitor service.""" + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + RedisProvider(), + DatabaseProvider(), + MetricsProvider(), + EventProvider(), + MessagingProvider(), + RepositoryProvider(), + KubernetesProvider(), + PodMonitorProvider(), + context={Settings: settings}, + ) - container = create_pod_monitor_container(settings) logger = await container.get(logging.Logger) logger.info("Starting PodMonitor with DI container...") @@ -29,31 +63,28 @@ async def run_pod_monitor(settings: Settings) -> None: monitor = await container.get(PodMonitor) - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() + # Signal handling with minimal boilerplate + shutdown = asyncio.Event() loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) + loop.add_signal_handler(sig, shutdown.set) logger.info("PodMonitor initialized, starting run...") try: - # Run monitor until shutdown signal - run_task = asyncio.create_task(monitor.run()) - shutdown_task = asyncio.create_task(shutdown_event.wait()) + # Run monitor until shutdown + monitor_task = asyncio.create_task(monitor.run()) + shutdown_task = asyncio.create_task(shutdown.wait()) done, pending = await asyncio.wait( - [run_task, shutdown_task], + [monitor_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED, ) - # Cancel remaining tasks for task in pending: task.cancel() - try: + with suppress(asyncio.CancelledError): await task - except asyncio.CancelledError: - pass finally: logger.info("Initiating graceful shutdown...") @@ -61,12 +92,10 @@ async def run_pod_monitor(settings: Settings) -> None: def main() -> None: - """Main entry point for pod monitor worker""" + """Main entry point for pod monitor worker.""" settings = Settings() logger = setup_logger(settings.LOG_LEVEL) - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - logger.info("Starting PodMonitor worker...") if settings.ENABLE_TRACING: @@ -78,7 +107,6 @@ def main() -> None: enable_console_exporter=False, sampling_rate=settings.TRACING_SAMPLING_RATE, ) - logger.info("Tracing initialized for PodMonitor Service") asyncio.run(run_pod_monitor(settings)) diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 6325fc35..9b483684 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -1,72 +1,71 @@ +""" +Result Processor Worker using FastStream. + +This is the clean version: +- No asyncio.get_running_loop() +- No signal.SIGINT/SIGTERM handlers +- No create_task() +- No manual consumer loops +- No TaskGroup management + +Everything is handled by FastStream + Dishka DI. +""" + import asyncio import logging -import signal +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager -from app.core.container import create_result_processor_container +from app.core.database_context import Database from app.core.logging import setup_logger +from app.core.providers import ( + DatabaseProvider, + EventProvider, + LoggingProvider, + MessagingProvider, + MetricsProvider, + RedisProvider, + RepositoryProvider, + ResultProcessorProvider, + SettingsProvider, +) from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.domain.enums.kafka import GroupId -from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.domain.enums.events import EventType +from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId +from app.domain.events.typed import ( + DomainEvent, + ExecutionCompletedEvent, + ExecutionFailedEvent, + ExecutionTimeoutEvent, +) +from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.services.idempotency.faststream_middleware import IdempotencyMiddleware +from app.services.result_processor.processor_logic import ProcessorLogic from app.settings import Settings from beanie import init_beanie -from pymongo.asynchronous.mongo_client import AsyncMongoClient - - -async def run_result_processor(settings: Settings) -> None: - """Run the result processor service.""" - - db_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( - settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 - ) - await init_beanie(database=db_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) - - container = create_result_processor_container(settings) - logger = await container.get(logging.Logger) - - consumer = await container.get(IdempotentConsumerWrapper) - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info("ResultProcessor consumer initialized, starting run...") - - try: - # Run consumer until shutdown signal - run_task = asyncio.create_task(consumer.run()) - shutdown_task = asyncio.create_task(shutdown_event.wait()) - - done, pending = await asyncio.wait( - [run_task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - # Cancel remaining tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - finally: - logger.info("Initiating graceful shutdown...") - await container.close() - await db_client.close() +from dishka import make_async_container +from dishka.integrations.faststream import FromDishka, setup_dishka +from faststream import FastStream +from faststream.kafka import KafkaBroker def main() -> None: - """Main entry point for result processor worker""" + """ + Entry point - minimal boilerplate. + + FastStream handles: + - Signal handling (SIGINT/SIGTERM) + - Consumer loop + - Graceful shutdown + """ settings = Settings() + # Setup logging logger = setup_logger(settings.LOG_LEVEL) - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - logger.info("Starting ResultProcessor worker...") + logger.info("Starting ResultProcessor (FastStream)...") + # Setup tracing if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.RESULT_PROCESSOR, @@ -76,9 +75,106 @@ def main() -> None: enable_console_exporter=False, sampling_rate=settings.TRACING_SAMPLING_RATE, ) - logger.info("Tracing initialized for ResultProcessor Service") - asyncio.run(run_result_processor(settings)) + # Create DI container with all providers + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + RedisProvider(), + DatabaseProvider(), + MetricsProvider(), + EventProvider(), + MessagingProvider(), + RepositoryProvider(), + ResultProcessorProvider(), + context={Settings: settings}, + ) + + # Build topic list and group ID from config + topics = [ + f"{settings.KAFKA_TOPIC_PREFIX}{t}" + for t in CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.RESULT_PROCESSOR] + ] + group_id = f"{GroupId.RESULT_PROCESSOR}.{settings.KAFKA_GROUP_SUFFIX}" + + # Create broker + broker = KafkaBroker( + settings.KAFKA_BOOTSTRAP_SERVERS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + # Create lifespan for infrastructure initialization + @asynccontextmanager + async def lifespan(app: FastStream) -> AsyncIterator[None]: + """Initialize infrastructure before app starts.""" + app_logger = await container.get(logging.Logger) + app_logger.info("ResultProcessor starting...") + + # Initialize database + db = await container.get(Database) + await init_beanie(database=db, document_models=ALL_DOCUMENTS) + + # Initialize schema registry + schema_registry = await container.get(SchemaRegistryManager) + await initialize_event_schemas(schema_registry) + + # Decoder: Avro bytes → typed DomainEvent + async def decode_avro(body: bytes) -> DomainEvent: + return await schema_registry.deserialize_event(body, "result_processor") + + # Create subscriber with Avro decoder + subscriber = broker.subscriber( + *topics, + group_id=group_id, + auto_commit=False, + decoder=decode_avro, + ) + + # Route by event_type header (producer sets this, Kafka stores as bytes) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_COMPLETED.encode()) + async def handle_completed( + event: ExecutionCompletedEvent, + logic: FromDishka[ProcessorLogic], + ) -> None: + await logic._handle_completed(event) + + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_FAILED.encode()) + async def handle_failed( + event: ExecutionFailedEvent, + logic: FromDishka[ProcessorLogic], + ) -> None: + await logic._handle_failed(event) + + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_TIMEOUT.encode()) + async def handle_timeout( + event: ExecutionTimeoutEvent, + logic: FromDishka[ProcessorLogic], + ) -> None: + await logic._handle_timeout(event) + + # Default handler for unmatched events (prevents message loss) + @subscriber() + async def handle_other(event: DomainEvent) -> None: + pass + + app_logger.info("Infrastructure initialized, starting event processing...") + + yield + + app_logger.info("ResultProcessor shutting down...") + await container.close() + + # Create FastStream app + app = FastStream(broker, lifespan=lifespan) + + # Setup Dishka integration for automatic DI in handlers + setup_dishka(container=container, app=app, auto_inject=True) + + # Add idempotency middleware (appends to end = most inner, runs after Dishka) + broker.add_middleware(IdempotencyMiddleware) + + # Run! FastStream handles signal handling, consumer loops, graceful shutdown + asyncio.run(app.run()) if __name__ == "__main__": diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 87b4754d..80a42240 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -1,90 +1,66 @@ +""" +Saga Orchestrator Worker using FastStream. + +This is the clean version: +- No asyncio.get_running_loop() +- No signal.SIGINT/SIGTERM handlers +- No create_task() +- No manual consumer loops +- No TaskGroup management + +Everything is handled by FastStream + Dishka DI. +""" + import asyncio import logging -import signal +import time +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager -from app.core.container import create_saga_orchestrator_container from app.core.database_context import Database from app.core.logging import setup_logger +from app.core.providers import ( + DatabaseProvider, + EventProvider, + LoggingProvider, + MessagingProvider, + MetricsProvider, + RedisProvider, + RepositoryProvider, + SagaOrchestratorProvider, + SettingsProvider, +) from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.domain.events.typed import DomainEvent from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.idempotency.middleware import IdempotentConsumerWrapper +from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.saga.saga_logic import SagaLogic from app.settings import Settings from beanie import init_beanie - - -async def run_saga_orchestrator(settings: Settings) -> None: - """Run the saga orchestrator.""" - - container = create_saga_orchestrator_container(settings) - logger = await container.get(logging.Logger) - logger.info("Starting SagaOrchestrator with DI container...") - - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - - schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) - - consumer = await container.get(IdempotentConsumerWrapper | None) - logic = await container.get(SagaLogic) - - # Handle case where no sagas have triggers - if consumer is None: - logger.warning("No consumer provided (no saga triggers), exiting") - await container.close() - return - - # Shutdown event - signal handlers just set this - shutdown_event = asyncio.Event() - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown_event.set) - - logger.info(f"SagaOrchestrator initialized for saga: {logic.config.name}, starting run...") - - async def run_orchestrator_tasks() -> None: - """Run consumer and timeout checker using TaskGroup.""" - async with asyncio.TaskGroup() as tg: - tg.create_task(consumer.run()) - tg.create_task(logic.check_timeouts_loop()) - - try: - # Run orchestrator until shutdown signal - run_task = asyncio.create_task(run_orchestrator_tasks()) - shutdown_task = asyncio.create_task(shutdown_event.wait()) - - done, pending = await asyncio.wait( - [run_task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - # Cancel remaining tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - finally: - logger.info("Initiating graceful shutdown...") - await container.close() - - logger.warning("Saga orchestrator stopped") +from dishka import make_async_container +from dishka.integrations.faststream import FromDishka, setup_dishka +from faststream import FastStream +from faststream.kafka import KafkaBroker def main() -> None: - """Main entry point for saga orchestrator worker""" + """ + Entry point - minimal boilerplate. + + FastStream handles: + - Signal handling (SIGINT/SIGTERM) + - Consumer loop + - Graceful shutdown + """ settings = Settings() + # Setup logging logger = setup_logger(settings.LOG_LEVEL) - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - - logger.info("Starting Saga Orchestrator worker...") + logger.info("Starting SagaOrchestrator (FastStream)...") + # Setup tracing if settings.ENABLE_TRACING: init_tracing( service_name=GroupId.SAGA_ORCHESTRATOR, @@ -94,9 +70,112 @@ def main() -> None: enable_console_exporter=False, sampling_rate=settings.TRACING_SAMPLING_RATE, ) - logger.info("Tracing initialized for Saga Orchestrator Service") - asyncio.run(run_saga_orchestrator(settings)) + # Create DI container with all providers + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + RedisProvider(), + DatabaseProvider(), + MetricsProvider(), + EventProvider(), + MessagingProvider(), + RepositoryProvider(), + SagaOrchestratorProvider(), + context={Settings: settings}, + ) + + # We need to determine topics dynamically based on registered sagas + # This will be done in lifespan after SagaLogic is initialized + + # Create broker + broker = KafkaBroker( + settings.KAFKA_BOOTSTRAP_SERVERS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + # Track timeout checking state for opportunistic checking + timeout_check_state = {"last_check": 0.0, "interval": 30.0} + + # Create lifespan for infrastructure initialization + @asynccontextmanager + async def lifespan(app: FastStream) -> AsyncIterator[None]: + """Initialize infrastructure before app starts.""" + app_logger = await container.get(logging.Logger) + app_logger.info("SagaOrchestrator starting...") + + # Initialize database + db = await container.get(Database) + await init_beanie(database=db, document_models=ALL_DOCUMENTS) + + # Initialize schema registry + schema_registry = await container.get(SchemaRegistryManager) + await initialize_event_schemas(schema_registry) + + # Get saga logic to determine topics + logic = await container.get(SagaLogic) + trigger_topics = logic.get_trigger_topics() + + if not trigger_topics: + app_logger.warning("No saga triggers configured, shutting down") + yield + await container.close() + return + + # Build topic list with prefix + topics = [f"{settings.KAFKA_TOPIC_PREFIX}{t}" for t in trigger_topics] + group_id = f"{GroupId.SAGA_ORCHESTRATOR}.{settings.KAFKA_GROUP_SUFFIX}" + + # Decoder: Avro bytes → typed DomainEvent + async def decode_avro(body: bytes) -> DomainEvent: + return await schema_registry.deserialize_event(body, "saga_orchestrator") + + # Register handler dynamically after determining topics + # Saga orchestrator uses single handler - routing is internal to SagaLogic + @broker.subscriber( + *topics, + group_id=group_id, + auto_commit=False, + decoder=decode_avro, + ) + async def handle_saga_event( + event: DomainEvent, + saga_logic: FromDishka[SagaLogic], + ) -> None: + """ + Handle saga trigger events. + + Dependencies are automatically injected via Dishka. + Routing is handled internally by SagaLogic based on saga configuration. + """ + # Handle the event through saga logic (internal routing) + await saga_logic.handle_event(event) + + # Opportunistic timeout check (replaces background loop) + now = time.monotonic() + if now - timeout_check_state["last_check"] >= timeout_check_state["interval"]: + await saga_logic.check_timeouts_once() + timeout_check_state["last_check"] = now + + app_logger.info(f"Subscribing to topics: {topics}") + app_logger.info("Infrastructure initialized, starting event processing...") + + yield + + app_logger.info("SagaOrchestrator shutting down...") + await container.close() + + # Create FastStream app + app = FastStream(broker, lifespan=lifespan) + + # Setup Dishka integration for automatic DI in handlers + setup_dishka(container=container, app=app, auto_inject=True) + + # Add idempotency middleware (appends to end = most inner, runs after Dishka) + broker.add_middleware(IdempotencyMiddleware) + + # Run! FastStream handles signal handling, consumer loops, graceful shutdown + asyncio.run(app.run()) if __name__ == "__main__": From a5c9e97f1203182dada0aea4ae78d2d96f733c12 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 14:50:01 +0100 Subject: [PATCH 09/21] moved lifetime mgmt to DI --- backend/app/core/container.py | 27 +- backend/app/core/dishka_lifespan.py | 67 +- backend/app/core/k8s_clients.py | 47 -- backend/app/core/providers.py | 398 ++++++--- backend/app/dlq/manager.py | 172 ++-- backend/app/events/core/__init__.py | 4 - backend/app/events/core/consumer.py | 126 --- backend/app/events/core/dispatcher.py | 177 ---- backend/app/events/core/producer.py | 91 +-- backend/app/events/event_store_consumer.py | 103 +-- backend/app/services/event_bus.py | 174 ++-- .../app/services/k8s_worker/worker_logic.py | 155 +--- backend/app/services/notification_service.py | 72 +- backend/app/services/pod_monitor/monitor.py | 23 +- .../result_processor/processor_logic.py | 29 +- backend/app/services/saga/saga_logic.py | 9 +- backend/app/services/sse/event_router.py | 8 +- backend/app/services/sse/sse_service.py | 4 +- backend/pyproject.toml | 2 + .../tests/e2e/test_k8s_worker_create_pod.py | 25 +- backend/tests/helpers/fakes/__init__.py | 11 + backend/tests/helpers/fakes/kafka.py | 78 ++ backend/tests/helpers/fakes/kubernetes.py | 54 ++ backend/tests/helpers/fakes/providers.py | 84 ++ .../tests/helpers/fakes/schema_registry.py | 116 +++ .../events/test_consume_roundtrip.py | 70 -- .../events/test_consumer_lifecycle.py | 56 -- .../events/test_event_dispatcher.py | 72 -- .../events/test_producer_roundtrip.py | 35 +- .../result_processor/test_result_processor.py | 165 ---- .../sse/test_partitioned_event_router.py | 41 - backend/tests/unit/conftest.py | 185 ++++- .../unit/events/test_event_dispatcher.py | 61 -- .../unit/services/coordinator/__init__.py | 0 .../unit/services/pod_monitor/test_monitor.py | 771 ++++++++---------- .../result_processor/test_processor.py | 28 - .../saga/test_execution_saga_steps.py | 176 ++-- .../saga/test_saga_orchestrator_unit.py | 100 +-- .../services/sse/test_kafka_redis_bridge.py | 50 -- .../unit/services/sse/test_sse_service.py | 11 +- backend/uv.lock | 74 ++ backend/workers/run_event_replay.py | 5 +- backend/workers/run_k8s_worker.py | 30 +- backend/workers/run_pod_monitor.py | 15 +- backend/workers/run_result_processor.py | 22 +- backend/workers/run_saga_orchestrator.py | 22 +- backend/workers/run_sse_bridge.py | 137 ++++ 47 files changed, 1822 insertions(+), 2360 deletions(-) delete mode 100644 backend/app/core/k8s_clients.py delete mode 100644 backend/app/events/core/consumer.py delete mode 100644 backend/app/events/core/dispatcher.py create mode 100644 backend/tests/helpers/fakes/__init__.py create mode 100644 backend/tests/helpers/fakes/kafka.py create mode 100644 backend/tests/helpers/fakes/kubernetes.py create mode 100644 backend/tests/helpers/fakes/providers.py create mode 100644 backend/tests/helpers/fakes/schema_registry.py delete mode 100644 backend/tests/integration/events/test_consume_roundtrip.py delete mode 100644 backend/tests/integration/events/test_consumer_lifecycle.py delete mode 100644 backend/tests/integration/events/test_event_dispatcher.py delete mode 100644 backend/tests/integration/result_processor/test_result_processor.py delete mode 100644 backend/tests/integration/services/sse/test_partitioned_event_router.py delete mode 100644 backend/tests/unit/events/test_event_dispatcher.py delete mode 100644 backend/tests/unit/services/coordinator/__init__.py delete mode 100644 backend/tests/unit/services/result_processor/test_processor.py delete mode 100644 backend/tests/unit/services/sse/test_kafka_redis_bridge.py create mode 100644 backend/workers/run_sse_bridge.py diff --git a/backend/app/core/container.py b/backend/app/core/container.py index 0d62da6c..9b5bddd6 100644 --- a/backend/app/core/container.py +++ b/backend/app/core/container.py @@ -4,6 +4,7 @@ from app.core.providers import ( AdminServicesProvider, AuthProvider, + BoundaryClientProvider, BusinessServicesProvider, CoreServicesProvider, DatabaseProvider, @@ -11,12 +12,11 @@ EventReplayProvider, K8sWorkerProvider, KafkaServicesProvider, - KubernetesProvider, LoggingProvider, MessagingProvider, MetricsProvider, PodMonitorProvider, - RedisProvider, + RedisServicesProvider, RepositoryProvider, ResultProcessorProvider, SagaOrchestratorProvider, @@ -38,7 +38,8 @@ def create_app_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), @@ -67,7 +68,8 @@ def create_result_processor_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), @@ -84,13 +86,13 @@ def create_k8s_worker_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), MessagingProvider(), EventProvider(), - KubernetesProvider(), K8sWorkerProvider(), context={Settings: settings}, ) @@ -102,14 +104,14 @@ def create_pod_monitor_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), MessagingProvider(), EventProvider(), KafkaServicesProvider(), - KubernetesProvider(), PodMonitorProvider(), context={Settings: settings}, ) @@ -121,7 +123,8 @@ def create_saga_orchestrator_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), @@ -138,7 +141,8 @@ def create_event_replay_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), @@ -155,7 +159,8 @@ def create_dlq_processor_container(settings: Settings) -> AsyncContainer: SettingsProvider(), LoggingProvider(), DatabaseProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), RepositoryProvider(), diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 7b00c6a1..b186eba7 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -1,7 +1,7 @@ import asyncio import logging +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator import redis.asyncio as redis from beanie import init_beanie @@ -13,22 +13,26 @@ from app.core.startup import initialize_rate_limits from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS -from app.events.core import UnifiedConsumer +from app.events.core import UnifiedProducer from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.services.event_bus import EventBus from app.services.notification_service import NotificationService from app.settings import Settings @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - """ - Application lifespan with dishka dependency injection. + """Application lifespan with dishka dependency injection. + + All service lifecycle (start/stop, background tasks) is managed by DI providers. + Lifespan only: + 1. Resolves dependencies (triggers provider lifecycle setup) + 2. Initializes schemas and beanie + 3. On shutdown, container cleanup handles everything - Services are already initialized by their DI providers (which handle __aenter__/__aexit__). - Lifespan just starts the run() methods as background tasks. + Note: SSE Kafka consumers are now in the separate SSE bridge worker (run_sse_bridge.py). """ - # Get settings and logger from DI container (uses test settings in tests) container: AsyncContainer = app.state.dishka_container settings = await container.get(Settings) logger = await container.get(logging.Logger) @@ -41,10 +45,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: }, ) - # Metrics setup moved to app creation to allow middleware registration - logger.info("Lifespan start: tracing and services initialization") - - # Initialize tracing only when enabled (avoid exporter retries in tests) + # Initialize tracing only when enabled if settings.ENABLE_TRACING and not settings.TESTING: instrumentation_report = init_tracing( service_name=settings.TRACING_SERVICE_NAME, @@ -73,59 +74,37 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: ) # Phase 1: Resolve all DI dependencies in parallel + # This triggers async generator providers which start services and background tasks ( schema_registry, database, redis_client, rate_limit_metrics, - sse_consumers, - event_store_consumer, - notification_service, + _event_store_consumer, + _notification_service, + _kafka_producer, + _event_bus, ) = await asyncio.gather( container.get(SchemaRegistryManager), container.get(Database), container.get(redis.Redis), container.get(RateLimitMetrics), - container.get(list[UnifiedConsumer]), container.get(EventStoreConsumer), container.get(NotificationService), + container.get(UnifiedProducer), + container.get(EventBus), ) - # Phase 2: Initialize infrastructure in parallel (independent subsystems) + # Phase 2: Initialize infrastructure await asyncio.gather( initialize_event_schemas(schema_registry), init_beanie(database=database, document_models=ALL_DOCUMENTS), initialize_rate_limits(redis_client, settings, logger, rate_limit_metrics), ) - logger.info("Infrastructure initialized (schemas, beanie, rate limits)") - - # Phase 3: Start run() methods as background tasks - # Note: Services are already initialized by their DI providers (which handle __aenter__/__aexit__) - - async def run_sse_consumers() -> None: - """Run SSE consumers using TaskGroup.""" - async with asyncio.TaskGroup() as tg: - for consumer in sse_consumers: - tg.create_task(consumer.run()) - - tasks = [ - asyncio.create_task(run_sse_consumers(), name="sse_consumers"), - asyncio.create_task(event_store_consumer.run(), name="event_store_consumer"), - asyncio.create_task(notification_service.run(), name="notification_service"), - ] - logger.info( - "Background services started", - extra={"sse_consumer_count": len(sse_consumers)}, - ) + logger.info("Application started - all services running") try: yield finally: - # Cancel all background tasks on shutdown - logger.info("Shutting down background services...") - for task in tasks: - task.cancel() - - # Wait for tasks to finish cancellation - await asyncio.gather(*tasks, return_exceptions=True) - logger.info("Background services stopped") + # Container cleanup handles all service shutdown via async generator cleanup + logger.info("Application shutting down - container cleanup will stop all services") diff --git a/backend/app/core/k8s_clients.py b/backend/app/core/k8s_clients.py deleted file mode 100644 index 0aedd5c7..00000000 --- a/backend/app/core/k8s_clients.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -from dataclasses import dataclass - -from kubernetes import client as k8s_client -from kubernetes import config as k8s_config -from kubernetes import watch as k8s_watch - - -@dataclass(frozen=True) -class K8sClients: - """Kubernetes API clients bundle for dependency injection.""" - - api_client: k8s_client.ApiClient - v1: k8s_client.CoreV1Api - apps_v1: k8s_client.AppsV1Api - networking_v1: k8s_client.NetworkingV1Api - watch: k8s_watch.Watch - - -def create_k8s_clients( - logger: logging.Logger, kubeconfig_path: str | None = None, in_cluster: bool | None = None -) -> K8sClients: - if in_cluster: - k8s_config.load_incluster_config() - elif kubeconfig_path: - k8s_config.load_kube_config(config_file=kubeconfig_path) - else: - k8s_config.load_kube_config() - - configuration = k8s_client.Configuration.get_default_copy() - logger.info(f"Kubernetes API host: {configuration.host}") - logger.info(f"SSL CA configured: {configuration.ssl_ca_cert is not None}") - - api_client = k8s_client.ApiClient(configuration) - return K8sClients( - api_client=api_client, - v1=k8s_client.CoreV1Api(api_client), - apps_v1=k8s_client.AppsV1Api(api_client), - networking_v1=k8s_client.NetworkingV1Api(api_client), - watch=k8s_watch.Watch(), - ) - - -def close_k8s_clients(clients: K8sClients) -> None: - close = getattr(clients.api_client, "close", None) - if callable(close): - close() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index f4cc6d9c..c02c5e9b 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -1,13 +1,16 @@ +import asyncio import logging -from typing import AsyncIterator +from collections.abc import AsyncIterable import redis.asyncio as redis from aiokafka import AIOKafkaConsumer, AIOKafkaProducer from dishka import Provider, Scope, from_context, provide +from kubernetes import client as k8s_client +from kubernetes import config as k8s_config +from kubernetes import watch as k8s_watch from pymongo.asynchronous.mongo_client import AsyncMongoClient from app.core.database_context import Database -from app.core.k8s_clients import K8sClients, close_k8s_clients, create_k8s_clients from app.core.logging import setup_logger from app.core.metrics import ( ConnectionMetrics, @@ -42,9 +45,9 @@ from app.db.repositories.user_settings_repository import UserSettingsRepository from app.dlq.manager import DLQManager from app.dlq.models import RetryPolicy, RetryStrategy -from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId, KafkaTopic +from app.domain.enums.kafka import GroupId, KafkaTopic from app.domain.saga.models import SagaConfig -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer, UnifiedProducer +from app.events.core import ProducerMetrics, UnifiedProducer from app.events.event_store import EventStore, create_event_store from app.events.event_store_consumer import EventStoreConsumer from app.events.schema.schema_registry import SchemaRegistryManager @@ -71,7 +74,6 @@ from app.services.saga.saga_logic import SagaLogic from app.services.saga.saga_service import SagaService from app.services.saved_script_service import SavedScriptService -from app.services.sse.event_router import SSEEventRouter from app.services.sse.redis_bus import SSERedisBus from app.services.sse.sse_connection_registry import SSEConnectionRegistry from app.services.sse.sse_service import SSEService @@ -93,13 +95,20 @@ def get_logger(self, settings: Settings) -> logging.Logger: return setup_logger(settings.LOG_LEVEL) -class RedisProvider(Provider): +class BoundaryClientProvider(Provider): + """Provides all external boundary clients (Redis, Kafka, K8s). + + These are the ONLY places that create connections to external systems. + Override this provider in tests with fakes to test without external deps. + """ + scope = Scope.APP + # Redis @provide - async def get_redis_client(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[redis.Redis]: - # Create Redis client - it will automatically use the current event loop - client = redis.Redis( + def get_redis_client(self, settings: Settings, logger: logging.Logger) -> redis.Redis: + logger.info(f"Redis configured: {settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}") + return redis.Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, @@ -110,13 +119,55 @@ async def get_redis_client(self, settings: Settings, logger: logging.Logger) -> socket_connect_timeout=5, socket_timeout=5, ) - # Test connection - await client.ping() # type: ignore[misc] # redis-py dual sync/async return type - logger.info(f"Redis connected: {settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}") - try: - yield client - finally: - await client.aclose() + + # Kafka - one shared producer for all services + @provide + async def get_kafka_producer_client( + self, settings: Settings, logger: logging.Logger + ) -> AsyncIterable[AIOKafkaProducer]: + """Provide AIOKafkaProducer with DI-managed lifecycle.""" + producer = AIOKafkaProducer( + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + client_id=f"{settings.SERVICE_NAME}-producer", + acks="all", + compression_type="gzip", + max_batch_size=16384, + linger_ms=10, + enable_idempotence=True, + ) + await producer.start() + logger.info("Kafka producer started") + + yield producer + + await producer.stop() + logger.info("Kafka producer stopped") + + # Kubernetes + @provide + def get_k8s_api_client(self, settings: Settings, logger: logging.Logger) -> k8s_client.ApiClient: + k8s_config.load_kube_config(config_file=settings.KUBERNETES_CONFIG_PATH) + configuration = k8s_client.Configuration.get_default_copy() + logger.info(f"Kubernetes API host: {configuration.host}") + return k8s_client.ApiClient(configuration) + + @provide + def get_k8s_core_v1_api(self, api_client: k8s_client.ApiClient) -> k8s_client.CoreV1Api: + return k8s_client.CoreV1Api(api_client) + + @provide + def get_k8s_apps_v1_api(self, api_client: k8s_client.ApiClient) -> k8s_client.AppsV1Api: + return k8s_client.AppsV1Api(api_client) + + @provide + def get_k8s_watch(self) -> k8s_watch.Watch: + return k8s_watch.Watch() + + +class RedisServicesProvider(Provider): + """Services that depend on Redis.""" + + scope = Scope.APP @provide def get_rate_limit_service( @@ -129,16 +180,12 @@ class DatabaseProvider(Provider): scope = Scope.APP @provide - async def get_database(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[Database]: + def get_database(self, settings: Settings, logger: logging.Logger) -> Database: client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 ) - database = client[settings.DATABASE_NAME] - logger.info(f"MongoDB connected: {settings.DATABASE_NAME}") - try: - yield database - finally: - await client.close() + logger.info(f"MongoDB configured: {settings.DATABASE_NAME}") + return client[settings.DATABASE_NAME] class CoreServicesProvider(Provider): @@ -157,21 +204,38 @@ class MessagingProvider(Provider): scope = Scope.APP @provide - async def get_kafka_producer( - self, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger, - event_metrics: EventMetrics - ) -> AsyncIterator[UnifiedProducer]: - async with UnifiedProducer(schema_registry, logger, settings, event_metrics) as producer: - yield producer + def get_kafka_producer( + self, + kafka_producer: AIOKafkaProducer, + schema_registry: SchemaRegistryManager, + settings: Settings, + logger: logging.Logger, + event_metrics: EventMetrics, + ) -> UnifiedProducer: + """Provide UnifiedProducer. Kafka producer lifecycle managed by BoundaryClientProvider.""" + return UnifiedProducer( + producer=kafka_producer, + metrics=ProducerMetrics(), + schema_registry=schema_registry, + settings=settings, + logger=logger, + event_metrics=event_metrics, + ) @provide async def get_dlq_manager( self, + kafka_producer: AIOKafkaProducer, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger, dlq_metrics: DLQMetrics, - ) -> AsyncIterator[DLQManager]: + ) -> AsyncIterable[DLQManager]: + """Provide DLQManager with DI-managed lifecycle. + + Producer lifecycle managed by BoundaryClientProvider. This provider + manages the consumer and background tasks. + """ topic_name = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DEAD_LETTER_QUEUE}" consumer = AIOKafkaConsumer( topic_name, @@ -185,27 +249,44 @@ async def get_dlq_manager( max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, ) - producer = AIOKafkaProducer( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - client_id="dlq-manager-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - manager = DLQManager( + + await consumer.start() + + dlq_manager = DLQManager( settings=settings, consumer=consumer, - producer=producer, + producer=kafka_producer, schema_registry=schema_registry, logger=logger, dlq_metrics=dlq_metrics, dlq_topic=KafkaTopic.DEAD_LETTER_QUEUE, default_retry_policy=RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF), ) - async with manager: - yield manager + + # Background task: process incoming DLQ messages + async def process_messages_loop() -> None: + async for msg in consumer: + await dlq_manager.process_consumer_message(msg) + + # Background task: periodic check for scheduled retries + async def monitor_loop() -> None: + while True: + await dlq_manager.check_scheduled_retries() + await asyncio.sleep(10) + + process_task = asyncio.create_task(process_messages_loop()) + monitor_task = asyncio.create_task(monitor_loop()) + logger.info("DLQ Manager started") + + yield dlq_manager + + # Cleanup + process_task.cancel() + monitor_task.cancel() + await asyncio.gather(process_task, monitor_task, return_exceptions=True) + + await consumer.stop() + logger.info("DLQ Manager stopped") @provide def get_idempotency_config(self) -> IdempotencyConfig: @@ -239,7 +320,7 @@ def get_schema_registry(self, settings: Settings, logger: logging.Logger) -> Sch return SchemaRegistryManager(settings, logger) @provide - async def get_event_store( + def get_event_store( self, schema_registry: SchemaRegistryManager, logger: logging.Logger, event_metrics: EventMetrics ) -> EventStore: return create_event_store( @@ -247,42 +328,115 @@ async def get_event_store( ) @provide - def get_event_store_consumer( + async def get_event_store_consumer( self, event_store: EventStore, schema_registry: SchemaRegistryManager, settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, - ) -> EventStoreConsumer: + ) -> AsyncIterable[EventStoreConsumer]: + """Provide EventStoreConsumer with DI-managed lifecycle.""" topics = get_all_topics() - return EventStoreConsumer( + topic_strings = [f"{settings.KAFKA_TOPIC_PREFIX}{topic}" for topic in topics] + + consumer = AIOKafkaConsumer( + *topic_strings, + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"{GroupId.EVENT_STORE_CONSUMER}.{settings.KAFKA_GROUP_SUFFIX}", + enable_auto_commit=False, + max_poll_records=100, + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + fetch_max_wait_ms=5000, + ) + + await consumer.start() + logger.info(f"Event store consumer started for topics: {topic_strings}") + + event_store_consumer = EventStoreConsumer( event_store=event_store, - topics=list(topics), + consumer=consumer, schema_registry_manager=schema_registry, - settings=settings, logger=logger, event_metrics=event_metrics, ) + async def batch_loop() -> None: + while True: + await event_store_consumer.process_batch() + + task = asyncio.create_task(batch_loop()) + + yield event_store_consumer + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + await consumer.stop() + logger.info("Event store consumer stopped") + @provide async def get_event_bus( - self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics - ) -> AsyncIterator[EventBus]: - async with EventBus(settings, logger, connection_metrics) as bus: - yield bus + self, + kafka_producer: AIOKafkaProducer, + settings: Settings, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, + ) -> AsyncIterable[EventBus]: + """Provide EventBus with DI-managed lifecycle. + Producer lifecycle managed by BoundaryClientProvider. This provider + manages the consumer and background listener task. + """ + topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" + consumer = AIOKafkaConsumer( + topic, + bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, + group_id=f"event-bus-{settings.SERVICE_NAME}", + auto_offset_reset="latest", + enable_auto_commit=True, + client_id=f"event-bus-consumer-{settings.SERVICE_NAME}", + session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, + heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, + max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) -class KubernetesProvider(Provider): - scope = Scope.APP + await consumer.start() - @provide - async def get_k8s_clients(self, settings: Settings, logger: logging.Logger) -> AsyncIterator[K8sClients]: - clients = create_k8s_clients(logger, kubeconfig_path=settings.KUBERNETES_CONFIG_PATH) + event_bus = EventBus( + producer=kafka_producer, + consumer=consumer, + settings=settings, + logger=logger, + connection_metrics=connection_metrics, + ) + + # Create background listener task + async def listener_loop() -> None: + while True: + await event_bus.process_kafka_message() + + listener_task = asyncio.create_task(listener_loop()) + logger.info("Event bus started with Kafka backing") + + yield event_bus + + # Cleanup + listener_task.cancel() try: - yield clients - finally: - close_k8s_clients(clients) + await listener_task + except asyncio.CancelledError: + pass + + await consumer.stop() + logger.info("Event bus stopped") class MetricsProvider(Provider): @@ -398,65 +552,17 @@ def get_user_repository(self) -> UserRepository: class SSEProvider(Provider): - """Provides SSE (Server-Sent Events) related services.""" + """Provides SSE (Server-Sent Events) related services. - scope = Scope.APP + Note: Kafka consumers for SSE are now in the separate SSE bridge worker + (run_sse_bridge.py). This provider only handles Redis pub/sub and SSE service. + """ - @provide - async def get_sse_redis_bus(self, redis_client: redis.Redis, logger: logging.Logger) -> AsyncIterator[SSERedisBus]: - bus = SSERedisBus(redis_client, logger) - yield bus + scope = Scope.APP @provide - def get_sse_event_router( - self, - sse_redis_bus: SSERedisBus, - logger: logging.Logger, - ) -> SSEEventRouter: - return SSEEventRouter(sse_bus=sse_redis_bus, logger=logger) - - @provide - def get_sse_consumers( - self, - router: SSEEventRouter, - schema_registry: SchemaRegistryManager, - settings: Settings, - event_metrics: EventMetrics, - logger: logging.Logger, - ) -> list[UnifiedConsumer]: - """Create SSE consumer pool with routing handlers wired to SSEEventRouter.""" - topics = list(CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.WEBSOCKET_GATEWAY]) - suffix = settings.KAFKA_GROUP_SUFFIX - consumers: list[UnifiedConsumer] = [] - - for i in range(settings.SSE_CONSUMER_POOL_SIZE): - dispatcher = EventDispatcher(logger=logger) - router.register_handlers(dispatcher) - - config = ConsumerConfig( - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"sse-bridge-pool.{suffix}", - client_id=f"sse-bridge-{i}.{suffix}", - enable_auto_commit=True, - auto_offset_reset="latest", - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - consumer = UnifiedConsumer( - config=config, - dispatcher=dispatcher, - schema_registry=schema_registry, - settings=settings, - logger=logger, - event_metrics=event_metrics, - topics=topics, - ) - consumers.append(consumer) - - return consumers + def get_sse_redis_bus(self, redis_client: redis.Redis, logger: logging.Logger) -> SSERedisBus: + return SSERedisBus(redis_client, logger) @provide(scope=Scope.REQUEST) def get_sse_connection_registry( @@ -473,7 +579,6 @@ def get_sse_connection_registry( def get_sse_service( self, sse_repository: SSERepository, - consumers: list[UnifiedConsumer], sse_redis_bus: SSERedisBus, connection_registry: SSEConnectionRegistry, settings: Settings, @@ -482,7 +587,6 @@ def get_sse_service( ) -> SSEService: return SSEService( repository=sse_repository, - num_consumers=len(consumers), sse_bus=sse_redis_bus, connection_registry=connection_registry, settings=settings, @@ -575,7 +679,7 @@ def get_admin_settings_service( return AdminSettingsService(admin_settings_repository, logger) @provide - def get_notification_service( + async def get_notification_service( self, notification_repository: NotificationRepository, event_bus: EventBus, @@ -583,8 +687,9 @@ def get_notification_service( settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, - ) -> NotificationService: - return NotificationService( + ) -> AsyncIterable[NotificationService]: + """Provide NotificationService with DI-managed background tasks.""" + service = NotificationService( notification_repository=notification_repository, event_bus=event_bus, sse_bus=sse_redis_bus, @@ -593,6 +698,27 @@ def get_notification_service( notification_metrics=notification_metrics, ) + async def pending_loop() -> None: + while True: + await service.process_pending_batch() + await asyncio.sleep(5) + + async def cleanup_loop() -> None: + while True: + await asyncio.sleep(86400) # 24 hours + await service.cleanup_old() + + pending_task = asyncio.create_task(pending_loop()) + cleanup_task = asyncio.create_task(cleanup_loop()) + logger.info("NotificationService background tasks started") + + yield service + + pending_task.cancel() + cleanup_task.cancel() + await asyncio.gather(pending_task, cleanup_task, return_exceptions=True) + logger.info("NotificationService background tasks stopped") + @provide def get_grafana_alert_processor( self, @@ -615,11 +741,6 @@ def _create_default_saga_config() -> SagaConfig: ) -# Standalone factory functions for services (no lifecycle - run() handles everything) - - - - class BusinessServicesProvider(Provider): scope = Scope.REQUEST @@ -664,7 +785,7 @@ def get_saved_script_service( return SavedScriptService(saved_script_repository, logger) @provide - async def get_replay_service( + def get_replay_service( self, replay_repository: ReplayRepository, kafka_producer: UnifiedProducer, @@ -711,18 +832,23 @@ def get_k8s_worker_logic( settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, + kubernetes_metrics: KubernetesMetrics, + execution_metrics: ExecutionMetrics, + k8s_v1: k8s_client.CoreV1Api, + k8s_apps_v1: k8s_client.AppsV1Api, ) -> K8sWorkerLogic: config = K8sWorkerConfig() - logic = K8sWorkerLogic( + return K8sWorkerLogic( config=config, producer=kafka_producer, settings=settings, logger=logger, event_metrics=event_metrics, + kubernetes_metrics=kubernetes_metrics, + execution_metrics=execution_metrics, + k8s_v1=k8s_v1, + k8s_apps_v1=k8s_apps_v1, ) - # Initialize K8s clients synchronously (safe during DI setup) - logic.initialize() - return logic class PodMonitorProvider(Provider): @@ -732,15 +858,16 @@ class PodMonitorProvider(Provider): def get_event_mapper( self, logger: logging.Logger, - k8s_clients: K8sClients, + k8s_v1: k8s_client.CoreV1Api, ) -> PodEventMapper: - return PodEventMapper(logger=logger, k8s_api=k8s_clients.v1) + return PodEventMapper(logger=logger, k8s_api=k8s_v1) @provide def get_pod_monitor( self, kafka_event_service: KafkaEventService, - k8s_clients: K8sClients, + k8s_v1: k8s_client.CoreV1Api, + k8s_watch: k8s_watch.Watch, logger: logging.Logger, event_mapper: PodEventMapper, kubernetes_metrics: KubernetesMetrics, @@ -750,7 +877,8 @@ def get_pod_monitor( config=config, kafka_event_service=kafka_event_service, logger=logger, - k8s_clients=k8s_clients, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, event_mapper=event_mapper, kubernetes_metrics=kubernetes_metrics, ) diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index da434964..3e0cd5c9 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -2,9 +2,9 @@ import json import logging from datetime import datetime, timezone -from typing import Any, Callable +from typing import Callable -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka import AIOKafkaConsumer, AIOKafkaProducer, ConsumerRecord from opentelemetry.trace import SpanKind from app.core.metrics import DLQMetrics @@ -32,6 +32,11 @@ class DLQManager: + """Dead Letter Queue manager - pure logic class. + + Lifecycle (start/stop consumer, background tasks) managed by DI provider. + """ + def __init__( self, settings: Settings, @@ -56,9 +61,6 @@ def __init__( self.consumer: AIOKafkaConsumer = consumer self.producer: AIOKafkaProducer = producer - self._process_task: asyncio.Task[None] | None = None - self._monitor_task: asyncio.Task[None] | None = None - # Topic-specific retry policies self._retry_policies: dict[str, RetryPolicy] = {} @@ -68,72 +70,42 @@ def __init__( self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") - def _kafka_msg_to_message(self, msg: Any) -> DLQMessage: + def _kafka_msg_to_message(self, msg: ConsumerRecord[bytes, bytes]) -> DLQMessage: """Parse Kafka ConsumerRecord into DLQMessage.""" data = json.loads(msg.value) headers = {k: v.decode() for k, v in (msg.headers or [])} return DLQMessage(**data, dlq_offset=msg.offset, dlq_partition=msg.partition, headers=headers) - async def __aenter__(self) -> "DLQManager": - """Start DLQ manager.""" - # Start producer and consumer in parallel for faster startup - await asyncio.gather(self.producer.start(), self.consumer.start()) - - # Start processing tasks - self._process_task = asyncio.create_task(self._process_messages()) - self._monitor_task = asyncio.create_task(self._monitor_dlq()) - - self.logger.info("DLQ Manager started") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop DLQ manager.""" - # Cancel tasks - for task in [self._process_task, self._monitor_task]: - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Stop Kafka clients - await self.consumer.stop() - await self.producer.stop() - - self.logger.info("DLQ Manager stopped") - - async def _process_messages(self) -> None: - """Process DLQ messages using async iteration.""" - async for msg in self.consumer: - try: - start = asyncio.get_running_loop().time() - dlq_msg = self._kafka_msg_to_message(msg) - - # Record metrics - self.metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) - self.metrics.record_dlq_message_age((datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds()) - - # Process with tracing - ctx = extract_trace_context(dlq_msg.headers) - with get_tracer().start_as_current_span( - name="dlq.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: self.dlq_topic, - EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, - EventAttributes.EVENT_ID: dlq_msg.event.event_id, - }, - ): - await self._process_dlq_message(dlq_msg) - - # Commit and record duration - await self.consumer.commit() - self.metrics.record_dlq_processing_duration(asyncio.get_running_loop().time() - start, "process") + async def process_consumer_message(self, msg: ConsumerRecord[bytes, bytes]) -> None: + """Process a single DLQ message. Called by DI provider's background task.""" + try: + start = asyncio.get_running_loop().time() + dlq_msg = self._kafka_msg_to_message(msg) + + # Record metrics + self.metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) + self.metrics.record_dlq_message_age((datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds()) + + # Process with tracing + ctx = extract_trace_context(dlq_msg.headers) + with get_tracer().start_as_current_span( + name="dlq.consume", + context=ctx, + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: self.dlq_topic, + EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, + EventAttributes.EVENT_ID: dlq_msg.event.event_id, + }, + ): + await self._process_dlq_message(dlq_msg) + + # Commit and record duration + await self.consumer.commit() + self.metrics.record_dlq_processing_duration(asyncio.get_running_loop().time() - start, "process") - except Exception as e: - self.logger.error(f"Error processing DLQ message: {e}") + except Exception as e: + self.logger.error(f"Error processing DLQ message: {e}") async def _process_dlq_message(self, message: DLQMessage) -> None: # Apply filters @@ -186,7 +158,10 @@ async def _update_message_status(self, event_id: str, update: DLQMessageUpdate) if not doc: return - update_dict: dict[str, Any] = {"status": update.status, "last_updated": datetime.now(timezone.utc)} + update_dict: dict[str, DLQMessageStatus | datetime | int | str] = { + "status": update.status, + "last_updated": datetime.now(timezone.utc), + } if update.next_retry_at is not None: update_dict["next_retry_at"] = update.next_retry_at if update.retried_at is not None: @@ -271,51 +246,44 @@ async def _discard_message(self, message: DLQMessage, reason: str) -> None: self.logger.warning("Discarded message", extra={"event_id": message.event.event_id, "reason": reason}) - async def _monitor_dlq(self) -> None: + async def check_scheduled_retries(self) -> None: + """Check for and process scheduled retries. Called periodically by DI provider.""" try: - while True: - try: - # Find messages ready for retry using Beanie - now = datetime.now(timezone.utc) - - docs = ( - await DLQMessageDocument.find( - { - "status": DLQMessageStatus.SCHEDULED, - "next_retry_at": {"$lte": now}, - } - ) - .limit(100) - .to_list() - ) - - for doc in docs: - message = DLQMessage.model_validate(doc, from_attributes=True) - await self._retry_message(message) - - # Update queue size metrics - await self._update_queue_metrics() - - # Sleep before next check - await asyncio.sleep(10) - - except asyncio.CancelledError: - raise - except Exception as e: - self.logger.error(f"Error in DLQ monitor: {e}") - await asyncio.sleep(60) - except asyncio.CancelledError: - self.logger.info("DLQ monitor cancelled") + # Find messages ready for retry using Beanie + now = datetime.now(timezone.utc) + + docs = ( + await DLQMessageDocument.find( + { + "status": DLQMessageStatus.SCHEDULED, + "next_retry_at": {"$lte": now}, + } + ) + .limit(100) + .to_list() + ) + + for doc in docs: + message = DLQMessage.model_validate(doc, from_attributes=True) + await self._retry_message(message) + + # Update queue size metrics + await self._update_queue_metrics() + + except Exception as e: + self.logger.error(f"Error in DLQ monitor: {e}") async def _update_queue_metrics(self) -> None: # Get counts by topic using Beanie aggregation - pipeline: list[dict[str, Any]] = [ + pipeline: list[dict[str, object]] = [ {"$match": {"status": {"$in": [DLQMessageStatus.PENDING, DLQMessageStatus.SCHEDULED]}}}, {"$group": {"_id": "$original_topic", "count": {"$sum": 1}}}, ] async for result in DLQMessageDocument.aggregate(pipeline): - self.metrics.update_dlq_queue_size(result["_id"], result["count"]) + topic = str(result["_id"]) + count = int(result["count"]) + self.metrics.update_dlq_queue_size(topic, count) def set_retry_policy(self, topic: str, policy: RetryPolicy) -> None: self._retry_policies[topic] = policy diff --git a/backend/app/events/core/__init__.py b/backend/app/events/core/__init__.py index 3b12df76..d8902ed1 100644 --- a/backend/app/events/core/__init__.py +++ b/backend/app/events/core/__init__.py @@ -1,5 +1,3 @@ -from .consumer import UnifiedConsumer -from .dispatcher import EventDispatcher from .dlq_handler import ( create_dlq_error_handler, create_immediate_dlq_handler, @@ -22,8 +20,6 @@ "ConsumerMetrics", # Core components "UnifiedProducer", - "UnifiedConsumer", - "EventDispatcher", # Helpers "create_dlq_error_handler", "create_immediate_dlq_handler", diff --git a/backend/app/events/core/consumer.py b/backend/app/events/core/consumer.py deleted file mode 100644 index 8a051429..00000000 --- a/backend/app/events/core/consumer.py +++ /dev/null @@ -1,126 +0,0 @@ -import logging -from collections.abc import Awaitable, Callable - -from aiokafka import AIOKafkaConsumer, TopicPartition -from opentelemetry.trace import SpanKind - -from app.core.metrics import EventMetrics -from app.core.tracing import EventAttributes -from app.core.tracing.utils import extract_trace_context, get_tracer -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings - -from .dispatcher import EventDispatcher -from .types import ConsumerConfig - - -class UnifiedConsumer: - """Kafka consumer with framework-style run(). - - No loops in user code. Register handlers, call run(), handlers get called. - - Usage: - dispatcher = EventDispatcher() - dispatcher.register(EventType.FOO, handle_foo) - - consumer = UnifiedConsumer(..., dispatcher=dispatcher) - await consumer.run() # Blocks, calls handlers when events arrive - """ - - def __init__( - self, - config: ConsumerConfig, - dispatcher: EventDispatcher, - schema_registry: SchemaRegistryManager, - settings: Settings, - logger: logging.Logger, - event_metrics: EventMetrics, - topics: list[KafkaTopic], - error_callback: Callable[[Exception, DomainEvent], Awaitable[None]] | None = None, - ): - self._config = config - self._dispatcher = dispatcher - self._schema_registry = schema_registry - self._event_metrics = event_metrics - self._topics = [f"{settings.KAFKA_TOPIC_PREFIX}{t}" for t in topics] - self._error_callback = error_callback - self.logger = logger - self._consumer: AIOKafkaConsumer | None = None - - async def run(self) -> None: - """Run the consumer. Blocks until stopped. Calls registered handlers.""" - tracer = get_tracer() - - self._consumer = AIOKafkaConsumer( - *self._topics, - bootstrap_servers=self._config.bootstrap_servers, - group_id=self._config.group_id, - client_id=self._config.client_id, - auto_offset_reset=self._config.auto_offset_reset, - enable_auto_commit=self._config.enable_auto_commit, - session_timeout_ms=self._config.session_timeout_ms, - heartbeat_interval_ms=self._config.heartbeat_interval_ms, - max_poll_interval_ms=self._config.max_poll_interval_ms, - request_timeout_ms=self._config.request_timeout_ms, - fetch_min_bytes=self._config.fetch_min_bytes, - fetch_max_wait_ms=self._config.fetch_max_wait_ms, - ) - - await self._consumer.start() - self.logger.info(f"Consumer running for topics: {self._topics}") - - try: - async for msg in self._consumer: - if not msg.value: - continue - - try: - event = await self._schema_registry.deserialize_event(msg.value, msg.topic) - - headers = {k: v.decode() if isinstance(v, bytes) else v for k, v in (msg.headers or [])} - ctx = extract_trace_context(headers) - - with tracer.start_as_current_span( - "kafka.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: msg.topic, - EventAttributes.KAFKA_PARTITION: msg.partition, - EventAttributes.KAFKA_OFFSET: msg.offset, - EventAttributes.EVENT_TYPE: event.event_type, - EventAttributes.EVENT_ID: event.event_id, - }, - ): - await self._dispatcher.dispatch(event) - - if not self._config.enable_auto_commit: - await self._consumer.commit() - - self._event_metrics.record_kafka_message_consumed(msg.topic, self._config.group_id) - - except Exception as e: - self.logger.error(f"Error processing message: {e}", exc_info=True) - self._event_metrics.record_kafka_consumption_error( - msg.topic, self._config.group_id, type(e).__name__ - ) - if self._error_callback: - await self._error_callback(e, event) - - finally: - await self._consumer.stop() - self.logger.info("Consumer stopped") - - async def seek_to_beginning(self) -> None: - if self._consumer and (assignment := self._consumer.assignment()): - await self._consumer.seek_to_beginning(*assignment) - - async def seek_to_end(self) -> None: - if self._consumer and (assignment := self._consumer.assignment()): - await self._consumer.seek_to_end(*assignment) - - async def seek_to_offset(self, topic: str, partition: int, offset: int) -> None: - if self._consumer: - self._consumer.seek(TopicPartition(topic, partition), offset) diff --git a/backend/app/events/core/dispatcher.py b/backend/app/events/core/dispatcher.py deleted file mode 100644 index bc69a4a3..00000000 --- a/backend/app/events/core/dispatcher.py +++ /dev/null @@ -1,177 +0,0 @@ -import asyncio -import logging -from collections import defaultdict -from collections.abc import Awaitable, Callable -from typing import TypeAlias, TypeVar - -from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent -from app.infrastructure.kafka.mappings import get_event_class_for_type - -T = TypeVar("T", bound=DomainEvent) -EventHandler: TypeAlias = Callable[[DomainEvent], Awaitable[None]] - - -class EventDispatcher: - """ - Type-safe event dispatcher with automatic routing. - - This dispatcher eliminates the need for manual if/elif routing by maintaining - a direct mapping from event types to their handlers. - """ - - def __init__(self, logger: logging.Logger) -> None: - self.logger = logger - # Map event types to their handlers - self._handlers: dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]] = defaultdict(list) - - # Map topics to event types that can appear on them - self._topic_event_types: dict[str, set[type[DomainEvent]]] = defaultdict(set) - - # Metrics per event type - self._event_metrics: dict[EventType, dict[str, int]] = defaultdict( - lambda: {"processed": 0, "failed": 0, "skipped": 0} - ) - - def register( - self, event_type: EventType - ) -> Callable[[Callable[[T], Awaitable[None]]], Callable[[T], Awaitable[None]]]: - """ - Decorator for registering type-safe event handlers. - - Generic over T (any DomainEvent subtype) - accepts handlers with specific - event types while preserving their type signature for callers. - - Usage: - @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def handle_execution(event: ExecutionRequestedEvent) -> None: - # Handler logic here - event is properly typed - """ - - def decorator(handler: Callable[[T], Awaitable[None]]) -> Callable[[T], Awaitable[None]]: - self.logger.info(f"Registering handler '{handler.__name__}' for event type '{event_type}'") - # Safe: dispatch() routes by event_type, guaranteeing correct types at runtime - self._handlers[event_type].append(handler) # type: ignore[arg-type] - return handler - - return decorator - - def register_handler(self, event_type: EventType, handler: EventHandler) -> None: - """ - Direct registration method for handlers. - - Args: - event_type: The event type this handler processes - handler: The async handler function - """ - self.logger.info(f"Registering handler '{handler.__name__}' for event type '{event_type}'") - self._handlers[event_type].append(handler) - - def remove_handler(self, event_type: EventType, handler: EventHandler) -> bool: - """ - Remove a specific handler for an event type. - - Args: - event_type: The event type to remove handler from - handler: The handler function to remove - - Returns: - True if handler was found and removed, False otherwise - """ - if event_type in self._handlers and handler in self._handlers[event_type]: - self._handlers[event_type].remove(handler) - self.logger.info(f"Removed handler '{handler.__name__}' for event type '{event_type}'") - # Clean up empty lists - if not self._handlers[event_type]: - del self._handlers[event_type] - return True - return False - - async def dispatch(self, event: DomainEvent) -> None: - """ - Dispatch an event to all registered handlers for its type. - - Args: - event: The event to dispatch - """ - event_type = event.event_type - handlers = self._handlers.get(event_type, []) - self.logger.debug(f"Dispatcher has {len(self._handlers)} event types registered") - self.logger.debug( - f"For event type {event_type}, found {len(handlers)} handlers: {[h.__class__.__name__ for h in handlers]}" - ) - - if not handlers: - self._event_metrics[event_type]["skipped"] += 1 - self.logger.debug(f"No handlers registered for event type {event_type}") - return - - self.logger.debug(f"Dispatching {event_type} to {len(handlers)} handler(s)") - - # Run handlers concurrently for better performance - tasks = [] - for handler in handlers: - tasks.append(self._execute_handler(handler, event)) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Count successes and failures - for result in results: - if isinstance(result, Exception): - self._event_metrics[event_type]["failed"] += 1 - else: - self._event_metrics[event_type]["processed"] += 1 - - async def _execute_handler(self, handler: EventHandler, event: DomainEvent) -> None: - """ - Execute a single handler with error handling. - - Args: - handler: The handler function - event: The event to process - """ - try: - self.logger.debug(f"Executing handler {handler.__class__.__name__} for event {event.event_id}") - await handler(event) - self.logger.debug(f"Handler {handler.__class__.__name__} completed") - except Exception as e: - self.logger.error( - f"Handler '{handler.__class__.__name__}' failed for event {event.event_id}: {e}", exc_info=True - ) - raise - - def get_topics_for_registered_handlers(self) -> set[str]: - """ - Get all topics that have registered handlers. - - Returns: - Set of topic names that should be subscribed to - """ - topics = set() - for event_type in self._handlers.keys(): - # Find event class for this type - event_class = get_event_class_for_type(event_type) - if event_class and hasattr(event_class, "topic"): - topics.add(str(event_class.topic)) - return topics - - def get_metrics(self) -> dict[str, dict[str, int]]: - """Get processing metrics for all event types.""" - return {event_type: metrics for event_type, metrics in self._event_metrics.items()} - - def clear_handlers(self) -> None: - """Clear all registered handlers (useful for testing).""" - self._handlers.clear() - self.logger.info("All event handlers cleared") - - def get_handlers(self, event_type: EventType) -> list[Callable[[DomainEvent], Awaitable[None]]]: - """Get all handlers for a specific event type.""" - return self._handlers.get(event_type, []).copy() - - def get_all_handlers(self) -> dict[EventType, list[Callable[[DomainEvent], Awaitable[None]]]]: - """Get all registered handlers (returns a copy).""" - return {k: v.copy() for k, v in self._handlers.items()} - - def replace_handlers(self, event_type: EventType, handlers: list[Callable[[DomainEvent], Awaitable[None]]]) -> None: - """Replace all handlers for a specific event type.""" - self._handlers[event_type] = handlers diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index 69e136ff..d4c7a432 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -3,7 +3,6 @@ import logging import socket from datetime import datetime, timezone -from typing import Any from aiokafka import AIOKafkaProducer from aiokafka.errors import KafkaError @@ -16,98 +15,36 @@ from app.infrastructure.kafka.mappings import EVENT_TYPE_TO_TOPIC from app.settings import Settings -from .types import ProducerMetrics, ProducerState +from .types import ProducerMetrics class UnifiedProducer: - """Fully async Kafka producer using aiokafka.""" + """Kafka producer wrapper with schema registry integration. + + Pure logic class - lifecycle managed by DI provider. + """ def __init__( self, - schema_registry_manager: SchemaRegistryManager, - logger: logging.Logger, + producer: AIOKafkaProducer, + metrics: ProducerMetrics, + schema_registry: SchemaRegistryManager, settings: Settings, + logger: logging.Logger, event_metrics: EventMetrics, ): + self._producer = producer + self._metrics = metrics + self._schema_registry = schema_registry self._settings = settings - self._schema_registry = schema_registry_manager self.logger = logger - self._producer: AIOKafkaProducer | None = None - self._state = ProducerState.STOPPED - self._metrics = ProducerMetrics() self._event_metrics = event_metrics self._topic_prefix = settings.KAFKA_TOPIC_PREFIX - @property - def state(self) -> ProducerState: - return self._state - - @property - def metrics(self) -> ProducerMetrics: - return self._metrics - - @property - def producer(self) -> AIOKafkaProducer | None: - return self._producer - - async def __aenter__(self) -> "UnifiedProducer": - """Start the Kafka producer.""" - self._state = ProducerState.STARTING - self.logger.info("Starting producer...") - - self._producer = AIOKafkaProducer( - bootstrap_servers=self._settings.KAFKA_BOOTSTRAP_SERVERS, - client_id=f"{self._settings.SERVICE_NAME}-producer", - acks="all", - compression_type="gzip", - max_batch_size=16384, - linger_ms=10, - enable_idempotence=True, - ) - - await self._producer.start() - self._state = ProducerState.RUNNING - self.logger.info(f"Producer started: {self._settings.KAFKA_BOOTSTRAP_SERVERS}") - return self - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the Kafka producer.""" - self._state = ProducerState.STOPPING - self.logger.info("Stopping producer...") - - if self._producer: - await self._producer.stop() - self._producer = None - - self._state = ProducerState.STOPPED - self.logger.info("Producer stopped") - - def get_status(self) -> dict[str, Any]: - return { - "state": self._state, - "config": { - "bootstrap_servers": self._settings.KAFKA_BOOTSTRAP_SERVERS, - "client_id": f"{self._settings.SERVICE_NAME}-producer", - }, - "metrics": { - "messages_sent": self._metrics.messages_sent, - "messages_failed": self._metrics.messages_failed, - "bytes_sent": self._metrics.bytes_sent, - "queue_size": self._metrics.queue_size, - "avg_latency_ms": self._metrics.avg_latency_ms, - "last_error": self._metrics.last_error, - "last_error_time": self._metrics.last_error_time.isoformat() if self._metrics.last_error_time else None, - }, - } - async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: """Produce a message to Kafka.""" - if not self._producer: - self.logger.error("Producer not running") - return - try: serialized_value = await self._schema_registry.serialize_event(event_to_produce) topic = f"{self._topic_prefix}{EVENT_TYPE_TO_TOPIC[event_to_produce.event_type]}" @@ -143,10 +80,6 @@ async def send_to_dlq( self, original_event: DomainEvent, original_topic: str, error: Exception, retry_count: int = 0 ) -> None: """Send a failed event to the Dead Letter Queue.""" - if not self._producer: - self.logger.error("Producer not running, cannot send to DLQ") - return - try: # Get producer ID (hostname + task name) current_task = asyncio.current_task() diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py index 01764c82..d7b0497a 100644 --- a/backend/app/events/event_store_consumer.py +++ b/backend/app/events/event_store_consumer.py @@ -1,45 +1,36 @@ -import asyncio import logging -from aiokafka import AIOKafkaConsumer +from aiokafka import AIOKafkaConsumer, ConsumerRecord, TopicPartition from opentelemetry.trace import SpanKind from app.core.metrics import EventMetrics from app.core.tracing.utils import trace_span -from app.domain.enums.kafka import GroupId, KafkaTopic +from app.domain.enums.kafka import GroupId from app.domain.events.typed import DomainEvent from app.events.event_store import EventStore from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings class EventStoreConsumer: """Consumes events from Kafka and stores them in MongoDB. - Uses Kafka's native batching via getmany() - no application-level buffering. - Kafka's fetch_max_wait_ms controls batch timing at the protocol level. - - Usage: - consumer = EventStoreConsumer(...) - await consumer.run() # Blocks until cancelled + Pure logic class - lifecycle managed by DI provider. + Uses Kafka's native batching via getmany(). """ def __init__( self, event_store: EventStore, - topics: list[KafkaTopic], + consumer: AIOKafkaConsumer, schema_registry_manager: SchemaRegistryManager, - settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, batch_size: int = 100, batch_timeout_ms: int = 5000, ): - """Store dependencies. All work happens in run().""" self.event_store = event_store - self.topics = topics - self.settings = settings + self.consumer = consumer self.group_id = group_id self.batch_size = batch_size self.batch_timeout_ms = batch_timeout_ms @@ -47,67 +38,35 @@ def __init__( self.event_metrics = event_metrics self.schema_registry_manager = schema_registry_manager - async def run(self) -> None: - """Run the consumer. Blocks until cancelled. + async def process_batch(self) -> None: + """Process a single batch of messages from Kafka. - Creates consumer, starts it, runs batch loop, stops on cancellation. - Uses getmany() which blocks on Kafka's fetch - no polling, no timers. + Called repeatedly by DI provider's background task. """ - self.logger.info("Event store consumer starting...") - - topic_strings = [f"{self.settings.KAFKA_TOPIC_PREFIX}{topic}" for topic in self.topics] - - consumer = AIOKafkaConsumer( - *topic_strings, - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{self.group_id}.{self.settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - max_poll_records=self.batch_size, - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - fetch_max_wait_ms=self.batch_timeout_ms, + batch_data: dict[TopicPartition, list[ConsumerRecord[bytes, bytes]]] = await self.consumer.getmany( + timeout_ms=self.batch_timeout_ms, + max_records=self.batch_size, ) - await consumer.start() - self.logger.info(f"Event store consumer initialized for topics: {topic_strings}") - - try: - while True: - # getmany() blocks until Kafka has data OR fetch_max_wait_ms expires - # This is NOT polling - it's async waiting on the network socket - batch_data = await consumer.getmany( - timeout_ms=self.batch_timeout_ms, - max_records=self.batch_size, - ) - - if not batch_data: - continue - - # Deserialize all messages in the batch - events: list[DomainEvent] = [] - for tp, messages in batch_data.items(): - for msg in messages: - try: - event = await self.schema_registry_manager.deserialize_event(msg.value, msg.topic) - events.append(event) - self.event_metrics.record_kafka_message_consumed( - topic=msg.topic, - consumer_group=str(self.group_id), - ) - except Exception as e: - self.logger.error(f"Failed to deserialize message from {tp}: {e}", exc_info=True) - - if events: - await self._store_batch(events) - await consumer.commit() - - except asyncio.CancelledError: - self.logger.info("Event store consumer cancelled") - finally: - await consumer.stop() - self.logger.info("Event store consumer stopped") + if not batch_data: + return + + events: list[DomainEvent] = [] + for tp, messages in batch_data.items(): + for msg in messages: + try: + event = await self.schema_registry_manager.deserialize_event(msg.value, msg.topic) + events.append(event) + self.event_metrics.record_kafka_message_consumed( + topic=msg.topic, + consumer_group=str(self.group_id), + ) + except Exception as e: + self.logger.error(f"Failed to deserialize message from {tp}: {e}", exc_info=True) + + if events: + await self._store_batch(events) + await self.consumer.commit() async def _store_batch(self, events: list[DomainEvent]) -> None: """Store a batch of events.""" diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py index 6ae60f87..a96e74de 100644 --- a/backend/app/services/event_bus.py +++ b/backend/app/services/event_bus.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Any, Callable, Optional +from typing import Any, Callable from uuid import uuid4 from aiokafka import AIOKafkaConsumer, AIOKafkaProducer @@ -37,8 +37,9 @@ class Subscription: class EventBus: - """ - Distributed event bus for cross-instance communication via Kafka. + """Distributed event bus for cross-instance communication via Kafka. + + Pure logic class - lifecycle managed by DI provider. Publishers send events to Kafka. Subscribers receive events from OTHER instances only - self-published messages are filtered out. This design means: @@ -51,79 +52,49 @@ class EventBus: - *.completed - matches all completed events """ - def __init__(self, settings: Settings, logger: logging.Logger, connection_metrics: ConnectionMetrics) -> None: + def __init__( + self, + producer: AIOKafkaProducer, + consumer: AIOKafkaConsumer, + settings: Settings, + logger: logging.Logger, + connection_metrics: ConnectionMetrics, + ) -> None: + self.producer = producer + self.consumer = consumer self.logger = logger self.settings = settings self.metrics = connection_metrics - self.producer: Optional[AIOKafkaProducer] = None - self.consumer: Optional[AIOKafkaConsumer] = None self._subscriptions: dict[str, Subscription] = {} # id -> Subscription self._pattern_index: dict[str, set[str]] = {} # pattern -> set of subscription ids - self._consumer_task: Optional[asyncio.Task[None]] = None self._lock = asyncio.Lock() self._topic = f"{self.settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" self._instance_id = str(uuid4()) # Unique ID for filtering self-published messages - async def __aenter__(self) -> "EventBus": - """Start the event bus with Kafka backing.""" - await self._initialize_kafka() - self._consumer_task = asyncio.create_task(self._kafka_listener()) - self.logger.info("Event bus started with Kafka backing") - return self - - async def _initialize_kafka(self) -> None: - """Initialize Kafka producer and consumer.""" - # Producer setup - self.producer = AIOKafkaProducer( - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - client_id=f"event-bus-producer-{uuid4()}", - linger_ms=10, - max_batch_size=16384, - enable_idempotence=True, - ) + async def process_kafka_message(self) -> None: + """Process a single message from Kafka consumer. - # Consumer setup - self.consumer = AIOKafkaConsumer( - self._topic, - bootstrap_servers=self.settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"event-bus-{uuid4()}", - auto_offset_reset="latest", - enable_auto_commit=True, - client_id=f"event-bus-consumer-{uuid4()}", - session_timeout_ms=self.settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=self.settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=self.settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=self.settings.KAFKA_REQUEST_TIMEOUT_MS, - ) + Called by DI provider's background task. Filters out self-published messages. + """ + try: + msg = await asyncio.wait_for(self.consumer.getone(), timeout=0.1) - # Start both in parallel for faster startup - await asyncio.gather(self.producer.start(), self.consumer.start()) - - async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: - """Stop the event bus and clean up resources.""" - # Cancel consumer task - if self._consumer_task and not self._consumer_task.done(): - self._consumer_task.cancel() - try: - await self._consumer_task - except asyncio.CancelledError: - pass - - # Stop Kafka components - if self.consumer: - await self.consumer.stop() - self.consumer = None - - if self.producer: - await self.producer.stop() - self.producer = None - - # Clear subscriptions - async with self._lock: - self._subscriptions.clear() - self._pattern_index.clear() + # Skip messages from this instance - publisher handles its own state + headers = dict(msg.headers) if msg.headers else {} + source = headers.get("source_instance", b"").decode("utf-8") + if source == self._instance_id: + return - self.logger.info("Event bus stopped") + event_dict = json.loads(msg.value.decode("utf-8")) + event = EventBusEvent.model_validate(event_dict) + await self._distribute_event(event.event_type, event) + + except asyncio.TimeoutError: + pass + except KafkaError as e: + self.logger.error(f"Consumer error: {e}") + except Exception as e: + self.logger.error(f"Error processing Kafka message: {e}") async def publish(self, event_type: str, data: dict[str, Any]) -> None: """ @@ -137,21 +108,19 @@ async def publish(self, event_type: str, data: dict[str, Any]) -> None: data: Event data payload """ event = self._create_event(event_type, data) - - if self.producer: - try: - value = event.model_dump_json().encode("utf-8") - key = event_type.encode("utf-8") if event_type else None - headers = [("source_instance", self._instance_id.encode("utf-8"))] - - await self.producer.send_and_wait( - topic=self._topic, - value=value, - key=key, - headers=headers, - ) - except Exception as e: - self.logger.error(f"Failed to publish to Kafka: {e}") + try: + value = event.model_dump_json().encode("utf-8") + key = event_type.encode("utf-8") if event_type else None + headers = [("source_instance", self._instance_id.encode("utf-8"))] + + await self.producer.send_and_wait( + topic=self._topic, + value=value, + key=key, + headers=headers, + ) + except Exception as e: + self.logger.error(f"Failed to publish to Kafka: {e}") def _create_event(self, event_type: str, data: dict[str, Any]) -> EventBusEvent: """Create a standardized event object.""" @@ -186,6 +155,7 @@ async def subscribe(self, pattern: str, handler: Callable[[EventBusEvent], Any]) # Update metrics self._update_metrics(pattern) + self.metrics.increment_event_bus_subscriptions() self.logger.debug(f"Created subscription {subscription.id} for pattern: {pattern}") return subscription.id @@ -221,6 +191,7 @@ async def _remove_subscription(self, subscription_id: str) -> None: # Update metrics self._update_metrics(pattern) + self.metrics.decrement_event_bus_subscriptions() self.logger.debug(f"Removed subscription {subscription_id} for pattern: {pattern}") @@ -260,55 +231,10 @@ async def _invoke_handler(self, handler: Callable[[EventBusEvent], Any], event: else: await asyncio.to_thread(handler, event) - async def _kafka_listener(self) -> None: - """Listen for Kafka messages from OTHER instances and distribute to local subscribers.""" - if not self.consumer: - return - - self.logger.info("Kafka listener started") - - try: - while True: - try: - msg = await asyncio.wait_for(self.consumer.getone(), timeout=0.1) - - # Skip messages from this instance - publisher handles its own state - headers = dict(msg.headers) if msg.headers else {} - source = headers.get("source_instance", b"").decode("utf-8") - if source == self._instance_id: - continue - - try: - event_dict = json.loads(msg.value.decode("utf-8")) - event = EventBusEvent.model_validate(event_dict) - await self._distribute_event(event.event_type, event) - except Exception as e: - self.logger.error(f"Error processing Kafka message: {e}") - - except asyncio.TimeoutError: - continue - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - continue - - except asyncio.CancelledError: - self.logger.info("Kafka listener cancelled") - def _update_metrics(self, pattern: str) -> None: """Update metrics for a pattern (must be called within lock).""" if self.metrics: count = len(self._pattern_index.get(pattern, set())) self.metrics.update_event_bus_subscribers(count, pattern) - async def get_statistics(self) -> dict[str, Any]: - """Get event bus statistics.""" - async with self._lock: - return { - "patterns": list(self._pattern_index.keys()), - "total_patterns": len(self._pattern_index), - "total_subscriptions": len(self._subscriptions), - "kafka_enabled": self.producer is not None, - "consumer_task_active": self._consumer_task is not None and not self._consumer_task.done(), - } - diff --git a/backend/app/services/k8s_worker/worker_logic.py b/backend/app/services/k8s_worker/worker_logic.py index 848f4c48..1bd2331a 100644 --- a/backend/app/services/k8s_worker/worker_logic.py +++ b/backend/app/services/k8s_worker/worker_logic.py @@ -1,26 +1,22 @@ import asyncio import logging -import os import time from pathlib import Path from typing import Any from kubernetes import client as k8s_client -from kubernetes import config as k8s_config from kubernetes.client.rest import ApiException from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics -from app.domain.enums.events import EventType from app.domain.enums.storage import ExecutionErrorType from app.domain.events.typed import ( CreatePodCommandEvent, DeletePodCommandEvent, - DomainEvent, ExecutionFailedEvent, ExecutionStartedEvent, PodCreatedEvent, ) -from app.events.core import EventDispatcher, UnifiedProducer +from app.events.core import UnifiedProducer from app.runtime_registry import RUNTIME_REGISTRY from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.pod_builder import PodBuilder @@ -32,13 +28,12 @@ class K8sWorkerLogic: Business logic for Kubernetes pod management. Handles: - - K8s client initialization - Pod creation from command events - Pod deletion (compensation) - Image pre-puller daemonset management - Event publishing (PodCreated, ExecutionFailed) - This class is stateful and must be instantiated once per worker instance. + All dependencies including K8s clients are injected via constructor. """ def __init__( @@ -48,18 +43,21 @@ def __init__( settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, + kubernetes_metrics: KubernetesMetrics, + execution_metrics: ExecutionMetrics, + k8s_v1: k8s_client.CoreV1Api, + k8s_apps_v1: k8s_client.AppsV1Api, ): self._event_metrics = event_metrics self.logger = logger - self.metrics = KubernetesMetrics(settings) - self.execution_metrics = ExecutionMetrics(settings) - self.config = config or K8sWorkerConfig() + self.metrics = kubernetes_metrics + self.execution_metrics = execution_metrics + self.config = config self._settings = settings - # Kubernetes clients (initialized in initialize()) - self.v1: k8s_client.CoreV1Api | None = None - self.networking_v1: k8s_client.NetworkingV1Api | None = None - self.apps_v1: k8s_client.AppsV1Api | None = None + # Kubernetes clients via DI + self.v1 = k8s_v1 + self.apps_v1 = k8s_apps_v1 # Components self.pod_builder = PodBuilder(namespace=self.config.namespace, config=self.config) @@ -69,79 +67,7 @@ def __init__( self._active_creations: set[str] = set() self._creation_semaphore = asyncio.Semaphore(self.config.max_concurrent_pods) - def initialize(self) -> None: - """Initialize Kubernetes clients. Must be called before handling events.""" - if self.config.namespace == "default": - raise RuntimeError( - "KubernetesWorker namespace 'default' is forbidden. Set K8S_NAMESPACE to a dedicated namespace." - ) - - self._initialize_kubernetes_client() - - def register_handlers(self, dispatcher: EventDispatcher) -> None: - """Register event handlers with the dispatcher.""" - dispatcher.register_handler(EventType.CREATE_POD_COMMAND, self._handle_create_pod_command_wrapper) - dispatcher.register_handler(EventType.DELETE_POD_COMMAND, self._handle_delete_pod_command_wrapper) - - def _initialize_kubernetes_client(self) -> None: - """Initialize Kubernetes API clients.""" - try: - # Load config - if self.config.in_cluster: - self.logger.info("Using in-cluster Kubernetes configuration") - k8s_config.load_incluster_config() - elif self.config.kubeconfig_path and os.path.exists(self.config.kubeconfig_path): - self.logger.info(f"Using kubeconfig from {self.config.kubeconfig_path}") - k8s_config.load_kube_config(config_file=self.config.kubeconfig_path) - else: - # Try default locations - if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount"): - self.logger.info("Detected in-cluster environment") - k8s_config.load_incluster_config() - else: - self.logger.info("Using default kubeconfig") - k8s_config.load_kube_config() - - # Get the default configuration that was set by load_kube_config - configuration = k8s_client.Configuration.get_default_copy() - - # Log the configuration for debugging - self.logger.info(f"Kubernetes API host: {configuration.host}") - self.logger.info(f"SSL CA cert configured: {configuration.ssl_ca_cert is not None}") - - # Create API clients with the configuration - api_client = k8s_client.ApiClient(configuration) - self.v1 = k8s_client.CoreV1Api(api_client) - self.networking_v1 = k8s_client.NetworkingV1Api(api_client) - self.apps_v1 = k8s_client.AppsV1Api(api_client) - - # Test connection with namespace-scoped operation - _ = self.v1.list_namespaced_pod(namespace=self.config.namespace, limit=1) - self.logger.info( - f"Successfully connected to Kubernetes API, namespace {self.config.namespace} accessible" - ) - - except Exception as e: - self.logger.error(f"Failed to initialize Kubernetes client: {e}") - raise - - async def _handle_create_pod_command_wrapper(self, event: DomainEvent) -> None: - """Wrapper for handling CreatePodCommandEvent with type safety.""" - assert isinstance(event, CreatePodCommandEvent) - self.logger.info( - f"Processing create_pod_command for execution {event.execution_id} from saga {event.saga_id}" - ) - await self._handle_create_pod_command(event) - - async def _handle_delete_pod_command_wrapper(self, event: DomainEvent) -> None: - """Wrapper for handling DeletePodCommandEvent.""" - assert isinstance(event, DeletePodCommandEvent) - self.logger.info( - f"Processing delete_pod_command for execution {event.execution_id} from saga {event.saga_id}" - ) - await self._handle_delete_pod_command(event) - - async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: + async def handle_create_pod_command(self, command: CreatePodCommandEvent) -> None: """Handle create pod command from saga orchestrator.""" execution_id = command.execution_id @@ -153,7 +79,7 @@ async def _handle_create_pod_command(self, command: CreatePodCommandEvent) -> No # Create pod asynchronously asyncio.create_task(self._create_pod_for_execution(command)) - async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: + async def handle_delete_pod_command(self, command: DeletePodCommandEvent) -> None: """Handle delete pod command from saga orchestrator (compensation).""" execution_id = command.execution_id self.logger.info(f"Deleting pod for execution {execution_id} due to: {command.reason}") @@ -161,22 +87,20 @@ async def _handle_delete_pod_command(self, command: DeletePodCommandEvent) -> No try: # Delete the pod pod_name = f"executor-{execution_id}" - if self.v1: - await asyncio.to_thread( - self.v1.delete_namespaced_pod, - name=pod_name, - namespace=self.config.namespace, - grace_period_seconds=30, - ) - self.logger.info(f"Successfully deleted pod {pod_name}") + await asyncio.to_thread( + self.v1.delete_namespaced_pod, + name=pod_name, + namespace=self.config.namespace, + grace_period_seconds=30, + ) + self.logger.info(f"Successfully deleted pod {pod_name}") # Delete associated ConfigMap configmap_name = f"script-{execution_id}" - if self.v1: - await asyncio.to_thread( - self.v1.delete_namespaced_config_map, name=configmap_name, namespace=self.config.namespace - ) - self.logger.info(f"Successfully deleted ConfigMap {configmap_name}") + await asyncio.to_thread( + self.v1.delete_namespaced_config_map, name=configmap_name, namespace=self.config.namespace + ) + self.logger.info(f"Successfully deleted ConfigMap {configmap_name}") # NetworkPolicy cleanup is managed via a static cluster policy; no per-execution NP deletion @@ -263,8 +187,6 @@ async def _get_entrypoint_script(self) -> str: async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None: """Create ConfigMap in Kubernetes.""" - if not self.v1: - raise RuntimeError("Kubernetes client not initialized") try: await asyncio.to_thread( self.v1.create_namespaced_config_map, namespace=self.config.namespace, body=config_map @@ -281,8 +203,6 @@ async def _create_config_map(self, config_map: k8s_client.V1ConfigMap) -> None: async def _create_pod(self, pod: k8s_client.V1Pod) -> None: """Create Pod in Kubernetes.""" - if not self.v1: - raise RuntimeError("Kubernetes client not initialized") try: await asyncio.to_thread(self.v1.create_namespaced_pod, namespace=self.config.namespace, body=pod) self.logger.debug(f"Created Pod {pod.metadata.name}") @@ -302,9 +222,6 @@ async def _publish_execution_started(self, command: CreatePodCommandEvent, pod: container_id=None, # Will be set when container actually starts metadata=command.metadata, ) - if not self.producer: - self.logger.error("Producer not initialized") - return await self.producer.produce(event_to_produce=event) async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_client.V1Pod) -> None: @@ -315,10 +232,6 @@ async def _publish_pod_created(self, command: CreatePodCommandEvent, pod: k8s_cl namespace=pod.metadata.namespace, metadata=command.metadata, ) - - if not self.producer: - self.logger.error("Producer not initialized") - return await self.producer.produce(event_to_produce=event) async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, error: str) -> None: @@ -332,10 +245,6 @@ async def _publish_pod_creation_failed(self, command: CreatePodCommandEvent, err metadata=command.metadata, error_message=str(error), ) - - if not self.producer: - self.logger.error("Producer not initialized") - return await self.producer.produce(event_to_produce=event) async def ensure_daemonset_task(self) -> None: @@ -347,10 +256,6 @@ async def ensure_daemonset_task(self) -> None: async def ensure_image_pre_puller_daemonset(self) -> None: """Ensure the runtime image pre-puller DaemonSet exists.""" - if not self.apps_v1: - self.logger.warning("Kubernetes AppsV1Api client not initialized. Skipping DaemonSet creation.") - return - daemonset_name = "runtime-image-pre-puller" namespace = self.config.namespace await asyncio.sleep(5) @@ -429,13 +334,3 @@ async def wait_for_active_creations(self, timeout: float = 30) -> None: if self._active_creations: self.logger.warning(f"Timeout, {len(self._active_creations)} pod creations still active") - async def get_status(self) -> dict[str, Any]: - """Get worker status.""" - return { - "active_creations": len(self._active_creations), - "config": { - "namespace": self.config.namespace, - "max_concurrent_pods": self.config.max_concurrent_pods, - "enable_network_policies": True, - }, - } diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index ffb3e10c..45863b75 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -126,20 +126,30 @@ def __init__( extra={"repository": type(notification_repository).__name__}, ) - async def run(self) -> None: - """Run background tasks. Blocks until cancelled. + async def process_pending_batch(self) -> None: + """Process one batch of pending notifications. - Runs: - - Pending notification processor (retries failed deliveries) - - Old notification cleanup (daily) + Called periodically by DI provider's background task. """ - self.logger.info("Starting NotificationService background tasks...") try: - async with asyncio.TaskGroup() as tg: - tg.create_task(self._process_pending_notifications()) - tg.create_task(self._cleanup_old_notifications()) - except* asyncio.CancelledError: - self.logger.info("NotificationService background tasks cancelled") + notifications = await self.repository.find_pending_notifications( + batch_size=self.settings.NOTIF_PENDING_BATCH_SIZE + ) + for notification in notifications: + await self._deliver_notification(notification) + except Exception as e: + self.logger.error(f"Error processing pending notifications: {e}") + + async def cleanup_old(self) -> None: + """Cleanup old notifications once. + + Called periodically by DI provider's background task. + """ + try: + deleted_count = await self.repository.cleanup_old_notifications(self.settings.NOTIF_OLD_DAYS) + self.logger.info(f"Cleaned up {deleted_count} old notifications") + except Exception as e: + self.logger.error(f"Error cleaning up old notifications: {e}") async def create_notification( self, @@ -432,46 +442,6 @@ def _get_slack_color(self, priority: NotificationSeverity) -> str: NotificationSeverity.URGENT: "#990000", # Dark Red }.get(priority, "#808080") # Default gray - async def _process_pending_notifications(self) -> None: - """Process pending notifications in background.""" - while True: - try: - # Find pending notifications - notifications = await self.repository.find_pending_notifications( - batch_size=self.settings.NOTIF_PENDING_BATCH_SIZE - ) - - # Process each notification - for notification in notifications: - await self._deliver_notification(notification) - - # Sleep between batches - await asyncio.sleep(5) - - except asyncio.CancelledError: - raise - except Exception as e: - self.logger.error(f"Error processing pending notifications: {e}") - await asyncio.sleep(10) - - async def _cleanup_old_notifications(self) -> None: - """Cleanup old notifications periodically.""" - while True: - try: - # Run cleanup once per day - await asyncio.sleep(86400) # 24 hours - - # Delete old notifications - deleted_count = await self.repository.cleanup_old_notifications(self.settings.NOTIF_OLD_DAYS) - - self.logger.info(f"Cleaned up {deleted_count} old notifications") - - except asyncio.CancelledError: - raise - except Exception as e: - self.logger.error(f"Error cleaning up old notifications: {e}") - await asyncio.sleep(5) - async def mark_as_read(self, user_id: str, notification_id: str) -> bool: """Mark notification as read.""" success = await self.repository.mark_as_read(notification_id, user_id) diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index c43d53ad..6cd6ad36 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -6,9 +6,9 @@ from typing import Any from kubernetes import client as k8s_client +from kubernetes import watch as k8s_watch from kubernetes.client.rest import ApiException -from app.core.k8s_clients import K8sClients from app.core.metrics import KubernetesMetrics from app.core.utils import StringEnum from app.domain.events.typed import DomainEvent @@ -21,7 +21,6 @@ type ResourceVersion = str type EventType = str type KubeEvent = dict[str, Any] -type StatusDict = dict[str, Any] # Constants MAX_BACKOFF_SECONDS: int = 300 # 5 minutes @@ -82,7 +81,8 @@ def __init__( config: PodMonitorConfig, kafka_event_service: KafkaEventService, logger: logging.Logger, - k8s_clients: K8sClients, + k8s_v1: k8s_client.CoreV1Api, + k8s_watch: k8s_watch.Watch, event_mapper: PodEventMapper, kubernetes_metrics: KubernetesMetrics, ) -> None: @@ -91,9 +91,8 @@ def __init__( self.config = config # Kubernetes clients - self._clients = k8s_clients - self._v1 = k8s_clients.v1 - self._watch = k8s_clients.watch + self._v1 = k8s_v1 + self._watch = k8s_watch # Components self._event_mapper = event_mapper @@ -342,15 +341,3 @@ async def _reconcile(self) -> None: self.logger.error(f"Failed to reconcile state: {e}", exc_info=True) self._metrics.record_pod_monitor_reconciliation_run("failed") - async def get_status(self) -> StatusDict: - """Get monitor status.""" - return { - "tracked_pods": len(self._tracked_pods), - "reconnect_attempts": self._reconnect_attempts, - "last_resource_version": self._last_resource_version, - "config": { - "namespace": self.config.namespace, - "label_selector": self.config.label_selector, - "enable_reconciliation": self.config.enable_state_reconciliation, - }, - } diff --git a/backend/app/services/result_processor/processor_logic.py b/backend/app/services/result_processor/processor_logic.py index 1a61164e..d0b887ee 100644 --- a/backend/app/services/result_processor/processor_logic.py +++ b/backend/app/services/result_processor/processor_logic.py @@ -2,12 +2,10 @@ from app.core.metrics import ExecutionMetrics from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType from app.domain.enums.execution import ExecutionStatus from app.domain.enums.kafka import GroupId from app.domain.enums.storage import ExecutionErrorType, StorageType from app.domain.events.typed import ( - DomainEvent, EventMetadata, ExecutionCompletedEvent, ExecutionFailedEvent, @@ -16,19 +14,20 @@ ResultStoredEvent, ) from app.domain.execution import ExecutionNotFoundError, ExecutionResultDomain, ResourceUsageDomainAdapter -from app.events.core import EventDispatcher, UnifiedProducer +from app.events.core import UnifiedProducer from app.settings import Settings class ProcessorLogic: - """ - Business logic for result processing. + """Business logic for result processing. Handles: - Processing execution completion events - Storing results in database - Publishing ResultStored/ResultFailed events - Recording metrics + + Used by the result processor FastStream worker (run_result_processor.py). """ def __init__( @@ -45,26 +44,6 @@ def __init__( self._metrics = execution_metrics self.logger = logger - def register_handlers(self, dispatcher: EventDispatcher) -> None: - """Register event handlers with the dispatcher.""" - dispatcher.register_handler(EventType.EXECUTION_COMPLETED, self._handle_completed_wrapper) - dispatcher.register_handler(EventType.EXECUTION_FAILED, self._handle_failed_wrapper) - dispatcher.register_handler(EventType.EXECUTION_TIMEOUT, self._handle_timeout_wrapper) - - # Wrappers accepting DomainEvent to satisfy dispatcher typing - - async def _handle_completed_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionCompletedEvent) - await self._handle_completed(event) - - async def _handle_failed_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionFailedEvent) - await self._handle_failed(event) - - async def _handle_timeout_wrapper(self, event: DomainEvent) -> None: - assert isinstance(event, ExecutionTimeoutEvent) - await self._handle_timeout(event) - async def _handle_completed(self, event: ExecutionCompletedEvent) -> None: """Handle execution completed event.""" exec_obj = await self._execution_repo.get_execution(event.execution_id) diff --git a/backend/app/services/saga/saga_logic.py b/backend/app/services/saga/saga_logic.py index ed1476ce..931f6f76 100644 --- a/backend/app/services/saga/saga_logic.py +++ b/backend/app/services/saga/saga_logic.py @@ -15,7 +15,7 @@ from app.domain.enums.saga import SagaState from app.domain.events.typed import DomainEvent, EventMetadata, SagaCancelledEvent from app.domain.saga.models import Saga, SagaConfig -from app.events.core import EventDispatcher, UnifiedProducer +from app.events.core import UnifiedProducer from app.infrastructure.kafka.mappings import get_topic_for_event from .base_saga import BaseSaga @@ -82,13 +82,6 @@ def get_trigger_event_types(self) -> set[EventType]: event_types.update(trigger_event_types) return event_types - def register_handlers(self, dispatcher: EventDispatcher) -> None: - """Register event handlers with the dispatcher.""" - event_types = self.get_trigger_event_types() - for event_type in event_types: - dispatcher.register_handler(event_type, self.handle_event) - self.logger.info(f"Registered handler for event type: {event_type}") - async def handle_event(self, event: DomainEvent) -> None: """Handle incoming event.""" self.logger.info(f"Saga orchestrator handling event: type={event.event_type}, id={event.event_id}") diff --git a/backend/app/services/sse/event_router.py b/backend/app/services/sse/event_router.py index c2c6ef81..276ab3a9 100644 --- a/backend/app/services/sse/event_router.py +++ b/backend/app/services/sse/event_router.py @@ -4,7 +4,6 @@ from app.domain.enums.events import EventType from app.domain.events.typed import DomainEvent -from app.events.core import EventDispatcher from app.services.sse.redis_bus import SSERedisBus # Events that should be routed to SSE clients @@ -33,6 +32,8 @@ class SSEEventRouter: Stateless service that extracts execution_id from events and publishes them to Redis via SSERedisBus. Each execution_id has its own channel. + + Used by the SSE bridge worker (run_sse_bridge.py) via FastStream. """ def __init__(self, sse_bus: SSERedisBus, logger: logging.Logger) -> None: @@ -56,8 +57,3 @@ async def route_event(self, event: DomainEvent) -> None: f"Failed to publish {event.event_type} to Redis for {execution_id}: {e}", exc_info=True, ) - - def register_handlers(self, dispatcher: EventDispatcher) -> None: - """Register routing handlers for all relevant event types.""" - for event_type in SSE_RELEVANT_EVENTS: - dispatcher.register_handler(event_type, self.route_event) diff --git a/backend/app/services/sse/sse_service.py b/backend/app/services/sse/sse_service.py index cf1cfcdf..43b6111b 100644 --- a/backend/app/services/sse/sse_service.py +++ b/backend/app/services/sse/sse_service.py @@ -33,7 +33,6 @@ class SSEService: def __init__( self, repository: SSERepository, - num_consumers: int, sse_bus: SSERedisBus, connection_registry: SSEConnectionRegistry, settings: Settings, @@ -41,7 +40,6 @@ def __init__( connection_metrics: ConnectionMetrics, ) -> None: self.repository = repository - self._num_consumers = num_consumers self.sse_bus = sse_bus self.connection_registry = connection_registry self.settings = settings @@ -243,7 +241,7 @@ async def get_health_status(self) -> SSEHealthDomain: kafka_enabled=True, active_connections=active_connections, active_executions=self.connection_registry.get_execution_count(), - active_consumers=self._num_consumers, + active_consumers=0, # Consumers run in separate SSE bridge worker max_connections_per_user=5, shutdown=ShutdownStatus( phase="ready", diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 7377b600..05c355ab 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -137,9 +137,11 @@ packages = ["app", "workers"] [dependency-groups] dev = [ "coverage==7.13.0", + "fakeredis>=2.33.0", "hypothesis==6.103.4", "iniconfig==2.0.0", "matplotlib==3.10.8", + "mongomock-motor>=0.0.36", "mypy==1.17.1", "mypy_extensions==1.1.0", "pipdeptree==2.23.4", diff --git a/backend/tests/e2e/test_k8s_worker_create_pod.py b/backend/tests/e2e/test_k8s_worker_create_pod.py index 5d95a931..331de874 100644 --- a/backend/tests/e2e/test_k8s_worker_create_pod.py +++ b/backend/tests/e2e/test_k8s_worker_create_pod.py @@ -2,13 +2,14 @@ import uuid import pytest -from app.core.metrics import EventMetrics +from app.core.metrics import EventMetrics, ExecutionMetrics, KubernetesMetrics from app.domain.events.typed import CreatePodCommandEvent, EventMetadata from app.events.core import UnifiedProducer from app.services.k8s_worker.config import K8sWorkerConfig from app.services.k8s_worker.worker_logic import K8sWorkerLogic from app.settings import Settings from dishka import AsyncContainer +from kubernetes import client as k8s_client from kubernetes.client.rest import ApiException pytestmark = [pytest.mark.e2e, pytest.mark.k8s] @@ -22,8 +23,15 @@ async def test_worker_creates_configmap_and_pod( ) -> None: ns = test_settings.K8S_NAMESPACE + if ns == "default": + pytest.fail("K8S_NAMESPACE is set to 'default', which is forbidden") + producer: UnifiedProducer = await scope.get(UnifiedProducer) event_metrics: EventMetrics = await scope.get(EventMetrics) + kubernetes_metrics: KubernetesMetrics = await scope.get(KubernetesMetrics) + execution_metrics: ExecutionMetrics = await scope.get(ExecutionMetrics) + k8s_v1: k8s_client.CoreV1Api = await scope.get(k8s_client.CoreV1Api) + k8s_apps_v1: k8s_client.AppsV1Api = await scope.get(k8s_client.AppsV1Api) cfg = K8sWorkerConfig(namespace=ns, max_concurrent_pods=1) logic = K8sWorkerLogic( @@ -32,19 +40,12 @@ async def test_worker_creates_configmap_and_pod( settings=test_settings, logger=_test_logger, event_metrics=event_metrics, + kubernetes_metrics=kubernetes_metrics, + execution_metrics=execution_metrics, + k8s_v1=k8s_v1, + k8s_apps_v1=k8s_apps_v1, ) - # Initialize k8s clients using logic's own method - try: - logic.initialize() - except RuntimeError as e: - if "default" in str(e): - pytest.skip("K8S_NAMESPACE is set to 'default', which is forbidden") - raise - - if logic.v1 is None: - pytest.skip("Kubernetes cluster not available") - exec_id = uuid.uuid4().hex[:8] cmd = CreatePodCommandEvent( saga_id=uuid.uuid4().hex, diff --git a/backend/tests/helpers/fakes/__init__.py b/backend/tests/helpers/fakes/__init__.py new file mode 100644 index 00000000..c6b1eb30 --- /dev/null +++ b/backend/tests/helpers/fakes/__init__.py @@ -0,0 +1,11 @@ +"""Fake implementations for external boundary clients used in tests.""" + +from .providers import FakeBoundaryClientProvider, FakeDatabaseProvider, FakeSchemaRegistryProvider +from .schema_registry import FakeSchemaRegistryManager + +__all__ = [ + "FakeBoundaryClientProvider", + "FakeDatabaseProvider", + "FakeSchemaRegistryManager", + "FakeSchemaRegistryProvider", +] diff --git a/backend/tests/helpers/fakes/kafka.py b/backend/tests/helpers/fakes/kafka.py new file mode 100644 index 00000000..23b7df30 --- /dev/null +++ b/backend/tests/helpers/fakes/kafka.py @@ -0,0 +1,78 @@ +"""Minimal fake Kafka clients for DI container testing.""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class FakeAIOKafkaProducer: + """Minimal fake for AIOKafkaProducer - satisfies DI container.""" + + bootstrap_servers: str = "localhost:9092" + sent_messages: list[tuple[str, bytes, list[tuple[str, bytes]]]] = field(default_factory=list) + _started: bool = False + + async def start(self) -> None: + self._started = True + + async def stop(self) -> None: + self._started = False + + async def send_and_wait( + self, + topic: str, + value: bytes, + headers: list[tuple[str, bytes]] | None = None, + **kwargs: Any, + ) -> None: + self.sent_messages.append((topic, value, headers or [])) + + async def __aenter__(self) -> "FakeAIOKafkaProducer": + await self.start() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.stop() + + +@dataclass +class FakeAIOKafkaConsumer: + """Minimal fake for AIOKafkaConsumer - satisfies DI container.""" + + bootstrap_servers: str = "localhost:9092" + group_id: str = "test-group" + _topics: list[str] = field(default_factory=list) + _started: bool = False + _messages: list[Any] = field(default_factory=list) + + def subscribe(self, topics: list[str]) -> None: + self._topics = topics + + async def start(self) -> None: + self._started = True + + async def stop(self) -> None: + self._started = False + + async def commit(self) -> None: + pass + + def assignment(self) -> set[Any]: + return set() + + async def seek_to_beginning(self, *partitions: Any) -> None: + pass + + async def seek_to_end(self, *partitions: Any) -> None: + pass + + def seek(self, partition: Any, offset: int) -> None: + pass + + def __aiter__(self) -> "FakeAIOKafkaConsumer": + return self + + async def __anext__(self) -> Any: + if self._messages: + return self._messages.pop(0) + raise StopAsyncIteration diff --git a/backend/tests/helpers/fakes/kubernetes.py b/backend/tests/helpers/fakes/kubernetes.py new file mode 100644 index 00000000..64ad99ad --- /dev/null +++ b/backend/tests/helpers/fakes/kubernetes.py @@ -0,0 +1,54 @@ +"""Minimal fake Kubernetes clients for DI container testing.""" + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import MagicMock + + +@dataclass +class FakeK8sApiClient: + """Minimal fake for k8s ApiClient - satisfies DI container.""" + + configuration: Any = None + + def close(self) -> None: + pass + + +@dataclass +class FakeK8sCoreV1Api: + """Minimal fake for k8s CoreV1Api - satisfies DI container.""" + + api_client: FakeK8sApiClient = field(default_factory=FakeK8sApiClient) + + def read_namespaced_pod(self, name: str, namespace: str) -> Any: + return MagicMock() + + def list_namespaced_pod(self, namespace: str, **kwargs: Any) -> Any: + return MagicMock(items=[]) + + def delete_namespaced_pod(self, name: str, namespace: str, **kwargs: Any) -> Any: + return MagicMock() + + def read_namespaced_pod_log(self, name: str, namespace: str, **kwargs: Any) -> str: + return "" + + +@dataclass +class FakeK8sAppsV1Api: + """Minimal fake for k8s AppsV1Api - satisfies DI container.""" + + api_client: FakeK8sApiClient = field(default_factory=FakeK8sApiClient) + + +@dataclass +class FakeK8sWatch: + """Minimal fake for k8s Watch - satisfies DI container.""" + + _events: list[dict[str, Any]] = field(default_factory=list) + + def stream(self, func: Any, *args: Any, **kwargs: Any) -> Any: + return iter(self._events) + + def stop(self) -> None: + pass diff --git a/backend/tests/helpers/fakes/providers.py b/backend/tests/helpers/fakes/providers.py new file mode 100644 index 00000000..14718873 --- /dev/null +++ b/backend/tests/helpers/fakes/providers.py @@ -0,0 +1,84 @@ +"""Fake providers for unit testing with DI container.""" + +import logging +from typing import Any + +import fakeredis.aioredis +import redis.asyncio as redis +from aiokafka import AIOKafkaProducer +from app.core.database_context import Database +from app.events.schema.schema_registry import SchemaRegistryManager +from app.settings import Settings +from dishka import Provider, Scope, provide +from kubernetes import client as k8s_client +from kubernetes import watch as k8s_watch +from mongomock_motor import AsyncMongoMockClient + +from tests.helpers.fakes.kafka import FakeAIOKafkaProducer +from tests.helpers.fakes.kubernetes import ( + FakeK8sApiClient, + FakeK8sAppsV1Api, + FakeK8sCoreV1Api, + FakeK8sWatch, +) +from tests.helpers.fakes.schema_registry import FakeSchemaRegistryManager + + +class FakeBoundaryClientProvider(Provider): + """Fake boundary clients for unit testing. + + Overrides BoundaryClientProvider - provides fake implementations + for Redis, Kafka, and K8s clients so tests can run without external deps. + """ + + scope = Scope.APP + + @provide + def get_redis_client(self, logger: logging.Logger) -> redis.Redis: + logger.info("Using FakeRedis for testing") + return fakeredis.aioredis.FakeRedis(decode_responses=False) + + @provide + def get_kafka_producer_client(self) -> AIOKafkaProducer: + return FakeAIOKafkaProducer() + + @provide + def get_k8s_api_client(self, logger: logging.Logger) -> k8s_client.ApiClient: + logger.info("Using FakeK8sApiClient for testing") + return FakeK8sApiClient() + + @provide + def get_k8s_core_v1_api(self, api_client: k8s_client.ApiClient) -> k8s_client.CoreV1Api: + return FakeK8sCoreV1Api(api_client=api_client) + + @provide + def get_k8s_apps_v1_api(self, api_client: k8s_client.ApiClient) -> k8s_client.AppsV1Api: + return FakeK8sAppsV1Api(api_client=api_client) + + @provide + def get_k8s_watch(self) -> k8s_watch.Watch: + return FakeK8sWatch() + + +class FakeDatabaseProvider(Provider): + """Fake MongoDB database for unit testing using mongomock-motor.""" + + scope = Scope.APP + + @provide + def get_database(self, settings: Settings, logger: logging.Logger) -> Database: + logger.info(f"Using AsyncMongoMockClient for testing: {settings.DATABASE_NAME}") + client: AsyncMongoMockClient[dict[str, Any]] = AsyncMongoMockClient() + # mongomock_motor returns AsyncIOMotorDatabase which is API-compatible with AsyncDatabase + return client[settings.DATABASE_NAME] # type: ignore[return-value] + + +class FakeSchemaRegistryProvider(Provider): + """Fake Schema Registry provider - must be placed after EventProvider to override.""" + + scope = Scope.APP + + @provide + def get_schema_registry(self, logger: logging.Logger) -> SchemaRegistryManager: + logger.info("Using FakeSchemaRegistryManager for testing") + return FakeSchemaRegistryManager(logger=logger) # type: ignore[return-value] diff --git a/backend/tests/helpers/fakes/schema_registry.py b/backend/tests/helpers/fakes/schema_registry.py new file mode 100644 index 00000000..644e5804 --- /dev/null +++ b/backend/tests/helpers/fakes/schema_registry.py @@ -0,0 +1,116 @@ +"""Fake Schema Registry Manager for unit testing.""" + +import io +import logging +import struct +from functools import lru_cache +from typing import Any, get_args, get_origin + +import fastavro +from app.domain.enums.events import EventType +from app.domain.events.typed import DomainEvent +from fastavro.types import Schema + +MAGIC_BYTE = b"\x00" + + +@lru_cache(maxsize=1) +def _get_all_event_classes() -> list[type[DomainEvent]]: + """Get all concrete event classes from DomainEvent union.""" + union_type = get_args(DomainEvent)[0] # Annotated[Union[...], Discriminator] -> Union + return list(get_args(union_type)) if get_origin(union_type) else [union_type] + + +class FakeSchemaRegistryManager: + """Fake schema registry manager for unit tests. + + Serializes/deserializes using fastavro without network calls. + """ + + def __init__(self, settings: Any = None, logger: logging.Logger | None = None): + self.logger = logger or logging.getLogger(__name__) + self._schema_id_counter = 0 + self._schema_id_cache: dict[type[DomainEvent], int] = {} + self._id_to_class_cache: dict[int, type[DomainEvent]] = {} + self._parsed_schemas: dict[type[DomainEvent], Schema] = {} + + def _get_schema_id(self, event_class: type[DomainEvent]) -> int: + """Get or assign schema ID for event class.""" + if event_class not in self._schema_id_cache: + self._schema_id_counter += 1 + self._schema_id_cache[event_class] = self._schema_id_counter + self._id_to_class_cache[self._schema_id_counter] = event_class + return self._schema_id_cache[event_class] + + def _get_parsed_schema(self, event_class: type[DomainEvent]) -> Schema: + """Get or parse Avro schema for event class.""" + if event_class not in self._parsed_schemas: + avro_schema = event_class.avro_schema_to_python() + self._parsed_schemas[event_class] = fastavro.parse_schema(avro_schema) + return self._parsed_schemas[event_class] + + async def serialize_event(self, event: DomainEvent) -> bytes: + """Serialize event to Confluent wire format.""" + event_class = event.__class__ + schema_id = self._get_schema_id(event_class) + parsed_schema = self._get_parsed_schema(event_class) + + # Prepare payload + payload: dict[str, Any] = event.model_dump(mode="python", by_alias=False, exclude_unset=False) + payload.pop("event_type", None) + + # Convert datetime to microseconds for Avro logical type + if "timestamp" in payload and payload["timestamp"] is not None: + payload["timestamp"] = int(payload["timestamp"].timestamp() * 1_000_000) + + # Serialize with fastavro + buffer = io.BytesIO() + fastavro.schemaless_writer(buffer, parsed_schema, payload) + avro_bytes = buffer.getvalue() + + # Confluent wire format: [0x00][4-byte schema id BE][Avro binary] + return MAGIC_BYTE + struct.pack(">I", schema_id) + avro_bytes + + async def deserialize_event(self, data: bytes, topic: str) -> DomainEvent: + """Deserialize from Confluent wire format to DomainEvent.""" + if not data or len(data) < 5: + raise ValueError("Invalid message: too short for wire format") + if data[0:1] != MAGIC_BYTE: + raise ValueError(f"Unknown magic byte: {data[0]:#x}") + + schema_id = struct.unpack(">I", data[1:5])[0] + event_class = self._id_to_class_cache.get(schema_id) + if not event_class: + raise ValueError(f"Unknown schema ID: {schema_id}") + + parsed_schema = self._get_parsed_schema(event_class) + buffer = io.BytesIO(data[5:]) + obj = fastavro.schemaless_reader(buffer, parsed_schema, parsed_schema) + + if not isinstance(obj, dict): + raise ValueError(f"Deserialization returned {type(obj)}, expected dict") + + # Restore event_type if missing + if (f := event_class.model_fields.get("event_type")) and f.default and "event_type" not in obj: + obj["event_type"] = f.default + + return event_class.model_validate(obj) + + def deserialize_json(self, data: dict[str, Any]) -> DomainEvent: + """Deserialize JSON data to DomainEvent using event_type field.""" + if not (event_type_str := data.get("event_type")): + raise ValueError("Missing event_type in event data") + mapping = {cls.model_fields["event_type"].default: cls for cls in _get_all_event_classes()} + if not (event_class := mapping.get(EventType(event_type_str))): + raise ValueError(f"No event class found for event type: {event_type_str}") + return event_class.model_validate(data) + + async def set_compatibility(self, subject: str, mode: str) -> None: + """No-op for fake.""" + pass + + async def initialize_schemas(self) -> None: + """Pre-register all schemas.""" + for event_class in _get_all_event_classes(): + self._get_schema_id(event_class) + self._get_parsed_schema(event_class) diff --git a/backend/tests/integration/events/test_consume_roundtrip.py b/backend/tests/integration/events/test_consume_roundtrip.py deleted file mode 100644 index ff8d5554..00000000 --- a/backend/tests/integration/events/test_consume_roundtrip.py +++ /dev/null @@ -1,70 +0,0 @@ -import asyncio -import logging - -import pytest -from app.core.metrics import EventMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.settings import Settings -from dishka import AsyncContainer - -from tests.helpers import make_execution_requested_event - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.xdist_group("kafka_consumers")] - -_test_logger = logging.getLogger("test.events.consume_roundtrip") - - -@pytest.mark.asyncio -async def test_produce_consume_roundtrip( - scope: AsyncContainer, - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - consumer_config: ConsumerConfig, - test_settings: Settings, -) -> None: - # Ensure schemas are registered - await initialize_event_schemas(schema_registry) - - # Real producer from DI - producer: UnifiedProducer = await scope.get(UnifiedProducer) - - # Build a consumer that handles EXECUTION_REQUESTED - dispatcher = EventDispatcher(logger=_test_logger) - received = asyncio.Event() - - @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def _handle(_event: DomainEvent) -> None: - received.set() - - consumer = UnifiedConsumer( - consumer_config, - dispatcher, - schema_registry=schema_registry, - settings=test_settings, - logger=_test_logger, - event_metrics=event_metrics, - topics=[KafkaTopic.EXECUTION_EVENTS], - ) - - # Start consumer as background task - consumer_task = asyncio.create_task(consumer.run()) - - try: - # Produce a request event - execution_id = f"exec-{consumer_config.group_id}" - evt = make_execution_requested_event(execution_id=execution_id) - await producer.produce(evt, key=execution_id) - - # Wait for the handler to be called - await asyncio.wait_for(received.wait(), timeout=10.0) - finally: - consumer_task.cancel() - with pytest.raises(asyncio.CancelledError): - await consumer_task diff --git a/backend/tests/integration/events/test_consumer_lifecycle.py b/backend/tests/integration/events/test_consumer_lifecycle.py deleted file mode 100644 index 5ee140bf..00000000 --- a/backend/tests/integration/events/test_consumer_lifecycle.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -import logging - -import pytest -from app.core.metrics import EventMetrics -from app.domain.enums.kafka import KafkaTopic -from app.events.core import ConsumerConfig, EventDispatcher, UnifiedConsumer -from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.xdist_group("kafka_consumers")] - -_test_logger = logging.getLogger("test.events.consumer_lifecycle") - - -@pytest.mark.asyncio -async def test_consumer_run_and_cancel( - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - consumer_config: ConsumerConfig, - test_settings: Settings, -) -> None: - """Test consumer run() blocks until cancelled and seek methods work.""" - disp = EventDispatcher(logger=_test_logger) - consumer = UnifiedConsumer( - consumer_config, - dispatcher=disp, - schema_registry=schema_registry, - settings=test_settings, - logger=_test_logger, - event_metrics=event_metrics, - topics=[KafkaTopic.EXECUTION_EVENTS], - ) - - # Track when consumer is running - consumer_started = asyncio.Event() - - async def run_with_signal() -> None: - consumer_started.set() - await consumer.run() - - task = asyncio.create_task(run_with_signal()) - - try: - # Wait for consumer to start - await asyncio.wait_for(consumer_started.wait(), timeout=5.0) - - # Exercise seek functions while consumer is running - await consumer.seek_to_beginning() - await consumer.seek_to_end() - finally: - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task diff --git a/backend/tests/integration/events/test_event_dispatcher.py b/backend/tests/integration/events/test_event_dispatcher.py deleted file mode 100644 index 92d6544e..00000000 --- a/backend/tests/integration/events/test_event_dispatcher.py +++ /dev/null @@ -1,72 +0,0 @@ -import asyncio -import logging - -import pytest -from app.core.metrics import EventMetrics -from app.domain.enums.events import EventType -from app.domain.enums.kafka import KafkaTopic -from app.domain.events.typed import DomainEvent -from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.settings import Settings -from dishka import AsyncContainer - -from tests.helpers import make_execution_requested_event - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [pytest.mark.integration, pytest.mark.kafka, pytest.mark.xdist_group("kafka_consumers")] - -_test_logger = logging.getLogger("test.events.event_dispatcher") - - -@pytest.mark.asyncio -async def test_dispatcher_with_multiple_handlers( - scope: AsyncContainer, - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - consumer_config: ConsumerConfig, - test_settings: Settings, -) -> None: - # Ensure schema registry is ready - await initialize_event_schemas(schema_registry) - - # Build dispatcher with two handlers for the same event - dispatcher = EventDispatcher(logger=_test_logger) - h1_called = asyncio.Event() - h2_called = asyncio.Event() - - @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def h1(_e: DomainEvent) -> None: - h1_called.set() - - @dispatcher.register(EventType.EXECUTION_REQUESTED) - async def h2(_e: DomainEvent) -> None: - h2_called.set() - - # Real consumer against execution-events - consumer = UnifiedConsumer( - consumer_config, - dispatcher, - schema_registry=schema_registry, - settings=test_settings, - logger=_test_logger, - event_metrics=event_metrics, - topics=[KafkaTopic.EXECUTION_EVENTS], - ) - - # Start consumer as background task - consumer_task = asyncio.create_task(consumer.run()) - - # Produce a request event via DI - producer: UnifiedProducer = await scope.get(UnifiedProducer) - evt = make_execution_requested_event(execution_id=f"exec-{consumer_config.group_id}") - await producer.produce(evt, key="k") - - try: - await asyncio.wait_for(asyncio.gather(h1_called.wait(), h2_called.wait()), timeout=10.0) - finally: - consumer_task.cancel() - with pytest.raises(asyncio.CancelledError): - await consumer_task diff --git a/backend/tests/integration/events/test_producer_roundtrip.py b/backend/tests/integration/events/test_producer_roundtrip.py index eda0c299..bd30776f 100644 --- a/backend/tests/integration/events/test_producer_roundtrip.py +++ b/backend/tests/integration/events/test_producer_roundtrip.py @@ -2,8 +2,9 @@ from uuid import uuid4 import pytest +from aiokafka import AIOKafkaProducer from app.core.metrics import EventMetrics -from app.events.core import UnifiedProducer +from app.events.core import ProducerMetrics, UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager from app.infrastructure.kafka.mappings import get_topic_for_event from app.settings import Settings @@ -22,20 +23,34 @@ async def test_unified_producer_start_produce_send_to_dlq_stop( ) -> None: schema: SchemaRegistryManager = await scope.get(SchemaRegistryManager) event_metrics: EventMetrics = await scope.get(EventMetrics) - prod = UnifiedProducer( - schema, - logger=_test_logger, - settings=test_settings, - event_metrics=event_metrics, + + aiokafka_producer = AIOKafkaProducer( + bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, + client_id=f"{test_settings.SERVICE_NAME}-producer-test", + acks="all", + compression_type="gzip", + max_batch_size=16384, + linger_ms=10, + enable_idempotence=True, ) - async with prod: + # Start the underlying producer (lifecycle managed externally, not by UnifiedProducer) + await aiokafka_producer.start() + try: + prod = UnifiedProducer( + producer=aiokafka_producer, + metrics=ProducerMetrics(), + schema_registry=schema, + settings=test_settings, + logger=_test_logger, + event_metrics=event_metrics, + ) + ev = make_execution_requested_event(execution_id=f"exec-{uuid4().hex[:8]}") await prod.produce(ev) # Exercise send_to_dlq path topic = str(get_topic_for_event(ev.event_type)) await prod.send_to_dlq(ev, original_topic=topic, error=RuntimeError("forced"), retry_count=1) - - st = prod.get_status() - assert st["state"] == "running" + finally: + await aiokafka_producer.stop() diff --git a/backend/tests/integration/result_processor/test_result_processor.py b/backend/tests/integration/result_processor/test_result_processor.py deleted file mode 100644 index 445c1588..00000000 --- a/backend/tests/integration/result_processor/test_result_processor.py +++ /dev/null @@ -1,165 +0,0 @@ -import asyncio -import logging -import uuid - -import pytest -from app.core.database_context import Database -from app.core.metrics import EventMetrics, ExecutionMetrics -from app.db.repositories.execution_repository import ExecutionRepository -from app.domain.enums.events import EventType -from app.domain.enums.execution import ExecutionStatus -from app.domain.enums.kafka import GroupId, KafkaTopic -from app.domain.events.typed import EventMetadata, ExecutionCompletedEvent, ResourceUsageAvro, ResultStoredEvent -from app.domain.execution import DomainExecutionCreate -from app.events.core import ConsumerConfig, UnifiedConsumer, UnifiedProducer -from app.events.core.dispatcher import EventDispatcher -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.result_processor import ProcessorLogic -from app.settings import Settings -from dishka import AsyncContainer - -# xdist_group: Kafka consumer creation can crash librdkafka when multiple workers -# instantiate Consumer() objects simultaneously. Serial execution prevents this. -pytestmark = [ - pytest.mark.integration, - pytest.mark.kafka, - pytest.mark.mongodb, - pytest.mark.xdist_group("kafka_consumers"), -] - -_test_logger = logging.getLogger("test.result_processor.processor") - - -@pytest.mark.asyncio -async def test_result_processor_persists_and_emits( - scope: AsyncContainer, - schema_registry: SchemaRegistryManager, - event_metrics: EventMetrics, - test_settings: Settings, -) -> None: - # Ensure schemas - execution_metrics: ExecutionMetrics = await scope.get(ExecutionMetrics) - await initialize_event_schemas(schema_registry) - - # Dependencies - db: Database = await scope.get(Database) - repo: ExecutionRepository = await scope.get(ExecutionRepository) - producer: UnifiedProducer = await scope.get(UnifiedProducer) - - # Create a base execution to satisfy ProcessorLogic lookup - created = await repo.create_execution(DomainExecutionCreate( - script="print('x')", - user_id="u1", - lang="python", - lang_version="3.11", - status=ExecutionStatus.RUNNING, - )) - execution_id = created.execution_id - - # Build the ProcessorLogic and wire up the consumer - logic = ProcessorLogic( - execution_repo=repo, - producer=producer, - settings=test_settings, - logger=_test_logger, - execution_metrics=execution_metrics, - ) - - # Create dispatcher and register handlers - processor_dispatcher = EventDispatcher(logger=_test_logger) - logic.register_handlers(processor_dispatcher) - - # Create consumer config with unique group id - processor_consumer_config = ConsumerConfig( - bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.RESULT_PROCESSOR}.test.{uuid.uuid4().hex[:8]}", - max_poll_records=1, - enable_auto_commit=True, - auto_offset_reset="earliest", - session_timeout_ms=test_settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=test_settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=test_settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=test_settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - # Create processor consumer (idempotency is now handled by FastStream middleware in production) - processor_consumer = UnifiedConsumer( - processor_consumer_config, - dispatcher=processor_dispatcher, - schema_registry=schema_registry, - settings=test_settings, - logger=_test_logger, - event_metrics=event_metrics, - topics=[KafkaTopic.EXECUTION_COMPLETED, KafkaTopic.EXECUTION_FAILED, KafkaTopic.EXECUTION_TIMEOUT], - ) - - # Setup a separate consumer to capture ResultStoredEvent - stored_dispatcher = EventDispatcher(logger=_test_logger) - stored_received = asyncio.Event() - - @stored_dispatcher.register(EventType.RESULT_STORED) - async def _stored(event: ResultStoredEvent) -> None: - if event.execution_id == execution_id: - stored_received.set() - - stored_consumer_config = ConsumerConfig( - bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"test.result_stored.{uuid.uuid4().hex[:8]}", - max_poll_records=1, - enable_auto_commit=True, - auto_offset_reset="earliest", - ) - - stored_consumer = UnifiedConsumer( - stored_consumer_config, - stored_dispatcher, - schema_registry=schema_registry, - settings=test_settings, - logger=_test_logger, - event_metrics=event_metrics, - topics=[KafkaTopic.EXECUTION_RESULTS], - ) - - # Produce the event BEFORE starting consumers (auto_offset_reset="earliest" will read it) - usage = ResourceUsageAvro( - execution_time_wall_seconds=0.5, - cpu_time_jiffies=100, - clk_tck_hertz=100, - peak_memory_kb=1024, - ) - evt = ExecutionCompletedEvent( - execution_id=execution_id, - exit_code=0, - stdout="hello", - stderr="", - resource_usage=usage, - metadata=EventMetadata(service_name="tests", service_version="1.0.0"), - ) - await producer.produce(evt, key=execution_id) - - # Start consumers as background tasks - processor_task = asyncio.create_task(processor_consumer.run()) - stored_task = asyncio.create_task(stored_consumer.run()) - - try: - # Await the ResultStoredEvent - signals that processing is complete - await asyncio.wait_for(stored_received.wait(), timeout=12.0) - - # Now verify DB persistence - should be done since event was emitted - doc = await db.get_collection("executions").find_one({"execution_id": execution_id}) - assert doc is not None, f"Execution {execution_id} not found in DB after ResultStoredEvent" - assert doc.get("status") == ExecutionStatus.COMPLETED, ( - f"Expected COMPLETED status, got {doc.get('status')}" - ) - finally: - # Cancel and cleanup both consumers - processor_task.cancel() - stored_task.cancel() - try: - await processor_task - except asyncio.CancelledError: - pass - try: - await stored_task - except asyncio.CancelledError: - pass diff --git a/backend/tests/integration/services/sse/test_partitioned_event_router.py b/backend/tests/integration/services/sse/test_partitioned_event_router.py deleted file mode 100644 index fd8b046f..00000000 --- a/backend/tests/integration/services/sse/test_partitioned_event_router.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -import logging -from uuid import uuid4 - -import pytest -from app.events.core import EventDispatcher -from app.schemas_pydantic.sse import RedisSSEMessage -from app.services.sse.event_router import SSEEventRouter -from app.services.sse.redis_bus import SSERedisBus - -from tests.helpers import make_execution_requested_event - -pytestmark = [pytest.mark.integration, pytest.mark.redis] - -_test_logger = logging.getLogger("test.services.sse.event_router_integration") - - -@pytest.mark.asyncio -async def test_event_router_bridges_to_redis( - sse_redis_bus: SSERedisBus, -) -> None: - """Test that SSEEventRouter routes events to Redis correctly.""" - router = SSEEventRouter(sse_bus=sse_redis_bus, logger=_test_logger) - - # Register handlers with dispatcher - disp = EventDispatcher(logger=_test_logger) - router.register_handlers(disp) - - # Open Redis subscription for our execution id - execution_id = f"e-{uuid4().hex[:8]}" - subscription = await sse_redis_bus.open_subscription(execution_id) - - # Create and route an event - ev = make_execution_requested_event(execution_id=execution_id) - handler = disp.get_handlers(ev.event_type)[0] - await handler(ev) - - # Await the subscription - verify event arrived in Redis - msg = await asyncio.wait_for(subscription.get(RedisSSEMessage), timeout=2.0) - assert msg is not None - assert str(msg.event_type) == str(ev.event_type) diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index a81a26a2..24e1bf1b 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -1,7 +1,10 @@ import logging -from typing import NoReturn +from typing import AsyncGenerator, NoReturn import pytest +import pytest_asyncio +from aiokafka import AIOKafkaProducer +from app.core.database_context import Database from app.core.metrics import ( ConnectionMetrics, CoordinatorMetrics, @@ -16,14 +19,80 @@ ReplayMetrics, SecurityMetrics, ) +from app.core.providers import ( + CoreServicesProvider, + EventProvider, + KafkaServicesProvider, + LoggingProvider, + MessagingProvider, + MetricsProvider, + RedisServicesProvider, + RepositoryProvider, + SettingsProvider, +) +from app.db.docs import ALL_DOCUMENTS +from app.db.repositories import ( + EventRepository, + ExecutionRepository, + SagaRepository, + SSERepository, +) +from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository +from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager +from app.services.kafka_event_service import KafkaEventService from app.services.pod_monitor.config import PodMonitorConfig +from app.services.pod_monitor.event_mapper import PodEventMapper +from app.services.pod_monitor.monitor import PodMonitor from app.settings import Settings +from beanie import init_beanie +from dishka import AsyncContainer, make_async_container + +from tests.helpers.fakes import FakeBoundaryClientProvider, FakeDatabaseProvider, FakeSchemaRegistryProvider +from tests.helpers.fakes.kafka import FakeAIOKafkaProducer +from tests.helpers.k8s_fakes import FakeApi, FakeV1Api, FakeWatch, make_k8s_clients _test_logger = logging.getLogger("test.unit") -# Metrics fixtures - provided via DI, not global context +@pytest_asyncio.fixture(scope="session") +async def unit_container(test_settings: Settings) -> AsyncGenerator[AsyncContainer, None]: + """DI container for unit tests with fake boundary clients. + + Provides: + - Fake Redis, Kafka, K8s, MongoDB (boundary clients) + - Real metrics, repositories, services (internal) + """ + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + FakeBoundaryClientProvider(), + FakeDatabaseProvider(), + RedisServicesProvider(), + MetricsProvider(), + EventProvider(), + FakeSchemaRegistryProvider(), # Override real schema registry with fake + MessagingProvider(), + CoreServicesProvider(), + KafkaServicesProvider(), + RepositoryProvider(), + context={Settings: test_settings}, + ) + + db = await container.get(Database) + await init_beanie(database=db, document_models=ALL_DOCUMENTS) + + yield container + await container.close() + + +@pytest_asyncio.fixture +async def unit_scope(unit_container: AsyncContainer) -> AsyncGenerator[AsyncContainer, None]: + """Request scope from unit test container.""" + async with unit_container() as scope: + yield scope + + @pytest.fixture def connection_metrics(test_settings: Settings) -> ConnectionMetrics: return ConnectionMetrics(test_settings) @@ -84,33 +153,117 @@ def security_metrics(test_settings: Settings) -> SecurityMetrics: return SecurityMetrics(test_settings) +@pytest_asyncio.fixture +async def saga_repository(unit_container: AsyncContainer) -> SagaRepository: + return await unit_container.get(SagaRepository) + + +@pytest_asyncio.fixture +async def execution_repository(unit_container: AsyncContainer) -> ExecutionRepository: + return await unit_container.get(ExecutionRepository) + + +@pytest_asyncio.fixture +async def event_repository(unit_container: AsyncContainer) -> EventRepository: + return await unit_container.get(EventRepository) + + +@pytest_asyncio.fixture +async def sse_repository(unit_container: AsyncContainer) -> SSERepository: + return await unit_container.get(SSERepository) + + +@pytest_asyncio.fixture +async def resource_allocation_repository(unit_container: AsyncContainer) -> ResourceAllocationRepository: + return await unit_container.get(ResourceAllocationRepository) + + +@pytest_asyncio.fixture +async def unified_producer(unit_container: AsyncContainer) -> UnifiedProducer: + return await unit_container.get(UnifiedProducer) + + +@pytest_asyncio.fixture +async def schema_registry(unit_container: AsyncContainer) -> SchemaRegistryManager: + return await unit_container.get(SchemaRegistryManager) + + +@pytest_asyncio.fixture +async def test_logger(unit_container: AsyncContainer) -> logging.Logger: + return await unit_container.get(logging.Logger) + + +@pytest_asyncio.fixture +async def kafka_event_service(unit_container: AsyncContainer) -> KafkaEventService: + """Real KafkaEventService wired with fake backends.""" + return await unit_container.get(KafkaEventService) + + +@pytest_asyncio.fixture +async def fake_kafka_producer(unit_container: AsyncContainer) -> FakeAIOKafkaProducer: + """Access to fake Kafka producer for verifying sent messages.""" + producer = await unit_container.get(AIOKafkaProducer) + assert isinstance(producer, FakeAIOKafkaProducer) + return producer + + @pytest.fixture -def db() -> NoReturn: - raise RuntimeError("Unit tests should not access DB - use mocks or move to integration/") +def pod_monitor_config() -> PodMonitorConfig: + return PodMonitorConfig() @pytest.fixture -def redis_client() -> NoReturn: - raise RuntimeError("Unit tests should not access Redis - use mocks or move to integration/") +def k8s_v1() -> FakeV1Api: + """Default fake CoreV1Api for tests.""" + v1, _ = make_k8s_clients() + return v1 @pytest.fixture -def client() -> NoReturn: - raise RuntimeError("Unit tests should not use HTTP client - use mocks or move to integration/") +def k8s_watch() -> FakeWatch: + """Default fake Watch for tests.""" + _, watch = make_k8s_clients() + return watch + + +@pytest_asyncio.fixture +async def pod_monitor( + unit_container: AsyncContainer, + pod_monitor_config: PodMonitorConfig, + kubernetes_metrics: KubernetesMetrics, + k8s_v1: FakeV1Api, + k8s_watch: FakeWatch, +) -> PodMonitor: + """Fully wired PodMonitor ready for testing.""" + kafka_service = await unit_container.get(KafkaEventService) + event_mapper = PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")) + + return PodMonitor( + config=pod_monitor_config, + kafka_event_service=kafka_service, + logger=_test_logger, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, + event_mapper=event_mapper, + kubernetes_metrics=kubernetes_metrics, + ) @pytest.fixture -def app() -> NoReturn: - raise RuntimeError("Unit tests should not use full app - use mocks or move to integration/") +def db() -> NoReturn: + raise RuntimeError("Use 'unit_container' fixture for DB access in unit tests") -# Config fixtures - fresh instance per test (can be customized by tests) @pytest.fixture -def pod_monitor_config() -> PodMonitorConfig: - return PodMonitorConfig() +def redis_client() -> NoReturn: + raise RuntimeError("Use 'unit_container' fixture for Redis in unit tests") @pytest.fixture -def schema_registry(test_settings: Settings) -> SchemaRegistryManager: - """Provide SchemaRegistryManager for unit tests (no external connections).""" - return SchemaRegistryManager(test_settings, logger=_test_logger) +def client() -> NoReturn: + raise RuntimeError("Unit tests should not use HTTP client - move to integration/") + + +@pytest.fixture +def app() -> NoReturn: + raise RuntimeError("Unit tests should not use full app - move to integration/") diff --git a/backend/tests/unit/events/test_event_dispatcher.py b/backend/tests/unit/events/test_event_dispatcher.py deleted file mode 100644 index 6bda67e8..00000000 --- a/backend/tests/unit/events/test_event_dispatcher.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging - -from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent -from app.events.core import EventDispatcher - -from tests.helpers import make_execution_requested_event - -_test_logger = logging.getLogger("test.events.event_dispatcher") - - -def make_event() -> DomainEvent: - return make_execution_requested_event(execution_id="e1") - - -async def _async_noop(_: DomainEvent) -> None: - return None - - -def test_register_and_remove_handler() -> None: - disp = EventDispatcher(logger=_test_logger) - - # Register via direct method - disp.register_handler(EventType.EXECUTION_REQUESTED, _async_noop) - assert len(disp.get_handlers(EventType.EXECUTION_REQUESTED)) == 1 - - # Remove - ok = disp.remove_handler(EventType.EXECUTION_REQUESTED, _async_noop) - assert ok is True - assert len(disp.get_handlers(EventType.EXECUTION_REQUESTED)) == 0 - - -def test_decorator_registration() -> None: - disp = EventDispatcher(logger=_test_logger) - - @disp.register(EventType.EXECUTION_REQUESTED) - async def handler(ev: DomainEvent) -> None: # noqa: ARG001 - return None - - assert len(disp.get_handlers(EventType.EXECUTION_REQUESTED)) == 1 - - -async def test_dispatch_metrics_processed_and_skipped() -> None: - disp = EventDispatcher(logger=_test_logger) - called = {"n": 0} - - @disp.register(EventType.EXECUTION_REQUESTED) - async def handler(_: DomainEvent) -> None: - called["n"] += 1 - - await disp.dispatch(make_event()) - # Dispatch event with no handlers (different type) - # Reuse base event but fake type by replacing value - e = make_event() - e.event_type = EventType.EXECUTION_FAILED - await disp.dispatch(e) - - metrics = disp.get_metrics() - assert called["n"] == 1 - assert metrics[EventType.EXECUTION_REQUESTED]["processed"] >= 1 - assert metrics[EventType.EXECUTION_FAILED]["skipped"] >= 1 diff --git a/backend/tests/unit/services/coordinator/__init__.py b/backend/tests/unit/services/coordinator/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index ab063129..2399c779 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -2,10 +2,8 @@ import logging import types from typing import Any -from unittest.mock import MagicMock import pytest -from app.core.k8s_clients import K8sClients from app.core.metrics import EventMetrics, KubernetesMetrics from app.db.repositories.event_repository import EventRepository from app.domain.events.typed import ( @@ -25,8 +23,10 @@ WatchEventType, ) from app.settings import Settings +from dishka import AsyncContainer from kubernetes.client.rest import ApiException +from tests.helpers.fakes.kafka import FakeAIOKafkaProducer from tests.helpers.k8s_fakes import ( FakeApi, FakeV1Api, @@ -42,56 +42,221 @@ _test_logger = logging.getLogger("test.pod_monitor") -# ===== Test doubles for KafkaEventService dependencies ===== +# ===== Tests using default pod_monitor fixture ===== -class FakeEventRepository(EventRepository): - """In-memory event repository for testing.""" +@pytest.mark.asyncio +async def test_process_raw_event_invalid_and_backoff(pod_monitor: PodMonitor) -> None: + await pod_monitor._process_raw_event({}) - def __init__(self) -> None: - super().__init__(_test_logger) - self.stored_events: list[DomainEvent] = [] + pod_monitor.config.watch_reconnect_delay = 0 + pod_monitor._reconnect_attempts = 0 + await pod_monitor._backoff() + await pod_monitor._backoff() + assert pod_monitor._reconnect_attempts >= 2 - async def store_event(self, event: DomainEvent) -> str: - self.stored_events.append(event) - return event.event_id +@pytest.mark.asyncio +async def test_backoff_max_attempts(pod_monitor: PodMonitor) -> None: + pod_monitor.config.max_reconnect_attempts = 2 + pod_monitor._reconnect_attempts = 2 -class FakeUnifiedProducer(UnifiedProducer): - """Fake producer that captures events without Kafka.""" + with pytest.raises(RuntimeError, match="Max reconnect attempts exceeded"): + await pod_monitor._backoff() - def __init__(self) -> None: - # Don't call super().__init__ - we don't need real Kafka - self.produced_events: list[tuple[DomainEvent, str | None]] = [] - self.logger = _test_logger - async def produce( - self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None - ) -> None: - self.produced_events.append((event_to_produce, key)) +@pytest.mark.asyncio +async def test_watch_loop_with_cancellation(pod_monitor: PodMonitor) -> None: + pod_monitor.config.enable_state_reconciliation = False + watch_count: list[int] = [] + + async def mock_run_watch() -> None: + watch_count.append(1) + if len(watch_count) >= 3: + raise asyncio.CancelledError() + + pod_monitor._run_watch = mock_run_watch # type: ignore[method-assign] + + with pytest.raises(asyncio.CancelledError): + await pod_monitor._watch_loop() + + assert len(watch_count) == 3 + + +@pytest.mark.asyncio +async def test_watch_loop_api_exception_410(pod_monitor: PodMonitor) -> None: + pod_monitor.config.enable_state_reconciliation = False + pod_monitor._last_resource_version = "v123" + call_count = 0 + + async def mock_run_watch() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ApiException(status=410) + raise asyncio.CancelledError() - async def aclose(self) -> None: + async def mock_backoff() -> None: pass + pod_monitor._run_watch = mock_run_watch # type: ignore[method-assign] + pod_monitor._backoff = mock_backoff # type: ignore[method-assign] -def create_test_kafka_event_service( - event_metrics: EventMetrics, settings: Settings -) -> tuple[KafkaEventService, FakeUnifiedProducer]: - """Create real KafkaEventService with fake dependencies for testing.""" - fake_producer = FakeUnifiedProducer() - fake_repo = FakeEventRepository() + with pytest.raises(asyncio.CancelledError): + await pod_monitor._watch_loop() - service = KafkaEventService( - event_repository=fake_repo, - kafka_producer=fake_producer, - settings=settings, - logger=_test_logger, - event_metrics=event_metrics, + assert pod_monitor._last_resource_version is None + + +@pytest.mark.asyncio +async def test_watch_loop_generic_exception(pod_monitor: PodMonitor) -> None: + pod_monitor.config.enable_state_reconciliation = False + call_count = 0 + backoff_count = 0 + + async def mock_run_watch() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Unexpected error") + raise asyncio.CancelledError() + + async def mock_backoff() -> None: + nonlocal backoff_count + backoff_count += 1 + + pod_monitor._run_watch = mock_run_watch # type: ignore[method-assign] + pod_monitor._backoff = mock_backoff # type: ignore[method-assign] + + with pytest.raises(asyncio.CancelledError): + await pod_monitor._watch_loop() + + assert backoff_count == 1 + + +@pytest.mark.asyncio +async def test_cleanup_on_cancel(pod_monitor: PodMonitor) -> None: + """Test cleanup of tracked pods on cancellation.""" + watch_started = asyncio.Event() + + async def _blocking_watch() -> None: + pod_monitor._tracked_pods = {"pod1"} + watch_started.set() + await asyncio.sleep(10) + + pod_monitor._watch_loop = _blocking_watch # type: ignore[method-assign] + + task = asyncio.create_task(pod_monitor.run()) + await asyncio.wait_for(watch_started.wait(), timeout=1.0) + assert "pod1" in pod_monitor._tracked_pods + + task.cancel() + await task + + assert len(pod_monitor._tracked_pods) == 0 + + +@pytest.mark.asyncio +async def test_process_raw_event_with_metadata(pod_monitor: PodMonitor) -> None: + processed: list[PodEvent] = [] + + async def mock_process(event: PodEvent) -> None: + processed.append(event) + + pod_monitor._process_pod_event = mock_process # type: ignore[method-assign] + + raw_event = { + "type": "ADDED", + "object": types.SimpleNamespace(metadata=types.SimpleNamespace(resource_version="v1")), + } + + await pod_monitor._process_raw_event(raw_event) + assert len(processed) == 1 + assert processed[0].resource_version == "v1" + + raw_event_no_meta = {"type": "MODIFIED", "object": types.SimpleNamespace(metadata=None)} + + await pod_monitor._process_raw_event(raw_event_no_meta) + assert len(processed) == 2 + assert processed[1].resource_version is None + + +@pytest.mark.asyncio +async def test_watch_loop_api_exception_other_status(pod_monitor: PodMonitor) -> None: + pod_monitor.config.enable_state_reconciliation = False + call_count = 0 + backoff_count = 0 + + async def mock_run_watch() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ApiException(status=500) + raise asyncio.CancelledError() + + async def mock_backoff() -> None: + nonlocal backoff_count + backoff_count += 1 + + pod_monitor._run_watch = mock_run_watch # type: ignore[method-assign] + pod_monitor._backoff = mock_backoff # type: ignore[method-assign] + + with pytest.raises(asyncio.CancelledError): + await pod_monitor._watch_loop() + + assert backoff_count == 1 + + +@pytest.mark.asyncio +async def test_watch_loop_with_reconciliation(pod_monitor: PodMonitor) -> None: + """Test that reconciliation is called before each watch restart.""" + pod_monitor.config.enable_state_reconciliation = True + + reconcile_count = 0 + watch_count = 0 + + async def mock_reconcile() -> None: + nonlocal reconcile_count + reconcile_count += 1 + + async def mock_run_watch() -> None: + nonlocal watch_count + watch_count += 1 + if watch_count >= 2: + raise asyncio.CancelledError() + + pod_monitor._reconcile = mock_reconcile # type: ignore[method-assign] + pod_monitor._run_watch = mock_run_watch # type: ignore[method-assign] + + with pytest.raises(asyncio.CancelledError): + await pod_monitor._watch_loop() + + assert reconcile_count == 2 + assert watch_count == 2 + + +@pytest.mark.asyncio +async def test_publish_event_full_flow( + pod_monitor: PodMonitor, + fake_kafka_producer: FakeAIOKafkaProducer, +) -> None: + initial_count = len(fake_kafka_producer.sent_messages) + + event = ExecutionCompletedEvent( + execution_id="exec1", + aggregate_id="exec1", + exit_code=0, + resource_usage=ResourceUsageAvro(), + metadata=EventMetadata(service_name="test", service_version="1.0"), ) - return service, fake_producer + pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "exec1"}) + await pod_monitor._publish_event(event, pod) + + assert len(fake_kafka_producer.sent_messages) > initial_count -# ===== Helpers to create test instances with pure DI ===== + +# ===== Tests requiring custom setup ===== class SpyMapper: @@ -107,61 +272,28 @@ def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: # n return [] -def make_k8s_clients_di( - events: list[dict[str, Any]] | None = None, - resource_version: str = "rv1", - pods: list[Any] | None = None, -) -> K8sClients: - """Create K8sClients for DI with fakes.""" - v1, watch = make_k8s_clients(events=events, resource_version=resource_version, pods=pods) - return K8sClients( - api_client=MagicMock(), - v1=v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=watch, - ) - - -def make_pod_monitor( - event_metrics: EventMetrics, - kubernetes_metrics: KubernetesMetrics, - settings: Settings, - config: PodMonitorConfig | None = None, - kafka_service: KafkaEventService | None = None, - k8s_clients: K8sClients | None = None, - event_mapper: PodEventMapper | None = None, -) -> PodMonitor: - """Create PodMonitor with sensible test defaults.""" - cfg = config or PodMonitorConfig() - clients = k8s_clients or make_k8s_clients_di() - mapper = event_mapper or PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")) - service = kafka_service or create_test_kafka_event_service(event_metrics, settings)[0] - return PodMonitor( - config=cfg, - kafka_event_service=service, - logger=_test_logger, - k8s_clients=clients, - event_mapper=mapper, - kubernetes_metrics=kubernetes_metrics, - ) - - -# ===== Tests ===== - - @pytest.mark.asyncio async def test_run_and_cancel_lifecycle( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, + k8s_v1: FakeV1Api, + k8s_watch: FakeWatch, ) -> None: """Test that run() blocks until cancelled and cleans up on cancellation.""" pod_monitor_config.enable_state_reconciliation = False spy = SpyMapper() - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, event_mapper=spy) # type: ignore[arg-type] + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, + event_mapper=spy, # type: ignore[arg-type] + kubernetes_metrics=kubernetes_metrics, + ) - # Track when watch_loop is entered watch_started = asyncio.Event() async def _blocking_watch() -> None: @@ -170,80 +302,44 @@ async def _blocking_watch() -> None: pm._watch_loop = _blocking_watch # type: ignore[method-assign] - # Start run() as a task task = asyncio.create_task(pm.run()) - - # Wait until we're actually in the watch loop await asyncio.wait_for(watch_started.wait(), timeout=1.0) - # Cancel it - run() catches CancelledError and exits gracefully task.cancel() - await task # Should complete without raising (graceful shutdown) + await task - # Verify cleanup happened assert spy.cleared is True @pytest.mark.asyncio async def test_run_watch_flow_and_publish( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, ) -> None: pod_monitor_config.enable_state_reconciliation = False pod = make_pod(name="p", phase="Succeeded", labels={"execution-id": "e1"}, term_exit=0, resource_version="rv1") - k8s_clients = make_k8s_clients_di(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") + v1, watch = make_k8s_clients(events=[{"type": "MODIFIED", "object": pod}], resource_version="rv2") - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=v1, + k8s_watch=watch, + event_mapper=PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")), + kubernetes_metrics=kubernetes_metrics, ) await pm._run_watch() assert pm._last_resource_version == "rv2" -@pytest.mark.asyncio -async def test_process_raw_event_invalid_and_backoff( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - await pm._process_raw_event({}) - - pm.config.watch_reconnect_delay = 0 - pm._reconnect_attempts = 0 - await pm._backoff() - await pm._backoff() - assert pm._reconnect_attempts >= 2 - - -@pytest.mark.asyncio -async def test_get_status( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pod_monitor_config.namespace = "test-ns" - pod_monitor_config.label_selector = "app=test" - pod_monitor_config.enable_state_reconciliation = True - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - pm._tracked_pods = {"pod1", "pod2"} - pm._reconnect_attempts = 3 - pm._last_resource_version = "v123" - - status = await pm.get_status() - assert status["tracked_pods"] == 2 - assert status["reconnect_attempts"] == 3 - assert status["last_resource_version"] == "v123" - assert status["config"]["namespace"] == "test-ns" - assert status["config"]["label_selector"] == "app=test" - assert status["config"]["enable_reconciliation"] is True - - @pytest.mark.asyncio async def test_reconcile_success( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, ) -> None: pod_monitor_config.namespace = "test" @@ -251,10 +347,16 @@ async def test_reconcile_success( pod1 = make_pod(name="pod1", phase="Running", resource_version="v1") pod2 = make_pod(name="pod2", phase="Running", resource_version="v1") - k8s_clients = make_k8s_clients_di(pods=[pod1, pod2]) + v1, watch = make_k8s_clients(pods=[pod1, pod2]) - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=v1, + k8s_watch=watch, + event_mapper=PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")), + kubernetes_metrics=kubernetes_metrics, ) pm._tracked_pods = {"pod2", "pod3"} @@ -267,32 +369,28 @@ async def mock_process(event: PodEvent) -> None: await pm._reconcile() - # pod1 was missing and should have been processed assert "pod1" in processed - # pod3 was extra and should have been removed from tracking assert "pod3" not in pm._tracked_pods @pytest.mark.asyncio async def test_reconcile_exception( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, ) -> None: class FailV1(FakeV1Api): def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: raise RuntimeError("API error") - fail_v1 = FailV1() - k8s_clients = K8sClients( - api_client=MagicMock(), - v1=fail_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=make_watch([]), - ) - - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=FailV1(), + k8s_watch=make_watch([]), + event_mapper=PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")), + kubernetes_metrics=kubernetes_metrics, ) # Should not raise - errors are caught and logged @@ -301,8 +399,11 @@ def list_namespaced_pod(self, namespace: str, label_selector: str) -> Any: @pytest.mark.asyncio async def test_process_pod_event_full_flow( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, + k8s_v1: FakeV1Api, + k8s_watch: FakeWatch, ) -> None: pod_monitor_config.ignored_pod_phases = ["Unknown"] @@ -318,10 +419,14 @@ class Event: def clear_cache(self) -> None: pass - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, + pm = PodMonitor( config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, event_mapper=MockMapper(), # type: ignore[arg-type] + kubernetes_metrics=kubernetes_metrics, ) published: list[Any] = [] @@ -365,8 +470,11 @@ async def mock_publish(event: Any, pod: Any) -> None: # noqa: ARG001 @pytest.mark.asyncio async def test_process_pod_event_exception_handling( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, + k8s_v1: FakeV1Api, + k8s_watch: FakeWatch, ) -> None: class FailMapper: def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: @@ -375,10 +483,14 @@ def map_pod_event(self, pod: Any, event_type: WatchEventType) -> list[Any]: def clear_cache(self) -> None: pass - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, + pm = PodMonitor( config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, event_mapper=FailMapper(), # type: ignore[arg-type] + kubernetes_metrics=kubernetes_metrics, ) event = PodEvent( @@ -387,59 +499,47 @@ def clear_cache(self) -> None: resource_version=None, ) - # Should not raise - errors are caught and logged + # Should not raise await pm._process_pod_event(event) @pytest.mark.asyncio -async def test_publish_event_full_flow( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, +async def test_publish_event_exception_handling( + unit_container: AsyncContainer, + event_metrics: EventMetrics, + kubernetes_metrics: KubernetesMetrics, + test_settings: Settings, pod_monitor_config: PodMonitorConfig, + k8s_v1: FakeV1Api, + k8s_watch: FakeWatch, ) -> None: - service, fake_producer = create_test_kafka_event_service(event_metrics, test_settings) - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, kafka_service=service - ) - - event = ExecutionCompletedEvent( - execution_id="exec1", - aggregate_id="exec1", - exit_code=0, - resource_usage=ResourceUsageAvro(), - metadata=EventMetadata(service_name="test", service_version="1.0"), - ) - - pod = make_pod(name="test-pod", phase="Succeeded", labels={"execution-id": "exec1"}) - await pm._publish_event(event, pod) - - assert len(fake_producer.produced_events) == 1 - assert fake_producer.produced_events[0][1] == "exec1" + event_repo = await unit_container.get(EventRepository) + class FailingProducer(UnifiedProducer): + def __init__(self) -> None: + self.logger = _test_logger -@pytest.mark.asyncio -async def test_publish_event_exception_handling( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - class FailingProducer(FakeUnifiedProducer): async def produce( self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None ) -> None: raise RuntimeError("Publish failed") - # Create service with failing producer - failing_producer = FailingProducer() - fake_repo = FakeEventRepository() failing_service = KafkaEventService( - event_repository=fake_repo, - kafka_producer=failing_producer, + event_repository=event_repo, + kafka_producer=FailingProducer(), settings=test_settings, logger=_test_logger, event_metrics=event_metrics, ) - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, kafka_service=failing_service + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=failing_service, + logger=_test_logger, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, + event_mapper=PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")), + kubernetes_metrics=kubernetes_metrics, ) event = ExecutionStartedEvent( @@ -448,119 +548,17 @@ async def produce( metadata=EventMetadata(service_name="test", service_version="1.0"), ) - # Use pod with no metadata to exercise edge case pod = make_pod(name="no-meta-pod", phase="Pending") pod.metadata = None # type: ignore[assignment] - # Should not raise - errors are caught and logged + # Should not raise await pm._publish_event(event, pod) -@pytest.mark.asyncio -async def test_backoff_max_attempts( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pod_monitor_config.max_reconnect_attempts = 2 - - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - pm._reconnect_attempts = 2 - - with pytest.raises(RuntimeError, match="Max reconnect attempts exceeded"): - await pm._backoff() - - -@pytest.mark.asyncio -async def test_watch_loop_with_cancellation( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pod_monitor_config.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - watch_count: list[int] = [] - - async def mock_run_watch() -> None: - watch_count.append(1) - if len(watch_count) >= 3: - raise asyncio.CancelledError() - - pm._run_watch = mock_run_watch # type: ignore[method-assign] - - # watch_loop propagates CancelledError (correct behavior for structured concurrency) - with pytest.raises(asyncio.CancelledError): - await pm._watch_loop() - - assert len(watch_count) == 3 - - -@pytest.mark.asyncio -async def test_watch_loop_api_exception_410( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pod_monitor_config.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - pm._last_resource_version = "v123" - call_count = 0 - - async def mock_run_watch() -> None: - nonlocal call_count - call_count += 1 - if call_count == 1: - raise ApiException(status=410) - raise asyncio.CancelledError() - - async def mock_backoff() -> None: - pass - - pm._run_watch = mock_run_watch # type: ignore[method-assign] - pm._backoff = mock_backoff # type: ignore[method-assign] - - # watch_loop propagates CancelledError - with pytest.raises(asyncio.CancelledError): - await pm._watch_loop() - - # Resource version should be reset on 410 - assert pm._last_resource_version is None - - -@pytest.mark.asyncio -async def test_watch_loop_generic_exception( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pod_monitor_config.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - call_count = 0 - backoff_count = 0 - - async def mock_run_watch() -> None: - nonlocal call_count - call_count += 1 - if call_count == 1: - raise RuntimeError("Unexpected error") - raise asyncio.CancelledError() - - async def mock_backoff() -> None: - nonlocal backoff_count - backoff_count += 1 - - pm._run_watch = mock_run_watch # type: ignore[method-assign] - pm._backoff = mock_backoff # type: ignore[method-assign] - - # watch_loop propagates CancelledError - with pytest.raises(asyncio.CancelledError): - await pm._watch_loop() - - assert backoff_count == 1 - - @pytest.mark.asyncio async def test_pod_monitor_run_lifecycle( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, ) -> None: """Test PodMonitor lifecycle via run() method.""" @@ -568,31 +566,20 @@ async def test_pod_monitor_run_lifecycle( mock_v1 = FakeV1Api() mock_watch = make_watch([]) - mock_k8s_clients = K8sClients( - api_client=MagicMock(), - v1=mock_v1, - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=mock_watch, - ) - - service, _ = create_test_kafka_event_service(event_metrics, test_settings) event_mapper = PodEventMapper(logger=_test_logger, k8s_api=mock_v1) monitor = PodMonitor( config=pod_monitor_config, - kafka_event_service=service, + kafka_event_service=kafka_event_service, logger=_test_logger, - k8s_clients=mock_k8s_clients, + k8s_v1=mock_v1, + k8s_watch=mock_watch, event_mapper=event_mapper, kubernetes_metrics=kubernetes_metrics, ) - # Verify DI wiring - assert monitor._clients is mock_k8s_clients assert monitor._v1 is mock_v1 - # Track when watch_loop is entered watch_started = asyncio.Event() async def _blocking_watch() -> None: @@ -601,48 +588,31 @@ async def _blocking_watch() -> None: monitor._watch_loop = _blocking_watch # type: ignore[method-assign] - # Start and cancel - run() exits gracefully on cancel task = asyncio.create_task(monitor.run()) await asyncio.wait_for(watch_started.wait(), timeout=1.0) task.cancel() - await task # Should complete without raising - - -@pytest.mark.asyncio -async def test_cleanup_on_cancel( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - """Test cleanup of tracked pods on cancellation.""" - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - watch_started = asyncio.Event() - - # Replace _watch_loop to add tracked pods and wait - async def _blocking_watch() -> None: - pm._tracked_pods = {"pod1"} - watch_started.set() - await asyncio.sleep(10) - - pm._watch_loop = _blocking_watch # type: ignore[method-assign] - - task = asyncio.create_task(pm.run()) - await asyncio.wait_for(watch_started.wait(), timeout=1.0) - assert "pod1" in pm._tracked_pods - - # Cancel - run() exits gracefully - task.cancel() - await task # Should complete without raising - - # Cleanup should have cleared tracked pods - assert len(pm._tracked_pods) == 0 + await task def test_update_resource_version( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, + k8s_v1: FakeV1Api, + k8s_watch: FakeWatch, ) -> None: - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) + # Sync test needs minimal mock + class MockKafkaService: + pass + + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=MockKafkaService(), # type: ignore[arg-type] + logger=_test_logger, + k8s_v1=k8s_v1, + k8s_watch=k8s_watch, + event_mapper=PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")), + kubernetes_metrics=kubernetes_metrics, + ) class Stream: _stop_event = types.SimpleNamespace(resource_version="v123") @@ -656,71 +626,10 @@ class BadStream: pm._update_resource_version(BadStream()) -@pytest.mark.asyncio -async def test_process_raw_event_with_metadata( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - processed: list[PodEvent] = [] - - async def mock_process(event: PodEvent) -> None: - processed.append(event) - - pm._process_pod_event = mock_process # type: ignore[method-assign] - - raw_event = { - "type": "ADDED", - "object": types.SimpleNamespace(metadata=types.SimpleNamespace(resource_version="v1")), - } - - await pm._process_raw_event(raw_event) - assert len(processed) == 1 - assert processed[0].resource_version == "v1" - - raw_event_no_meta = {"type": "MODIFIED", "object": types.SimpleNamespace(metadata=None)} - - await pm._process_raw_event(raw_event_no_meta) - assert len(processed) == 2 - assert processed[1].resource_version is None - - -@pytest.mark.asyncio -async def test_watch_loop_api_exception_other_status( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - pod_monitor_config.enable_state_reconciliation = False - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - call_count = 0 - backoff_count = 0 - - async def mock_run_watch() -> None: - nonlocal call_count - call_count += 1 - if call_count == 1: - raise ApiException(status=500) - raise asyncio.CancelledError() - - async def mock_backoff() -> None: - nonlocal backoff_count - backoff_count += 1 - - pm._run_watch = mock_run_watch # type: ignore[method-assign] - pm._backoff = mock_backoff # type: ignore[method-assign] - - # watch_loop propagates CancelledError - with pytest.raises(asyncio.CancelledError): - await pm._watch_loop() - - assert backoff_count == 1 - - @pytest.mark.asyncio async def test_run_watch_with_field_selector( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, + kafka_event_service: KafkaEventService, + kubernetes_metrics: KubernetesMetrics, pod_monitor_config: PodMonitorConfig, ) -> None: pod_monitor_config.field_selector = "status.phase=Running" @@ -738,52 +647,16 @@ def stream(self, func: Any, **kwargs: Any) -> FakeWatchStream: watch_kwargs.append(kwargs) return FakeWatchStream([], "rv1") - k8s_clients = K8sClients( - api_client=MagicMock(), - v1=TrackingV1(), - apps_v1=MagicMock(), - networking_v1=MagicMock(), - watch=TrackingWatch([], "rv1"), - ) - - pm = make_pod_monitor( - event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config, k8s_clients=k8s_clients + pm = PodMonitor( + config=pod_monitor_config, + kafka_event_service=kafka_event_service, + logger=_test_logger, + k8s_v1=TrackingV1(), + k8s_watch=TrackingWatch([], "rv1"), + event_mapper=PodEventMapper(logger=_test_logger, k8s_api=FakeApi("{}")), + kubernetes_metrics=kubernetes_metrics, ) await pm._run_watch() assert any("field_selector" in kw for kw in watch_kwargs) - - -@pytest.mark.asyncio -async def test_watch_loop_with_reconciliation( - event_metrics: EventMetrics, kubernetes_metrics: KubernetesMetrics, test_settings: Settings, - pod_monitor_config: PodMonitorConfig, -) -> None: - """Test that reconciliation is called before each watch restart.""" - pod_monitor_config.enable_state_reconciliation = True - pm = make_pod_monitor(event_metrics, kubernetes_metrics, test_settings, config=pod_monitor_config) - - reconcile_count = 0 - watch_count = 0 - - async def mock_reconcile() -> None: - nonlocal reconcile_count - reconcile_count += 1 - - async def mock_run_watch() -> None: - nonlocal watch_count - watch_count += 1 - if watch_count >= 2: - raise asyncio.CancelledError() - - pm._reconcile = mock_reconcile # type: ignore[method-assign] - pm._run_watch = mock_run_watch # type: ignore[method-assign] - - # watch_loop propagates CancelledError - with pytest.raises(asyncio.CancelledError): - await pm._watch_loop() - - # Reconcile should be called before each watch restart - assert reconcile_count == 2 - assert watch_count == 2 diff --git a/backend/tests/unit/services/result_processor/test_processor.py b/backend/tests/unit/services/result_processor/test_processor.py deleted file mode 100644 index e230e5a9..00000000 --- a/backend/tests/unit/services/result_processor/test_processor.py +++ /dev/null @@ -1,28 +0,0 @@ -import logging -from unittest.mock import MagicMock - -import pytest -from app.core.metrics import ExecutionMetrics -from app.domain.enums.events import EventType -from app.events.core import EventDispatcher -from app.services.result_processor.processor_logic import ProcessorLogic - -pytestmark = pytest.mark.unit - -_test_logger = logging.getLogger("test.services.result_processor.processor") - - -def test_register_handlers_registers_expected_event_types(execution_metrics: ExecutionMetrics) -> None: - logic = ProcessorLogic( - execution_repo=MagicMock(), - producer=MagicMock(), - settings=MagicMock(), - logger=_test_logger, - execution_metrics=execution_metrics, - ) - dispatcher = EventDispatcher(logger=_test_logger) - logic.register_handlers(dispatcher) - - assert EventType.EXECUTION_COMPLETED in dispatcher._handlers - assert EventType.EXECUTION_FAILED in dispatcher._handlers - assert EventType.EXECUTION_TIMEOUT in dispatcher._handlers diff --git a/backend/tests/unit/services/saga/test_execution_saga_steps.py b/backend/tests/unit/services/saga/test_execution_saga_steps.py index 8c235076..ecbee035 100644 --- a/backend/tests/unit/services/saga/test_execution_saga_steps.py +++ b/backend/tests/unit/services/saga/test_execution_saga_steps.py @@ -1,7 +1,8 @@ import pytest +from aiokafka import AIOKafkaProducer +from app.db.docs import ResourceAllocationDocument from app.db.repositories.resource_allocation_repository import ResourceAllocationRepository -from app.domain.events.typed import DomainEvent, ExecutionRequestedEvent -from app.domain.saga import DomainResourceAllocation, DomainResourceAllocationCreate +from app.domain.events.typed import ExecutionRequestedEvent from app.events.core import UnifiedProducer from app.services.saga.execution_saga import ( AllocateResourcesStep, @@ -14,14 +15,15 @@ ValidateExecutionStep, ) from app.services.saga.saga_step import SagaContext +from dishka import AsyncContainer from tests.helpers import make_execution_requested_event pytestmark = pytest.mark.unit -def _req(timeout: int = 30, script: str = "print('x')") -> ExecutionRequestedEvent: - return make_execution_requested_event(execution_id="e1", script=script, timeout_seconds=timeout) +def _req(timeout: int = 30, script: str = "print('x')", execution_id: str = "e1") -> ExecutionRequestedEvent: + return make_execution_requested_event(execution_id=execution_id, script=script, timeout_seconds=timeout) @pytest.mark.asyncio @@ -42,50 +44,41 @@ async def test_validate_execution_step_success_and_failures() -> None: assert ok3 is False and ctx3.error is not None -class _FakeAllocRepo(ResourceAllocationRepository): - """Fake ResourceAllocationRepository for testing.""" - - def __init__(self, active: int = 0, alloc_id: str = "alloc-1") -> None: - self.active = active - self.alloc_id = alloc_id - self.released: list[str] = [] - - async def count_active(self, language: str) -> int: - return self.active - - async def create_allocation(self, create_data: DomainResourceAllocationCreate) -> DomainResourceAllocation: - return DomainResourceAllocation( - allocation_id=self.alloc_id, - execution_id=create_data.execution_id, - language=create_data.language, - cpu_request=create_data.cpu_request, - memory_request=create_data.memory_request, - cpu_limit=create_data.cpu_limit, - memory_limit=create_data.memory_limit, - ) - - async def release_allocation(self, allocation_id: str) -> bool: - self.released.append(allocation_id) - return True - - @pytest.mark.asyncio -async def test_allocate_resources_step_paths() -> None: - ctx = SagaContext("s1", "e1") - ctx.set("execution_id", "e1") - ok = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=0, alloc_id="alloc-1")).execute(ctx, _req()) - assert ok is True and ctx.get("resources_allocated") is True and ctx.get("allocation_id") == "alloc-1" +async def test_allocate_resources_step_paths(unit_container: AsyncContainer) -> None: + alloc_repo = await unit_container.get(ResourceAllocationRepository) + + # Test 1: Success path with clean repo + ctx = SagaContext("s1", "alloc-test-1") + ctx.set("execution_id", "alloc-test-1") + ok = await AllocateResourcesStep(alloc_repo=alloc_repo).execute(ctx, _req(execution_id="alloc-test-1")) + assert ok is True + assert ctx.get("resources_allocated") is True + assert ctx.get("allocation_id") is not None + + # Test 2: Limit exceeded (insert 100 active allocations) + for i in range(100): + doc = ResourceAllocationDocument( + allocation_id=f"limit-test-alloc-{i}", + execution_id=f"limit-test-exec-{i}", + language="python", + cpu_request="100m", + memory_request="128Mi", + cpu_limit="500m", + memory_limit="512Mi", + status="active", + ) + await doc.insert() - # Limit exceeded - ctx2 = SagaContext("s2", "e2") - ctx2.set("execution_id", "e2") - ok2 = await AllocateResourcesStep(alloc_repo=_FakeAllocRepo(active=100)).execute(ctx2, _req()) + ctx2 = SagaContext("s2", "limit-test-main") + ctx2.set("execution_id", "limit-test-main") + ok2 = await AllocateResourcesStep(alloc_repo=alloc_repo).execute(ctx2, _req(execution_id="limit-test-main")) assert ok2 is False - # Missing repo + # Test 3: Missing repo ctx3 = SagaContext("s3", "e3") ctx3.set("execution_id", "e3") - ok3 = await AllocateResourcesStep(alloc_repo=None).execute(ctx3, _req()) + ok3 = await AllocateResourcesStep(alloc_repo=None).execute(ctx3, _req(execution_id="e3")) assert ok3 is False @@ -109,80 +102,105 @@ def set(self, key: str, value: object) -> None: assert await MonitorExecutionStep().execute(bad, _req()) is False -class _FakeProducer(UnifiedProducer): - """Fake UnifiedProducer for testing.""" - - def __init__(self) -> None: - self.events: list[DomainEvent] = [] - - async def produce(self, event_to_produce: DomainEvent, key: str | None = None, - headers: dict[str, str] | None = None) -> None: - self.events.append(event_to_produce) - - @pytest.mark.asyncio -async def test_create_pod_step_publish_flag_and_compensation() -> None: - ctx = SagaContext("s1", "e1") - ctx.set("execution_id", "e1") +async def test_create_pod_step_publish_flag_and_compensation(unit_container: AsyncContainer) -> None: + producer = await unit_container.get(UnifiedProducer) + kafka_producer = await unit_container.get(AIOKafkaProducer) + # Skip publish path + ctx = SagaContext("s1", "skip-publish-test") + ctx.set("execution_id", "skip-publish-test") s1 = CreatePodStep(producer=None, publish_commands=False) - ok1 = await s1.execute(ctx, _req()) + ok1 = await s1.execute(ctx, _req(execution_id="skip-publish-test")) assert ok1 is True and ctx.get("pod_creation_triggered") is False # Publish path succeeds - ctx2 = SagaContext("s2", "e2") - ctx2.set("execution_id", "e2") - prod = _FakeProducer() - s2 = CreatePodStep(producer=prod, publish_commands=True) - ok2 = await s2.execute(ctx2, _req()) - assert ok2 is True and ctx2.get("pod_creation_triggered") is True and prod.events + ctx2 = SagaContext("s2", "publish-test") + ctx2.set("execution_id", "publish-test") + initial_count = len(kafka_producer.sent_messages) + s2 = CreatePodStep(producer=producer, publish_commands=True) + ok2 = await s2.execute(ctx2, _req(execution_id="publish-test")) + assert ok2 is True + assert ctx2.get("pod_creation_triggered") is True + assert len(kafka_producer.sent_messages) > initial_count # Missing producer -> failure - ctx3 = SagaContext("s3", "e3") - ctx3.set("execution_id", "e3") + ctx3 = SagaContext("s3", "missing-producer-test") + ctx3.set("execution_id", "missing-producer-test") s3 = CreatePodStep(producer=None, publish_commands=True) - ok3 = await s3.execute(ctx3, _req()) + ok3 = await s3.execute(ctx3, _req(execution_id="missing-producer-test")) assert ok3 is False and ctx3.error is not None # DeletePod compensation triggers only when flagged and producer exists - comp = DeletePodCompensation(producer=prod) + comp = DeletePodCompensation(producer=producer) ctx2.set("pod_creation_triggered", True) assert await comp.compensate(ctx2) is True @pytest.mark.asyncio -async def test_release_resources_compensation() -> None: - repo = _FakeAllocRepo() - comp = ReleaseResourcesCompensation(alloc_repo=repo) - ctx = SagaContext("s1", "e1") - ctx.set("allocation_id", "alloc-1") - assert await comp.compensate(ctx) is True and repo.released == ["alloc-1"] +async def test_release_resources_compensation(unit_container: AsyncContainer) -> None: + alloc_repo = await unit_container.get(ResourceAllocationRepository) + + # Create an allocation via repo + from app.domain.saga import DomainResourceAllocationCreate + + create_data = DomainResourceAllocationCreate( + execution_id="release-comp-test", + language="python", + cpu_request="100m", + memory_request="128Mi", + cpu_limit="500m", + memory_limit="512Mi", + ) + allocation = await alloc_repo.create_allocation(create_data) + + # Verify allocation was created with status="active" + doc = await ResourceAllocationDocument.find_one( + ResourceAllocationDocument.allocation_id == allocation.allocation_id + ) + assert doc is not None + assert doc.status == "active" + + # Release via compensation + comp = ReleaseResourcesCompensation(alloc_repo=alloc_repo) + ctx = SagaContext("s1", "release-comp-test") + ctx.set("allocation_id", allocation.allocation_id) + assert await comp.compensate(ctx) is True + + # Verify allocation was released + doc_after = await ResourceAllocationDocument.find_one( + ResourceAllocationDocument.allocation_id == allocation.allocation_id + ) + assert doc_after is not None + assert doc_after.status == "released" # Missing repo -> failure comp2 = ReleaseResourcesCompensation(alloc_repo=None) assert await comp2.compensate(ctx) is False + # Missing allocation_id -> True short-circuit ctx2 = SagaContext("sX", "eX") - assert await ReleaseResourcesCompensation(alloc_repo=repo).compensate(ctx2) is True + assert await ReleaseResourcesCompensation(alloc_repo=alloc_repo).compensate(ctx2) is True @pytest.mark.asyncio -async def test_delete_pod_compensation_variants() -> None: +async def test_delete_pod_compensation_variants(unit_container: AsyncContainer) -> None: # Not triggered -> True early comp_none = DeletePodCompensation(producer=None) - ctx = SagaContext("s", "e") + ctx = SagaContext("s", "delete-pod-test-1") ctx.set("pod_creation_triggered", False) assert await comp_none.compensate(ctx) is True # Triggered but missing producer -> False - ctx2 = SagaContext("s2", "e2") + ctx2 = SagaContext("s2", "delete-pod-test-2") ctx2.set("pod_creation_triggered", True) - ctx2.set("execution_id", "e2") + ctx2.set("execution_id", "delete-pod-test-2") assert await comp_none.compensate(ctx2) is False # Exercise get_compensation methods return types (coverage for lines returning comps/None) + alloc_repo = await unit_container.get(ResourceAllocationRepository) assert ValidateExecutionStep().get_compensation() is None - assert isinstance(AllocateResourcesStep(_FakeAllocRepo()).get_compensation(), ReleaseResourcesCompensation) + assert isinstance(AllocateResourcesStep(alloc_repo).get_compensation(), ReleaseResourcesCompensation) assert isinstance(QueueExecutionStep().get_compensation(), type(DeletePodCompensation(None)).__bases__[0]) or True assert CreatePodStep(None, publish_commands=False).get_compensation() is not None assert MonitorExecutionStep().get_compensation() is None diff --git a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py index 848cc21d..198f02ad 100644 --- a/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py +++ b/backend/tests/unit/services/saga/test_saga_orchestrator_unit.py @@ -6,12 +6,13 @@ from app.db.repositories.saga_repository import SagaRepository from app.domain.enums.events import EventType from app.domain.enums.saga import SagaState -from app.domain.events.typed import DomainEvent, ExecutionRequestedEvent +from app.domain.events.typed import ExecutionRequestedEvent from app.domain.saga.models import Saga, SagaConfig from app.events.core import UnifiedProducer from app.services.saga.base_saga import BaseSaga from app.services.saga.saga_logic import SagaLogic from app.services.saga.saga_step import CompensationStep, SagaContext, SagaStep +from dishka import AsyncContainer from tests.helpers import make_execution_requested_event @@ -20,40 +21,6 @@ _test_logger = logging.getLogger("test.services.saga.orchestrator") -class _FakeRepo(SagaRepository): - """Fake SagaRepository for testing.""" - - def __init__(self) -> None: - self.saved: list[Saga] = [] - self.existing: dict[tuple[str, str], Saga] = {} - - async def get_saga_by_execution_and_name(self, execution_id: str, saga_name: str) -> Saga | None: - return self.existing.get((execution_id, saga_name)) - - async def upsert_saga(self, saga: Saga) -> bool: - self.saved.append(saga) - return True - - -class _FakeProd(UnifiedProducer): - """Fake UnifiedProducer for testing.""" - - def __init__(self) -> None: - pass # Skip parent __init__ - - async def produce( - self, event_to_produce: DomainEvent, key: str | None = None, headers: dict[str, str] | None = None - ) -> None: - return None - - -class _FakeAlloc(ResourceAllocationRepository): - """Fake ResourceAllocationRepository for testing.""" - - def __init__(self) -> None: - pass # No special attributes needed - - class _StepOK(SagaStep[ExecutionRequestedEvent]): def __init__(self) -> None: super().__init__("ok") @@ -78,42 +45,63 @@ def get_steps(self) -> list[SagaStep[ExecutionRequestedEvent]]: return [_StepOK()] -def _logic(event_metrics: EventMetrics) -> SagaLogic: - return SagaLogic( +@pytest.mark.asyncio +async def test_min_success_flow( + unit_container: AsyncContainer, + event_metrics: EventMetrics, +) -> None: + saga_repo = await unit_container.get(SagaRepository) + producer = await unit_container.get(UnifiedProducer) + alloc_repo = await unit_container.get(ResourceAllocationRepository) + + logic = SagaLogic( config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), - saga_repository=_FakeRepo(), - producer=_FakeProd(), - resource_allocation_repository=_FakeAlloc(), + saga_repository=saga_repo, + producer=producer, + resource_allocation_repository=alloc_repo, logger=_test_logger, event_metrics=event_metrics, ) - - -@pytest.mark.asyncio -async def test_min_success_flow(event_metrics: EventMetrics) -> None: - logic = _logic(event_metrics) logic.register_saga(_Saga) - # Handle the event + await logic.handle_event(make_execution_requested_event(execution_id="e")) - # basic sanity; deep behavior covered by integration + assert len(logic._sagas) > 0 # noqa: SLF001 @pytest.mark.asyncio -async def test_should_trigger_and_existing_short_circuit(event_metrics: EventMetrics) -> None: - fake_repo = _FakeRepo() +async def test_should_trigger_and_existing_short_circuit( + unit_container: AsyncContainer, + event_metrics: EventMetrics, +) -> None: + saga_repo = await unit_container.get(SagaRepository) + producer = await unit_container.get(UnifiedProducer) + alloc_repo = await unit_container.get(ResourceAllocationRepository) + logic = SagaLogic( config=SagaConfig(name="t", enable_compensation=True, store_events=True, publish_commands=False), - saga_repository=fake_repo, - producer=_FakeProd(), - resource_allocation_repository=_FakeAlloc(), + saga_repository=saga_repo, + producer=producer, + resource_allocation_repository=alloc_repo, logger=_test_logger, event_metrics=event_metrics, ) logic.register_saga(_Saga) - assert logic._should_trigger_saga(_Saga, make_execution_requested_event(execution_id="e")) is True # noqa: SLF001 + + # Use unique execution_id to avoid conflicts with other tests + exec_id = "test-short-circuit-exec" + + assert logic._should_trigger_saga(_Saga, make_execution_requested_event(execution_id=exec_id)) is True # noqa: SLF001 + + # Create existing saga in real repo + existing_saga = Saga(saga_id="sX", saga_name="s", execution_id=exec_id, state=SagaState.RUNNING) + await saga_repo.upsert_saga(existing_saga) + + # Verify it was saved correctly + found = await saga_repo.get_saga_by_execution_and_name(exec_id, "s") + assert found is not None, "Saga should be found after upsert" + assert found.saga_id == "sX" + # Existing short-circuit returns existing ID - s = Saga(saga_id="sX", saga_name="s", execution_id="e", state=SagaState.RUNNING) - fake_repo.existing[("e", "s")] = s - sid = await logic._start_saga("s", make_execution_requested_event(execution_id="e")) # noqa: SLF001 + sid = await logic._start_saga("s", make_execution_requested_event(execution_id=exec_id)) # noqa: SLF001 assert sid == "sX" diff --git a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py b/backend/tests/unit/services/sse/test_kafka_redis_bridge.py deleted file mode 100644 index a1204957..00000000 --- a/backend/tests/unit/services/sse/test_kafka_redis_bridge.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging - -import pytest -from app.domain.enums.events import EventType -from app.domain.events.typed import DomainEvent, EventMetadata, ExecutionStartedEvent -from app.events.core import EventDispatcher -from app.services.sse.event_router import SSEEventRouter -from app.services.sse.redis_bus import SSERedisBus - -pytestmark = pytest.mark.unit - -_test_logger = logging.getLogger("test.services.sse.event_router") - - -class _FakeBus(SSERedisBus): - """Fake SSERedisBus for testing.""" - - def __init__(self) -> None: - self.published: list[tuple[str, DomainEvent]] = [] - - async def publish_event(self, execution_id: str, event: DomainEvent) -> None: - self.published.append((execution_id, event)) - - -def _make_metadata() -> EventMetadata: - return EventMetadata(service_name="test", service_version="1.0") - - -@pytest.mark.asyncio -async def test_event_router_registers_and_routes_events() -> None: - """Test that SSEEventRouter registers handlers and routes events to Redis.""" - fake_bus = _FakeBus() - router = SSEEventRouter(sse_bus=fake_bus, logger=_test_logger) - - # Register handlers with dispatcher - disp = EventDispatcher(_test_logger) - router.register_handlers(disp) - - # Verify handler was registered - handlers = disp.get_handlers(EventType.EXECUTION_STARTED) - assert len(handlers) > 0 - - # Event with empty execution_id is ignored - h = handlers[0] - await h(ExecutionStartedEvent(execution_id="", pod_name="p", metadata=_make_metadata())) - assert fake_bus.published == [] - - # Proper event is published - await h(ExecutionStartedEvent(execution_id="exec-123", pod_name="p", metadata=_make_metadata())) - assert fake_bus.published and fake_bus.published[-1][0] == "exec-123" diff --git a/backend/tests/unit/services/sse/test_sse_service.py b/backend/tests/unit/services/sse/test_sse_service.py index 4174ee57..35725ce2 100644 --- a/backend/tests/unit/services/sse/test_sse_service.py +++ b/backend/tests/unit/services/sse/test_sse_service.py @@ -115,7 +115,7 @@ async def test_execution_stream_closes_on_failed_event(connection_metrics: Conne repo = _FakeRepo() bus = _FakeBus() registry = _FakeRegistry() - svc = SSEService(repository=repo, num_consumers=3, sse_bus=bus, connection_registry=registry, + svc = SSEService(repository=repo, sse_bus=bus, connection_registry=registry, settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) agen = svc.create_execution_stream("exec-1", user_id="u1") @@ -159,7 +159,7 @@ async def test_execution_stream_result_stored_includes_result_payload(connection ) bus = _FakeBus() registry = _FakeRegistry() - svc = SSEService(repository=repo, num_consumers=3, sse_bus=bus, connection_registry=registry, + svc = SSEService(repository=repo, sse_bus=bus, connection_registry=registry, settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) agen = svc.create_execution_stream("exec-2", user_id="u1") @@ -184,7 +184,7 @@ async def test_notification_stream_connected_and_heartbeat_and_message(connectio registry = _FakeRegistry() settings = _make_fake_settings() settings.SSE_HEARTBEAT_INTERVAL = 0 # emit immediately - svc = SSEService(repository=repo, num_consumers=3, sse_bus=bus, connection_registry=registry, settings=settings, + svc = SSEService(repository=repo, sse_bus=bus, connection_registry=registry, settings=settings, logger=_test_logger, connection_metrics=connection_metrics) agen = svc.create_notification_stream("u1") @@ -221,10 +221,11 @@ async def test_notification_stream_connected_and_heartbeat_and_message(connectio async def test_health_status_shape(connection_metrics: ConnectionMetrics) -> None: # Create registry with 2 active connections and 2 executions for testing registry = _FakeRegistry(active_connections=2, active_executions=2) - svc = SSEService(repository=_FakeRepo(), num_consumers=3, sse_bus=_FakeBus(), connection_registry=registry, + svc = SSEService(repository=_FakeRepo(), sse_bus=_FakeBus(), connection_registry=registry, settings=_make_fake_settings(), logger=_test_logger, connection_metrics=connection_metrics) h = await svc.get_health_status() assert isinstance(h, SSEHealthDomain) - assert h.active_consumers == 3 + # Consumers now run in separate SSE bridge worker + assert h.active_consumers == 0 assert h.active_connections == 2 assert h.active_executions == 2 diff --git a/backend/uv.lock b/backend/uv.lock index b058c135..8cee706e 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -687,6 +687,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, ] +[[package]] +name = "fakeredis" +version = "2.33.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "redis" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/f9/57464119936414d60697fcbd32f38909bb5688b616ae13de6e98384433e0/fakeredis-2.33.0.tar.gz", hash = "sha256:d7bc9a69d21df108a6451bbffee23b3eba432c21a654afc7ff2d295428ec5770", size = 175187, upload-time = "2025-12-16T19:45:52.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/78/a850fed8aeef96d4a99043c90b818b2ed5419cd5b24a4049fd7cfb9f1471/fakeredis-2.33.0-py3-none-any.whl", hash = "sha256:de535f3f9ccde1c56672ab2fdd6a8efbc4f2619fc2f1acc87b8737177d71c965", size = 119605, upload-time = "2025-12-16T19:45:51.08Z" }, +] + [[package]] name = "fast-depends" version = "3.0.5" @@ -1222,9 +1235,11 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "coverage" }, + { name = "fakeredis" }, { name = "hypothesis" }, { name = "iniconfig" }, { name = "matplotlib" }, + { name = "mongomock-motor" }, { name = "mypy" }, { name = "mypy-extensions" }, { name = "pipdeptree" }, @@ -1366,9 +1381,11 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "coverage", specifier = "==7.13.0" }, + { name = "fakeredis", specifier = ">=2.33.0" }, { name = "hypothesis", specifier = "==6.103.4" }, { name = "iniconfig", specifier = "==2.0.0" }, { name = "matplotlib", specifier = "==3.10.8" }, + { name = "mongomock-motor", specifier = ">=0.0.36" }, { name = "mypy", specifier = "==1.17.1" }, { name = "mypy-extensions", specifier = "==1.1.0" }, { name = "pipdeptree", specifier = "==2.23.4" }, @@ -1678,6 +1695,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/77/2f37358731fdf228379fb9bdc0c736371691c0493075c393ee431e42b908/monggregate-0.22.1-py3-none-any.whl", hash = "sha256:4eef7839109ce4b1bb1172b6643fa22e2dc284a45e645ea55fd4efd848aedfb2", size = 169108, upload-time = "2025-08-24T15:00:55.959Z" }, ] +[[package]] +name = "mongomock" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pytz" }, + { name = "sentinels" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/a4/4a560a9f2a0bec43d5f63104f55bc48666d619ca74825c8ae156b08547cf/mongomock-4.3.0.tar.gz", hash = "sha256:32667b79066fabc12d4f17f16a8fd7361b5f4435208b3ba32c226e52212a8c30", size = 135862, upload-time = "2024-11-16T11:23:25.957Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/4d/8bea712978e3aff017a2ab50f262c620e9239cc36f348aae45e48d6a4786/mongomock-4.3.0-py2.py3-none-any.whl", hash = "sha256:5ef86bd12fc8806c6e7af32f21266c61b6c4ba96096f85129852d1c4fec1327e", size = 64891, upload-time = "2024-11-16T11:23:24.748Z" }, +] + +[[package]] +name = "mongomock-motor" +version = "0.0.36" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mongomock" }, + { name = "motor" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/9f/38e42a34ebad323addaf6296d6b5d83eaf2c423adf206b757c68315e196a/mongomock_motor-0.0.36.tar.gz", hash = "sha256:3cf62352ece5af2f02e04d2f252393f88b5fe0487997da00584020cee4b8efba", size = 5754, upload-time = "2025-05-16T22:52:27.214Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d6/99/f5fdbbdc96bfd03e5f9c36339547a9076f5dbb5882900b7621526d41a38d/mongomock_motor-0.0.36-py3-none-any.whl", hash = "sha256:3ecb7949662b8986ff9c267fa0b1402b5b75a6afd57f03850cd6e13a067e3691", size = 7334, upload-time = "2025-05-16T22:52:25.417Z" }, +] + +[[package]] +name = "motor" +version = "3.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymongo" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/ae/96b88362d6a84cb372f7977750ac2a8aed7b2053eed260615df08d5c84f4/motor-3.7.1.tar.gz", hash = "sha256:27b4d46625c87928f331a6ca9d7c51c2f518ba0e270939d395bc1ddc89d64526", size = 280997, upload-time = "2025-05-14T18:56:33.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/9a/35e053d4f442addf751ed20e0e922476508ee580786546d699b0567c4c67/motor-3.7.1-py3-none-any.whl", hash = "sha256:8a63b9049e38eeeb56b4fdd57c3312a6d1f25d01db717fe7d82222393c410298", size = 74996, upload-time = "2025-05-14T18:56:31.665Z" }, +] + [[package]] name = "msgpack" version = "1.1.0" @@ -2695,6 +2751,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/c1/abd18fc3c23dbe09321fcd812091320d4dc954046f95cb431ef2926cb11c/python_schema_registry_client-2.6.1-py3-none-any.whl", hash = "sha256:05950ca8f9a3409247514bef3fdb421839d6e1ae544b32dfd3b7b16237673303", size = 23095, upload-time = "2025-04-04T15:07:49.592Z" }, ] +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -2998,6 +3063,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/31/b0e29d572670dca3674eeee78e418f20bdf97fa8aa9ea71380885e175ca0/ruff-0.14.10-py3-none-win_arm64.whl", hash = "sha256:e51d046cf6dda98a4633b8a8a771451107413b0f07183b2bef03f075599e44e6", size = 13729839, upload-time = "2025-12-18T19:28:48.636Z" }, ] +[[package]] +name = "sentinels" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/9b/07195878aa25fe6ed209ec74bc55ae3e3d263b60a489c6e73fdca3c8fe05/sentinels-1.1.1.tar.gz", hash = "sha256:3c2f64f754187c19e0a1a029b148b74cf58dd12ec27b4e19c0e5d6e22b5a9a86", size = 4393, upload-time = "2025-08-12T07:57:50.26Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/65/dea992c6a97074f6d8ff9eab34741298cac2ce23e2b6c74fb7d08afdf85c/sentinels-1.1.1-py3-none-any.whl", hash = "sha256:835d3b28f3b47f5284afa4bf2db6e00f2dc5f80f9923d4b7e7aeeeccf6146a11", size = 3744, upload-time = "2025-08-12T07:57:48.858Z" }, +] + [[package]] name = "setuptools" version = "80.9.0" diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py index 95c38dad..51bfbaa1 100644 --- a/backend/workers/run_event_replay.py +++ b/backend/workers/run_event_replay.py @@ -34,14 +34,15 @@ async def run_replay_service(settings: Settings) -> None: db = await container.get(Database) await init_beanie(database=db, document_models=ALL_DOCUMENTS) - producer = await container.get(UnifiedProducer) + # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) + await container.get(UnifiedProducer) replay_service = await container.get(EventReplayService) logger.info("Event replay service initialized") async with AsyncExitStack() as stack: + # Container close stops Kafka producer via DI provider stack.push_async_callback(container.close) - await stack.enter_async_context(producer) task = asyncio.create_task(cleanup_task(replay_service, logger)) diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 9037ec42..7eb8a63e 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -18,12 +18,13 @@ from app.core.logging import setup_logger from app.core.providers import ( + BoundaryClientProvider, EventProvider, K8sWorkerProvider, LoggingProvider, MessagingProvider, MetricsProvider, - RedisProvider, + RedisServicesProvider, SettingsProvider, ) from app.core.tracing import init_tracing @@ -34,6 +35,7 @@ DeletePodCommandEvent, DomainEvent, ) +from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.k8s_worker.worker_logic import K8sWorkerLogic @@ -74,7 +76,8 @@ def main() -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), MetricsProvider(), EventProvider(), MessagingProvider(), @@ -106,6 +109,10 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) + # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) + await container.get(UnifiedProducer) + app_logger.info("Kafka producer ready") + # Get worker logic and ensure daemonset (one-time initialization) logic = await container.get(K8sWorkerLogic) await logic.ensure_image_pre_puller_daemonset() @@ -128,14 +135,14 @@ async def handle_create_pod_command( event: CreatePodCommandEvent, worker_logic: FromDishka[K8sWorkerLogic], ) -> None: - await worker_logic._handle_create_pod_command(event) + await worker_logic.handle_create_pod_command(event) @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.DELETE_POD_COMMAND.encode()) async def handle_delete_pod_command( event: DeletePodCommandEvent, worker_logic: FromDishka[K8sWorkerLogic], ) -> None: - await worker_logic._handle_delete_pod_command(event) + await worker_logic.handle_delete_pod_command(event) # Default handler for unmatched events (prevents message loss) @subscriber() @@ -144,12 +151,15 @@ async def handle_other(event: DomainEvent) -> None: app_logger.info("Infrastructure initialized, starting event processing...") - yield - - # Graceful shutdown: wait for active pod creations - app_logger.info("KubernetesWorker shutting down...") - await logic.wait_for_active_creations() - await container.close() + try: + yield + finally: + # Graceful shutdown: wait for active pod creations + app_logger.info("KubernetesWorker shutting down...") + await logic.wait_for_active_creations() + # Container close stops Kafka producer via DI provider + await container.close() + app_logger.info("KubernetesWorker shutdown complete") # Create FastStream app app = FastStream(broker, lifespan=lifespan) diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index 9675755b..cf436bbc 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -15,20 +15,21 @@ from app.core.database_context import Database from app.core.logging import setup_logger from app.core.providers import ( + BoundaryClientProvider, DatabaseProvider, EventProvider, - KubernetesProvider, LoggingProvider, MessagingProvider, MetricsProvider, PodMonitorProvider, - RedisProvider, + RedisServicesProvider, RepositoryProvider, SettingsProvider, ) from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId +from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.pod_monitor.monitor import PodMonitor from app.settings import Settings @@ -41,13 +42,13 @@ async def run_pod_monitor(settings: Settings) -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), DatabaseProvider(), MetricsProvider(), EventProvider(), MessagingProvider(), RepositoryProvider(), - KubernetesProvider(), PodMonitorProvider(), context={Settings: settings}, ) @@ -61,6 +62,10 @@ async def run_pod_monitor(settings: Settings) -> None: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) + # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) + await container.get(UnifiedProducer) + logger.info("Kafka producer ready") + monitor = await container.get(PodMonitor) # Signal handling with minimal boilerplate @@ -88,7 +93,9 @@ async def run_pod_monitor(settings: Settings) -> None: finally: logger.info("Initiating graceful shutdown...") + # Container close stops Kafka producer via DI provider await container.close() + logger.info("PodMonitor shutdown complete") def main() -> None: diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 9b483684..ad2ec5b2 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -19,12 +19,13 @@ from app.core.database_context import Database from app.core.logging import setup_logger from app.core.providers import ( + BoundaryClientProvider, DatabaseProvider, EventProvider, LoggingProvider, MessagingProvider, MetricsProvider, - RedisProvider, + RedisServicesProvider, RepositoryProvider, ResultProcessorProvider, SettingsProvider, @@ -39,6 +40,7 @@ ExecutionFailedEvent, ExecutionTimeoutEvent, ) +from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.result_processor.processor_logic import ProcessorLogic @@ -80,7 +82,8 @@ def main() -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), DatabaseProvider(), MetricsProvider(), EventProvider(), @@ -118,6 +121,10 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) + # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) + await container.get(UnifiedProducer) + app_logger.info("Kafka producer ready") + # Decoder: Avro bytes → typed DomainEvent async def decode_avro(body: bytes) -> DomainEvent: return await schema_registry.deserialize_event(body, "result_processor") @@ -159,10 +166,13 @@ async def handle_other(event: DomainEvent) -> None: app_logger.info("Infrastructure initialized, starting event processing...") - yield - - app_logger.info("ResultProcessor shutting down...") - await container.close() + try: + yield + finally: + app_logger.info("ResultProcessor shutting down...") + # Container close stops Kafka producer via DI provider + await container.close() + app_logger.info("ResultProcessor shutdown complete") # Create FastStream app app = FastStream(broker, lifespan=lifespan) diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 80a42240..2b3dec13 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -20,12 +20,13 @@ from app.core.database_context import Database from app.core.logging import setup_logger from app.core.providers import ( + BoundaryClientProvider, DatabaseProvider, EventProvider, LoggingProvider, MessagingProvider, MetricsProvider, - RedisProvider, + RedisServicesProvider, RepositoryProvider, SagaOrchestratorProvider, SettingsProvider, @@ -34,6 +35,7 @@ from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.domain.events.typed import DomainEvent +from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.saga.saga_logic import SagaLogic @@ -75,7 +77,8 @@ def main() -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), - RedisProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), DatabaseProvider(), MetricsProvider(), EventProvider(), @@ -112,6 +115,10 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: schema_registry = await container.get(SchemaRegistryManager) await initialize_event_schemas(schema_registry) + # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) + await container.get(UnifiedProducer) + app_logger.info("Kafka producer ready") + # Get saga logic to determine topics logic = await container.get(SagaLogic) trigger_topics = logic.get_trigger_topics() @@ -160,10 +167,13 @@ async def handle_saga_event( app_logger.info(f"Subscribing to topics: {topics}") app_logger.info("Infrastructure initialized, starting event processing...") - yield - - app_logger.info("SagaOrchestrator shutting down...") - await container.close() + try: + yield + finally: + app_logger.info("SagaOrchestrator shutting down...") + # Container close stops Kafka producer via DI provider + await container.close() + app_logger.info("SagaOrchestrator shutdown complete") # Create FastStream app app = FastStream(broker, lifespan=lifespan) diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py new file mode 100644 index 00000000..73f8263e --- /dev/null +++ b/backend/workers/run_sse_bridge.py @@ -0,0 +1,137 @@ +import asyncio +import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import redis.asyncio as redis +from app.core.logging import setup_logger +from app.core.providers import ( + BoundaryClientProvider, + EventProvider, + LoggingProvider, + MetricsProvider, + RedisServicesProvider, + SettingsProvider, +) +from app.core.tracing import init_tracing +from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId +from app.domain.events.typed import DomainEvent +from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.services.sse.event_router import SSE_RELEVANT_EVENTS, SSEEventRouter +from app.services.sse.redis_bus import SSERedisBus +from app.settings import Settings +from dishka import Provider, Scope, make_async_container, provide +from dishka.integrations.faststream import FromDishka, setup_dishka +from faststream import FastStream +from faststream.kafka import KafkaBroker + + +class SSEBridgeProvider(Provider): + """Provides SSE bridge specific dependencies.""" + + scope = Scope.APP + + @provide + def get_sse_event_router( + self, + sse_redis_bus: SSERedisBus, + logger: logging.Logger, + ) -> SSEEventRouter: + return SSEEventRouter(sse_bus=sse_redis_bus, logger=logger) + + @provide + def get_sse_redis_bus( + self, + redis_client: redis.Redis, + logger: logging.Logger, + ) -> SSERedisBus: + return SSERedisBus(redis_client, logger) + + +def main() -> None: + """Entry point for SSE bridge worker.""" + settings = Settings() + + logger = setup_logger(settings.LOG_LEVEL) + logger.info("Starting SSE Bridge (FastStream)...") + + if settings.ENABLE_TRACING: + init_tracing( + service_name="sse-bridge", + settings=settings, + logger=logger, + service_version=settings.TRACING_SERVICE_VERSION, + enable_console_exporter=False, + sampling_rate=settings.TRACING_SAMPLING_RATE, + ) + + # DI container - no database needed, just Redis and Kafka + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), + MetricsProvider(), + EventProvider(), + SSEBridgeProvider(), + context={Settings: settings}, + ) + + # Topics from config + topics = [ + f"{settings.KAFKA_TOPIC_PREFIX}{t}" + for t in CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.WEBSOCKET_GATEWAY] + ] + group_id = f"{GroupId.WEBSOCKET_GATEWAY}.{settings.KAFKA_GROUP_SUFFIX}" + + broker = KafkaBroker( + settings.KAFKA_BOOTSTRAP_SERVERS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) + + @asynccontextmanager + async def lifespan(app: FastStream) -> AsyncIterator[None]: + app_logger = await container.get(logging.Logger) + app_logger.info("SSE Bridge starting...") + + # Initialize schema registry + schema_registry = await container.get(SchemaRegistryManager) + await initialize_event_schemas(schema_registry) + + # Decoder: Avro bytes → typed DomainEvent + async def decode_avro(body: bytes) -> DomainEvent: + return await schema_registry.deserialize_event(body, "sse_bridge") + + # Single handler for all SSE-relevant events + # No filter needed - we check event_type in handler since route_event handles all types + @broker.subscriber( + *topics, + group_id=group_id, + auto_commit=True, # SSE bridge is idempotent (Redis pubsub) + decoder=decode_avro, + ) + async def handle_sse_event( + event: DomainEvent, + router: FromDishka[SSEEventRouter], + ) -> None: + """Route domain events to Redis for SSE delivery.""" + if event.event_type in SSE_RELEVANT_EVENTS: + await router.route_event(event) + + app_logger.info(f"Subscribing to topics: {topics}") + app_logger.info("SSE Bridge ready") + + try: + yield + finally: + app_logger.info("SSE Bridge shutting down...") + await container.close() + + app = FastStream(broker, lifespan=lifespan) + setup_dishka(container=container, app=app, auto_inject=True) + + asyncio.run(app.run()) + + +if __name__ == "__main__": + main() From a994a84bc9faf8e85cee674c6b9d919c850caaf8 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 16:51:41 +0100 Subject: [PATCH 10/21] moved all inits to providers, simplified all other stuff --- backend/app/core/dishka_lifespan.py | 105 +------------------ backend/app/core/providers.py | 20 +++- backend/app/core/startup.py | 42 -------- backend/app/events/schema/schema_registry.py | 5 - backend/workers/run_k8s_worker.py | 5 +- backend/workers/run_pod_monitor.py | 12 +-- backend/workers/run_result_processor.py | 14 +-- backend/workers/run_saga_orchestrator.py | 14 +-- backend/workers/run_sse_bridge.py | 5 +- 9 files changed, 34 insertions(+), 188 deletions(-) delete mode 100644 backend/app/core/startup.py diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index b186eba7..6ede442c 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -1,110 +1,11 @@ -import asyncio -import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -import redis.asyncio as redis -from beanie import init_beanie -from dishka import AsyncContainer from fastapi import FastAPI -from app.core.database_context import Database -from app.core.metrics import RateLimitMetrics -from app.core.startup import initialize_rate_limits -from app.core.tracing import init_tracing -from app.db.docs import ALL_DOCUMENTS -from app.events.core import UnifiedProducer -from app.events.event_store_consumer import EventStoreConsumer -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas -from app.services.event_bus import EventBus -from app.services.notification_service import NotificationService -from app.settings import Settings - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - """Application lifespan with dishka dependency injection. - - All service lifecycle (start/stop, background tasks) is managed by DI providers. - Lifespan only: - 1. Resolves dependencies (triggers provider lifecycle setup) - 2. Initializes schemas and beanie - 3. On shutdown, container cleanup handles everything - - Note: SSE Kafka consumers are now in the separate SSE bridge worker (run_sse_bridge.py). - """ - container: AsyncContainer = app.state.dishka_container - settings = await container.get(Settings) - logger = await container.get(logging.Logger) - - logger.info( - "Starting application with dishka DI", - extra={ - "project_name": settings.PROJECT_NAME, - "environment": "test" if settings.TESTING else "production", - }, - ) - - # Initialize tracing only when enabled - if settings.ENABLE_TRACING and not settings.TESTING: - instrumentation_report = init_tracing( - service_name=settings.TRACING_SERVICE_NAME, - settings=settings, - logger=logger, - service_version=settings.TRACING_SERVICE_VERSION, - sampling_rate=settings.TRACING_SAMPLING_RATE, - enable_console_exporter=settings.TESTING, - adaptive_sampling=settings.TRACING_ADAPTIVE_SAMPLING, - ) - - if instrumentation_report.has_failures(): - logger.warning( - "Some instrumentation libraries failed to initialize", - extra={"instrumentation_summary": instrumentation_report.get_summary()}, - ) - else: - logger.info( - "Distributed tracing initialized successfully", - extra={"instrumentation_summary": instrumentation_report.get_summary()}, - ) - else: - logger.info( - "Distributed tracing disabled", - extra={"testing": settings.TESTING, "enable_tracing": settings.ENABLE_TRACING}, - ) - - # Phase 1: Resolve all DI dependencies in parallel - # This triggers async generator providers which start services and background tasks - ( - schema_registry, - database, - redis_client, - rate_limit_metrics, - _event_store_consumer, - _notification_service, - _kafka_producer, - _event_bus, - ) = await asyncio.gather( - container.get(SchemaRegistryManager), - container.get(Database), - container.get(redis.Redis), - container.get(RateLimitMetrics), - container.get(EventStoreConsumer), - container.get(NotificationService), - container.get(UnifiedProducer), - container.get(EventBus), - ) - - # Phase 2: Initialize infrastructure - await asyncio.gather( - initialize_event_schemas(schema_registry), - init_beanie(database=database, document_models=ALL_DOCUMENTS), - initialize_rate_limits(redis_client, settings, logger, rate_limit_metrics), - ) - logger.info("Application started - all services running") - - try: - yield - finally: - # Container cleanup handles all service shutdown via async generator cleanup - logger.info("Application shutting down - container cleanup will stop all services") + """Minimal lifespan - container.close() triggers all provider cleanup.""" + yield + await app.state.dishka_container.close() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index c02c5e9b..7c17c181 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -4,6 +4,7 @@ import redis.asyncio as redis from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from beanie import init_beanie from dishka import Provider, Scope, from_context, provide from kubernetes import client as k8s_client from kubernetes import config as k8s_config @@ -27,6 +28,7 @@ ) from app.core.security import SecurityService from app.core.tracing import TracerManager +from app.db.docs import ALL_DOCUMENTS from app.db.repositories import ( EventRepository, ExecutionRepository, @@ -180,12 +182,14 @@ class DatabaseProvider(Provider): scope = Scope.APP @provide - def get_database(self, settings: Settings, logger: logging.Logger) -> Database: + async def get_database(self, settings: Settings, logger: logging.Logger) -> AsyncIterable[Database]: client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 ) - logger.info(f"MongoDB configured: {settings.DATABASE_NAME}") - return client[settings.DATABASE_NAME] + db = client[settings.DATABASE_NAME] + await init_beanie(database=db, document_models=ALL_DOCUMENTS) + logger.info(f"MongoDB + Beanie initialized: {settings.DATABASE_NAME}") + yield db class CoreServicesProvider(Provider): @@ -316,8 +320,14 @@ class EventProvider(Provider): scope = Scope.APP @provide - def get_schema_registry(self, settings: Settings, logger: logging.Logger) -> SchemaRegistryManager: - return SchemaRegistryManager(settings, logger) + async def get_schema_registry( + self, settings: Settings, logger: logging.Logger + ) -> AsyncIterable[SchemaRegistryManager]: + """Provide SchemaRegistryManager with DI-managed initialization.""" + registry = SchemaRegistryManager(settings, logger) + await registry.initialize_schemas() + logger.info("Schema registry initialized") + yield registry @provide def get_event_store( diff --git a/backend/app/core/startup.py b/backend/app/core/startup.py deleted file mode 100644 index 549c3cb8..00000000 --- a/backend/app/core/startup.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging - -import redis.asyncio as redis - -from app.core.metrics import RateLimitMetrics -from app.domain.rate_limit import RateLimitConfig -from app.services.rate_limit_service import RateLimitService -from app.settings import Settings - - -async def initialize_rate_limits( - redis_client: redis.Redis, - settings: Settings, - logger: logging.Logger, - rate_limit_metrics: RateLimitMetrics, -) -> None: - """ - Initialize default rate limits in Redis on application startup. - This ensures default limits are always available. - """ - try: - service = RateLimitService(redis_client, settings, rate_limit_metrics) - - # Check if config already exists - config_key = f"{settings.RATE_LIMIT_REDIS_PREFIX}config" - existing_config = await redis_client.get(config_key) - - if not existing_config: - logger.info("Initializing default rate limit configuration in Redis") - - # Get default config and save it - default_config = RateLimitConfig.get_default_config() - await service.update_config(default_config) - - logger.info(f"Initialized {len(default_config.default_rules)} default rate limit rules") - else: - logger.info("Rate limit configuration already exists in Redis") - - except Exception as e: - logger.error(f"Failed to initialize rate limits: {e}") - # Don't fail startup if rate limit init fails - # The service will use defaults if Redis is unavailable diff --git a/backend/app/events/schema/schema_registry.py b/backend/app/events/schema/schema_registry.py index a53306b6..e5d62c1a 100644 --- a/backend/app/events/schema/schema_registry.py +++ b/backend/app/events/schema/schema_registry.py @@ -146,8 +146,3 @@ async def initialize_schemas(self) -> None: await self.set_compatibility(subject, "FORWARD") await self._ensure_schema_registered(event_class) self.logger.info(f"Initialized {len(_get_all_event_classes())} event schemas") - - -async def initialize_event_schemas(registry: SchemaRegistryManager) -> None: - """Initialize all event schemas in the registry.""" - await registry.initialize_schemas() diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 7eb8a63e..f35baab1 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -36,7 +36,7 @@ DomainEvent, ) from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.events.schema.schema_registry import SchemaRegistryManager from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.k8s_worker.worker_logic import K8sWorkerLogic from app.settings import Settings @@ -105,9 +105,8 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("KubernetesWorker starting...") - # Initialize schema registry + # Resolve schema registry (initialization handled by provider) schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index cf436bbc..db090e38 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -27,13 +27,11 @@ SettingsProvider, ) from app.core.tracing import init_tracing -from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.events.schema.schema_registry import SchemaRegistryManager from app.services.pod_monitor.monitor import PodMonitor from app.settings import Settings -from beanie import init_beanie from dishka import make_async_container @@ -56,11 +54,9 @@ async def run_pod_monitor(settings: Settings) -> None: logger = await container.get(logging.Logger) logger.info("Starting PodMonitor with DI container...") - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - - schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) + # Resolve dependencies (initialization handled by providers) + await container.get(Database) # Triggers init_beanie via DatabaseProvider + await container.get(SchemaRegistryManager) # Triggers initialize_schemas # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index ad2ec5b2..29ed0fe6 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -31,7 +31,6 @@ SettingsProvider, ) from app.core.tracing import init_tracing -from app.db.docs import ALL_DOCUMENTS from app.domain.enums.events import EventType from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.events.typed import ( @@ -41,11 +40,10 @@ ExecutionTimeoutEvent, ) from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.events.schema.schema_registry import SchemaRegistryManager from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.result_processor.processor_logic import ProcessorLogic from app.settings import Settings -from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream @@ -113,13 +111,9 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("ResultProcessor starting...") - # Initialize database - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - - # Initialize schema registry - schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) + # Resolve dependencies (initialization handled by providers) + await container.get(Database) # Triggers init_beanie via DatabaseProvider + schema_registry = await container.get(SchemaRegistryManager) # Triggers initialize_schemas # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 2b3dec13..662cda67 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -32,15 +32,13 @@ SettingsProvider, ) from app.core.tracing import init_tracing -from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.domain.events.typed import DomainEvent from app.events.core import UnifiedProducer -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.events.schema.schema_registry import SchemaRegistryManager from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.saga.saga_logic import SagaLogic from app.settings import Settings -from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream @@ -107,13 +105,9 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("SagaOrchestrator starting...") - # Initialize database - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - - # Initialize schema registry - schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) + # Resolve dependencies (initialization handled by providers) + await container.get(Database) # Triggers init_beanie via DatabaseProvider + schema_registry = await container.get(SchemaRegistryManager) # Triggers initialize_schemas # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py index 73f8263e..2a650aa7 100644 --- a/backend/workers/run_sse_bridge.py +++ b/backend/workers/run_sse_bridge.py @@ -16,7 +16,7 @@ from app.core.tracing import init_tracing from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.events.typed import DomainEvent -from app.events.schema.schema_registry import SchemaRegistryManager, initialize_event_schemas +from app.events.schema.schema_registry import SchemaRegistryManager from app.services.sse.event_router import SSE_RELEVANT_EVENTS, SSEEventRouter from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings @@ -94,9 +94,8 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("SSE Bridge starting...") - # Initialize schema registry + # Resolve schema registry (initialization handled by provider) schema_registry = await container.get(SchemaRegistryManager) - await initialize_event_schemas(schema_registry) # Decoder: Avro bytes → typed DomainEvent async def decode_avro(body: bytes) -> DomainEvent: From e7f22284294f80e9943d1375258cd1c122c9ecc7 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 20:31:09 +0100 Subject: [PATCH 11/21] test infra updates - using beanie instead of mongodb driver --- backend/app/core/container.py | 9 --------- backend/app/core/database_context.py | 14 -------------- backend/app/core/dishka_lifespan.py | 18 ++++++++++++++++-- backend/app/core/providers.py | 18 ------------------ backend/tests/conftest.py | 19 +++++-------------- backend/tests/e2e/conftest.py | 5 ++--- backend/tests/helpers/cleanup.py | 10 ++++------ backend/tests/helpers/fakes/__init__.py | 3 +-- backend/tests/helpers/fakes/providers.py | 17 ----------------- backend/tests/integration/conftest.py | 5 ++--- backend/tests/unit/conftest.py | 18 +++++++++++------- backend/workers/dlq_processor.py | 9 ++++++--- backend/workers/run_event_replay.py | 10 ++++++---- backend/workers/run_k8s_worker.py | 14 +++++--------- backend/workers/run_pod_monitor.py | 19 ++++++++++++------- backend/workers/run_result_processor.py | 19 ++++++++++++------- backend/workers/run_saga_orchestrator.py | 20 +++++++++++++------- 17 files changed, 95 insertions(+), 132 deletions(-) delete mode 100644 backend/app/core/database_context.py diff --git a/backend/app/core/container.py b/backend/app/core/container.py index 9b5bddd6..6febb91f 100644 --- a/backend/app/core/container.py +++ b/backend/app/core/container.py @@ -7,7 +7,6 @@ BoundaryClientProvider, BusinessServicesProvider, CoreServicesProvider, - DatabaseProvider, EventProvider, EventReplayProvider, K8sWorkerProvider, @@ -37,7 +36,6 @@ def create_app_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), @@ -67,7 +65,6 @@ def create_result_processor_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), @@ -85,12 +82,10 @@ def create_k8s_worker_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), MetricsProvider(), - RepositoryProvider(), MessagingProvider(), EventProvider(), K8sWorkerProvider(), @@ -103,7 +98,6 @@ def create_pod_monitor_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), @@ -122,7 +116,6 @@ def create_saga_orchestrator_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), @@ -140,7 +133,6 @@ def create_event_replay_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), @@ -158,7 +150,6 @@ def create_dlq_processor_container(settings: Settings) -> AsyncContainer: return make_async_container( SettingsProvider(), LoggingProvider(), - DatabaseProvider(), BoundaryClientProvider(), RedisServicesProvider(), CoreServicesProvider(), diff --git a/backend/app/core/database_context.py b/backend/app/core/database_context.py deleted file mode 100644 index 06913e03..00000000 --- a/backend/app/core/database_context.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Any - -from pymongo.asynchronous.client_session import AsyncClientSession -from pymongo.asynchronous.collection import AsyncCollection -from pymongo.asynchronous.cursor import AsyncCursor -from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.mongo_client import AsyncMongoClient - -type MongoDocument = dict[str, Any] -type DBClient = AsyncMongoClient[MongoDocument] -type Database = AsyncDatabase[MongoDocument] -type Collection = AsyncCollection[MongoDocument] -type Cursor = AsyncCursor[MongoDocument] -type DBSession = AsyncClientSession diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index 6ede442c..b68c41d6 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -1,11 +1,25 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from beanie import init_beanie from fastapi import FastAPI +from pymongo.asynchronous.mongo_client import AsyncMongoClient + +from app.db.docs import ALL_DOCUMENTS +from app.settings import Settings @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: - """Minimal lifespan - container.close() triggers all provider cleanup.""" + container = app.state.dishka_container + settings = await container.get(Settings) + + client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + yield - await app.state.dishka_container.close() + + await client.close() + await container.close() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 7c17c181..473e925e 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -4,14 +4,11 @@ import redis.asyncio as redis from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from beanie import init_beanie from dishka import Provider, Scope, from_context, provide from kubernetes import client as k8s_client from kubernetes import config as k8s_config from kubernetes import watch as k8s_watch -from pymongo.asynchronous.mongo_client import AsyncMongoClient -from app.core.database_context import Database from app.core.logging import setup_logger from app.core.metrics import ( ConnectionMetrics, @@ -28,7 +25,6 @@ ) from app.core.security import SecurityService from app.core.tracing import TracerManager -from app.db.docs import ALL_DOCUMENTS from app.db.repositories import ( EventRepository, ExecutionRepository, @@ -178,20 +174,6 @@ def get_rate_limit_service( return RateLimitService(redis_client, settings, rate_limit_metrics) -class DatabaseProvider(Provider): - scope = Scope.APP - - @provide - async def get_database(self, settings: Settings, logger: logging.Logger) -> AsyncIterable[Database]: - client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( - settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 - ) - db = client[settings.DATABASE_NAME] - await init_beanie(database=db, document_models=ALL_DOCUMENTS) - logger.info(f"MongoDB + Beanie initialized: {settings.DATABASE_NAME}") - yield db - - class CoreServicesProvider(Provider): scope = Scope.APP diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index aeadecd4..c090249a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -8,7 +8,7 @@ import pytest import pytest_asyncio import redis.asyncio as redis -from app.core.database_context import Database +from app.db.docs import ALL_DOCUMENTS from app.main import create_app from app.settings import Settings from dishka import AsyncContainer @@ -88,18 +88,15 @@ async def app(test_settings: Settings) -> AsyncGenerator[FastAPI, None]: Uses lifespan_context to trigger startup/shutdown events, which initializes Beanie, metrics, and other services through the normal DI flow. - Cleanup: Best-effort drop of test database. May not always succeed due to - known MongoDB driver behavior when client stays connected, but ulimits on - MongoDB container (65536) prevent file descriptor exhaustion regardless. + Cleanup: Delete all documents via Beanie models. Preserves indexes and avoids + file descriptor exhaustion issues from dropping/recreating databases. """ application = create_app(settings=test_settings) async with application.router.lifespan_context(application): yield application - # Best-effort cleanup (may fail silently due to MongoDB driver behavior) - container: AsyncContainer = application.state.dishka_container - db: Database = await container.get(Database) - await db.client.drop_database(test_settings.DATABASE_NAME) + for doc_class in ALL_DOCUMENTS: + await doc_class.delete_all() @pytest_asyncio.fixture(scope="session") @@ -133,12 +130,6 @@ async def scope(app_container: AsyncContainer) -> AsyncGenerator[AsyncContainer, yield s -@pytest_asyncio.fixture -async def db(scope: AsyncContainer) -> AsyncGenerator[Database, None]: - database: Database = await scope.get(Database) - yield database - - @pytest_asyncio.fixture async def redis_client(scope: AsyncContainer) -> AsyncGenerator[redis.Redis, None]: # Dishka's RedisProvider handles cleanup when scope exits diff --git a/backend/tests/e2e/conftest.py b/backend/tests/e2e/conftest.py index 648dfaef..0196009b 100644 --- a/backend/tests/e2e/conftest.py +++ b/backend/tests/e2e/conftest.py @@ -2,18 +2,17 @@ import pytest_asyncio import redis.asyncio as redis -from app.core.database_context import Database from tests.helpers.cleanup import cleanup_db_and_redis @pytest_asyncio.fixture(autouse=True) -async def _cleanup(db: Database, redis_client: redis.Redis) -> AsyncGenerator[None, None]: +async def _cleanup(redis_client: redis.Redis) -> AsyncGenerator[None, None]: """Clean DB and Redis before each E2E test. Only pre-test cleanup - post-test cleanup causes event loop issues when SSE/streaming tests hold connections across loop boundaries. """ - await cleanup_db_and_redis(db, redis_client) + await cleanup_db_and_redis(redis_client) yield # No post-test cleanup to avoid "Event loop is closed" errors diff --git a/backend/tests/helpers/cleanup.py b/backend/tests/helpers/cleanup.py index 760b48da..78e70817 100644 --- a/backend/tests/helpers/cleanup.py +++ b/backend/tests/helpers/cleanup.py @@ -1,8 +1,8 @@ import redis.asyncio as redis -from app.core.database_context import Database +from app.db.docs import ALL_DOCUMENTS -async def cleanup_db_and_redis(db: Database, redis_client: redis.Redis) -> None: +async def cleanup_db_and_redis(redis_client: redis.Redis) -> None: """Clean DB and Redis before a test. Beanie is already initialized once during app lifespan (dishka_lifespan.py). @@ -13,9 +13,7 @@ async def cleanup_db_and_redis(db: Database, redis_client: redis.Redis) -> None: is safe and only affects that worker's database. See tests/conftest.py for REDIS_DB setup. """ - collections = await db.list_collection_names(filter={"type": "collection"}) - for name in collections: - if not name.startswith("system."): - await db[name].delete_many({}) + for doc_class in ALL_DOCUMENTS: + await doc_class.delete_all() await redis_client.flushdb() diff --git a/backend/tests/helpers/fakes/__init__.py b/backend/tests/helpers/fakes/__init__.py index c6b1eb30..ad95e394 100644 --- a/backend/tests/helpers/fakes/__init__.py +++ b/backend/tests/helpers/fakes/__init__.py @@ -1,11 +1,10 @@ """Fake implementations for external boundary clients used in tests.""" -from .providers import FakeBoundaryClientProvider, FakeDatabaseProvider, FakeSchemaRegistryProvider +from .providers import FakeBoundaryClientProvider, FakeSchemaRegistryProvider from .schema_registry import FakeSchemaRegistryManager __all__ = [ "FakeBoundaryClientProvider", - "FakeDatabaseProvider", "FakeSchemaRegistryManager", "FakeSchemaRegistryProvider", ] diff --git a/backend/tests/helpers/fakes/providers.py b/backend/tests/helpers/fakes/providers.py index 14718873..952c5128 100644 --- a/backend/tests/helpers/fakes/providers.py +++ b/backend/tests/helpers/fakes/providers.py @@ -1,18 +1,14 @@ """Fake providers for unit testing with DI container.""" import logging -from typing import Any import fakeredis.aioredis import redis.asyncio as redis from aiokafka import AIOKafkaProducer -from app.core.database_context import Database from app.events.schema.schema_registry import SchemaRegistryManager -from app.settings import Settings from dishka import Provider, Scope, provide from kubernetes import client as k8s_client from kubernetes import watch as k8s_watch -from mongomock_motor import AsyncMongoMockClient from tests.helpers.fakes.kafka import FakeAIOKafkaProducer from tests.helpers.fakes.kubernetes import ( @@ -60,19 +56,6 @@ def get_k8s_watch(self) -> k8s_watch.Watch: return FakeK8sWatch() -class FakeDatabaseProvider(Provider): - """Fake MongoDB database for unit testing using mongomock-motor.""" - - scope = Scope.APP - - @provide - def get_database(self, settings: Settings, logger: logging.Logger) -> Database: - logger.info(f"Using AsyncMongoMockClient for testing: {settings.DATABASE_NAME}") - client: AsyncMongoMockClient[dict[str, Any]] = AsyncMongoMockClient() - # mongomock_motor returns AsyncIOMotorDatabase which is API-compatible with AsyncDatabase - return client[settings.DATABASE_NAME] # type: ignore[return-value] - - class FakeSchemaRegistryProvider(Provider): """Fake Schema Registry provider - must be placed after EventProvider to override.""" diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 329ba48c..9d60ebe2 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -5,7 +5,6 @@ import pytest import pytest_asyncio import redis.asyncio as redis -from app.core.database_context import Database from app.core.metrics import DatabaseMetrics, EventMetrics from app.events.core import ConsumerConfig from app.events.schema.schema_registry import SchemaRegistryManager @@ -21,13 +20,13 @@ @pytest_asyncio.fixture(autouse=True) -async def _cleanup(db: Database, redis_client: redis.Redis) -> AsyncGenerator[None, None]: +async def _cleanup(redis_client: redis.Redis) -> AsyncGenerator[None, None]: """Clean DB and Redis before each integration test. Only pre-test cleanup - post-test cleanup causes event loop issues when SSE/streaming tests hold connections across loop boundaries. """ - await cleanup_db_and_redis(db, redis_client) + await cleanup_db_and_redis(redis_client) yield # No post-test cleanup to avoid "Event loop is closed" errors diff --git a/backend/tests/unit/conftest.py b/backend/tests/unit/conftest.py index 24e1bf1b..5f483653 100644 --- a/backend/tests/unit/conftest.py +++ b/backend/tests/unit/conftest.py @@ -4,7 +4,6 @@ import pytest import pytest_asyncio from aiokafka import AIOKafkaProducer -from app.core.database_context import Database from app.core.metrics import ( ConnectionMetrics, CoordinatorMetrics, @@ -47,8 +46,9 @@ from app.settings import Settings from beanie import init_beanie from dishka import AsyncContainer, make_async_container +from mongomock_motor import AsyncMongoMockClient -from tests.helpers.fakes import FakeBoundaryClientProvider, FakeDatabaseProvider, FakeSchemaRegistryProvider +from tests.helpers.fakes import FakeBoundaryClientProvider, FakeSchemaRegistryProvider from tests.helpers.fakes.kafka import FakeAIOKafkaProducer from tests.helpers.k8s_fakes import FakeApi, FakeV1Api, FakeWatch, make_k8s_clients @@ -60,18 +60,18 @@ async def unit_container(test_settings: Settings) -> AsyncGenerator[AsyncContain """DI container for unit tests with fake boundary clients. Provides: - - Fake Redis, Kafka, K8s, MongoDB (boundary clients) + - Fake Redis, Kafka, K8s (boundary clients) - Real metrics, repositories, services (internal) + - Real MongoDB via init_beanie """ container = make_async_container( SettingsProvider(), LoggingProvider(), FakeBoundaryClientProvider(), - FakeDatabaseProvider(), RedisServicesProvider(), MetricsProvider(), EventProvider(), - FakeSchemaRegistryProvider(), # Override real schema registry with fake + FakeSchemaRegistryProvider(), MessagingProvider(), CoreServicesProvider(), KafkaServicesProvider(), @@ -79,10 +79,14 @@ async def unit_container(test_settings: Settings) -> AsyncGenerator[AsyncContain context={Settings: test_settings}, ) - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) + mongo_client: AsyncMongoMockClient[dict[str, object]] = AsyncMongoMockClient() + await init_beanie( + database=mongo_client[test_settings.DATABASE_NAME], # type: ignore[arg-type] + document_models=ALL_DOCUMENTS, + ) yield container + await container.close() diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 97598539..2cb8b6b6 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -5,12 +5,12 @@ from datetime import datetime, timezone from app.core.container import create_dlq_processor_container -from app.core.database_context import Database from app.db.docs import ALL_DOCUMENTS from app.dlq import DLQMessage, RetryPolicy, RetryStrategy from app.dlq.manager import DLQManager from app.settings import Settings from beanie import init_beanie +from pymongo.asynchronous.mongo_client import AsyncMongoClient def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> None: @@ -84,8 +84,10 @@ async def main(settings: Settings) -> None: logger = await container.get(logging.Logger) logger.info("Starting DLQ Processor with DI container...") - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) + mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) manager = await container.get(DLQManager) @@ -104,6 +106,7 @@ def signal_handler() -> None: async with AsyncExitStack() as stack: stack.push_async_callback(container.close) + stack.push_async_callback(mongo_client.close) await stop_event.wait() diff --git a/backend/workers/run_event_replay.py b/backend/workers/run_event_replay.py index 51bfbaa1..8a905646 100644 --- a/backend/workers/run_event_replay.py +++ b/backend/workers/run_event_replay.py @@ -3,7 +3,6 @@ from contextlib import AsyncExitStack from app.core.container import create_event_replay_container -from app.core.database_context import Database from app.core.logging import setup_logger from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS @@ -11,6 +10,7 @@ from app.services.event_replay.replay_service import EventReplayService from app.settings import Settings from beanie import init_beanie +from pymongo.asynchronous.mongo_client import AsyncMongoClient async def cleanup_task(replay_service: EventReplayService, logger: logging.Logger, interval_hours: int = 6) -> None: @@ -31,8 +31,10 @@ async def run_replay_service(settings: Settings) -> None: logger = await container.get(logging.Logger) logger.info("Starting EventReplayService with DI container...") - db = await container.get(Database) - await init_beanie(database=db, document_models=ALL_DOCUMENTS) + mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) @@ -41,8 +43,8 @@ async def run_replay_service(settings: Settings) -> None: logger.info("Event replay service initialized") async with AsyncExitStack() as stack: - # Container close stops Kafka producer via DI provider stack.push_async_callback(container.close) + stack.push_async_callback(mongo_client.close) task = asyncio.create_task(cleanup_task(replay_service, logger)) diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index f35baab1..58e5ac6c 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -35,7 +35,6 @@ DeletePodCommandEvent, DomainEvent, ) -from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.k8s_worker.worker_logic import K8sWorkerLogic @@ -105,15 +104,12 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("KubernetesWorker starting...") - # Resolve schema registry (initialization handled by provider) - schema_registry = await container.get(SchemaRegistryManager) - - # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) - await container.get(UnifiedProducer) - app_logger.info("Kafka producer ready") - - # Get worker logic and ensure daemonset (one-time initialization) + # Get worker logic - triggers full dependency chain: + # K8sWorkerLogic -> UnifiedProducer -> SchemaRegistryManager (init) logic = await container.get(K8sWorkerLogic) + + # Get schema registry for decoder (already initialized via chain above) + schema_registry = await container.get(SchemaRegistryManager) await logic.ensure_image_pre_puller_daemonset() # Decoder: Avro bytes → typed DomainEvent diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index db090e38..fb7c550e 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -12,11 +12,9 @@ import signal from contextlib import suppress -from app.core.database_context import Database from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, - DatabaseProvider, EventProvider, LoggingProvider, MessagingProvider, @@ -27,12 +25,15 @@ SettingsProvider, ) from app.core.tracing import init_tracing +from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager from app.services.pod_monitor.monitor import PodMonitor from app.settings import Settings +from beanie import init_beanie from dishka import make_async_container +from pymongo.asynchronous.mongo_client import AsyncMongoClient async def run_pod_monitor(settings: Settings) -> None: @@ -42,7 +43,6 @@ async def run_pod_monitor(settings: Settings) -> None: LoggingProvider(), BoundaryClientProvider(), RedisServicesProvider(), - DatabaseProvider(), MetricsProvider(), EventProvider(), MessagingProvider(), @@ -54,9 +54,14 @@ async def run_pod_monitor(settings: Settings) -> None: logger = await container.get(logging.Logger) logger.info("Starting PodMonitor with DI container...") - # Resolve dependencies (initialization handled by providers) - await container.get(Database) # Triggers init_beanie via DatabaseProvider - await container.get(SchemaRegistryManager) # Triggers initialize_schemas + # Initialize MongoDB + Beanie + mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + + # Resolve schema registry (initialization handled by provider) + await container.get(SchemaRegistryManager) # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) @@ -89,7 +94,7 @@ async def run_pod_monitor(settings: Settings) -> None: finally: logger.info("Initiating graceful shutdown...") - # Container close stops Kafka producer via DI provider + await mongo_client.close() await container.close() logger.info("PodMonitor shutdown complete") diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 29ed0fe6..aa5ab748 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -16,11 +16,9 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from app.core.database_context import Database from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, - DatabaseProvider, EventProvider, LoggingProvider, MessagingProvider, @@ -31,6 +29,7 @@ SettingsProvider, ) from app.core.tracing import init_tracing +from app.db.docs import ALL_DOCUMENTS from app.domain.enums.events import EventType from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId from app.domain.events.typed import ( @@ -44,10 +43,12 @@ from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.result_processor.processor_logic import ProcessorLogic from app.settings import Settings +from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from pymongo.asynchronous.mongo_client import AsyncMongoClient def main() -> None: @@ -82,7 +83,6 @@ def main() -> None: LoggingProvider(), BoundaryClientProvider(), RedisServicesProvider(), - DatabaseProvider(), MetricsProvider(), EventProvider(), MessagingProvider(), @@ -111,9 +111,14 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("ResultProcessor starting...") - # Resolve dependencies (initialization handled by providers) - await container.get(Database) # Triggers init_beanie via DatabaseProvider - schema_registry = await container.get(SchemaRegistryManager) # Triggers initialize_schemas + # Initialize MongoDB + Beanie + mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + + # Resolve schema registry (initialization handled by provider) + schema_registry = await container.get(SchemaRegistryManager) # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) @@ -164,7 +169,7 @@ async def handle_other(event: DomainEvent) -> None: yield finally: app_logger.info("ResultProcessor shutting down...") - # Container close stops Kafka producer via DI provider + await mongo_client.close() await container.close() app_logger.info("ResultProcessor shutdown complete") diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 662cda67..ddb5ed8c 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -17,11 +17,9 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from app.core.database_context import Database from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, - DatabaseProvider, EventProvider, LoggingProvider, MessagingProvider, @@ -32,6 +30,7 @@ SettingsProvider, ) from app.core.tracing import init_tracing +from app.db.docs import ALL_DOCUMENTS from app.domain.enums.kafka import GroupId from app.domain.events.typed import DomainEvent from app.events.core import UnifiedProducer @@ -39,10 +38,12 @@ from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.saga.saga_logic import SagaLogic from app.settings import Settings +from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from pymongo.asynchronous.mongo_client import AsyncMongoClient def main() -> None: @@ -77,7 +78,6 @@ def main() -> None: LoggingProvider(), BoundaryClientProvider(), RedisServicesProvider(), - DatabaseProvider(), MetricsProvider(), EventProvider(), MessagingProvider(), @@ -105,9 +105,14 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("SagaOrchestrator starting...") - # Resolve dependencies (initialization handled by providers) - await container.get(Database) # Triggers init_beanie via DatabaseProvider - schema_registry = await container.get(SchemaRegistryManager) # Triggers initialize_schemas + # Initialize MongoDB + Beanie + mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + + # Resolve schema registry (initialization handled by provider) + schema_registry = await container.get(SchemaRegistryManager) # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) @@ -120,6 +125,7 @@ async def lifespan(app: FastStream) -> AsyncIterator[None]: if not trigger_topics: app_logger.warning("No saga triggers configured, shutting down") yield + await mongo_client.close() await container.close() return @@ -165,7 +171,7 @@ async def handle_saga_event( yield finally: app_logger.info("SagaOrchestrator shutting down...") - # Container close stops Kafka producer via DI provider + await mongo_client.close() await container.close() app_logger.info("SagaOrchestrator shutdown complete") From 583c5ffe0573ac4a44a3e94da34fa4ce64fd9a5c Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 22:20:22 +0100 Subject: [PATCH 12/21] more event-based stuff, removed all loops except 2 for notif service --- backend/app/core/container.py | 14 - backend/app/core/dishka_lifespan.py | 42 +++ backend/app/core/providers.py | 229 +++-------------- backend/app/dlq/manager.py | 68 +++-- backend/app/domain/enums/kafka.py | 1 - backend/app/events/event_store_consumer.py | 66 ++--- backend/app/infrastructure/kafka/topics.py | 8 - backend/app/services/event_bus.py | 240 ------------------ backend/app/services/notification_service.py | 30 +-- backend/app/services/pod_monitor/config.py | 1 - backend/app/services/pod_monitor/monitor.py | 14 +- backend/app/services/sse/redis_bus.py | 109 +++++++- .../services/sse/sse_connection_registry.py | 11 +- backend/app/services/user_settings_service.py | 9 +- .../tests/integration/core/test_container.py | 6 - .../test_admin_settings_repository.py | 9 +- .../services/admin/test_admin_user_service.py | 26 +- .../services/events/test_event_bus.py | 52 ---- .../unit/services/pod_monitor/test_monitor.py | 9 - backend/workers/dlq_processor.py | 181 ++++++++++--- backend/workers/run_pod_monitor.py | 32 +-- 21 files changed, 423 insertions(+), 734 deletions(-) delete mode 100644 backend/app/services/event_bus.py delete mode 100644 backend/tests/integration/services/events/test_event_bus.py diff --git a/backend/app/core/container.py b/backend/app/core/container.py index 6febb91f..62eb7329 100644 --- a/backend/app/core/container.py +++ b/backend/app/core/container.py @@ -145,17 +145,3 @@ def create_event_replay_container(settings: Settings) -> AsyncContainer: ) -def create_dlq_processor_container(settings: Settings) -> AsyncContainer: - """Create DI container for the DLQ processor worker.""" - return make_async_container( - SettingsProvider(), - LoggingProvider(), - BoundaryClientProvider(), - RedisServicesProvider(), - CoreServicesProvider(), - MetricsProvider(), - RepositoryProvider(), - MessagingProvider(), - EventProvider(), - context={Settings: settings}, - ) diff --git a/backend/app/core/dishka_lifespan.py b/backend/app/core/dishka_lifespan.py index b68c41d6..d0daff2d 100644 --- a/backend/app/core/dishka_lifespan.py +++ b/backend/app/core/dishka_lifespan.py @@ -1,3 +1,5 @@ +import asyncio +import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -6,6 +8,7 @@ from pymongo.asynchronous.mongo_client import AsyncMongoClient from app.db.docs import ALL_DOCUMENTS +from app.services.notification_service import NotificationService from app.settings import Settings @@ -13,13 +16,52 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: container = app.state.dishka_container settings = await container.get(Settings) + logger = await container.get(logging.Logger) client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 ) await init_beanie(database=client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + # Start notification service background tasks + notification_service = await container.get(NotificationService) + + async def pending_notification_task() -> None: + """Process pending notifications every 5 seconds.""" + while True: + try: + await asyncio.sleep(5) + await notification_service.process_pending_batch() + except asyncio.CancelledError: + break + except Exception: + logger.exception("Error processing pending notifications") + + async def cleanup_notification_task() -> None: + """Cleanup old notifications every 24 hours.""" + while True: + try: + await asyncio.sleep(86400) # 24 hours + await notification_service.cleanup_old() + except asyncio.CancelledError: + break + except Exception: + logger.exception("Error cleaning up notifications") + + pending_task = asyncio.create_task(pending_notification_task()) + cleanup_task = asyncio.create_task(cleanup_notification_task()) + logger.info("NotificationService background tasks started") + yield + # Shutdown background tasks + pending_task.cancel() + cleanup_task.cancel() + try: + await asyncio.gather(pending_task, cleanup_task) + except asyncio.CancelledError: + pass + logger.info("NotificationService background tasks stopped") + await client.close() await container.close() diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 473e925e..0a73833b 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -1,9 +1,8 @@ -import asyncio import logging from collections.abc import AsyncIterable import redis.asyncio as redis -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka import AIOKafkaProducer from dishka import Provider, Scope, from_context, provide from kubernetes import client as k8s_client from kubernetes import config as k8s_config @@ -43,16 +42,14 @@ from app.db.repositories.user_settings_repository import UserSettingsRepository from app.dlq.manager import DLQManager from app.dlq.models import RetryPolicy, RetryStrategy -from app.domain.enums.kafka import GroupId, KafkaTopic +from app.domain.enums.kafka import KafkaTopic from app.domain.saga.models import SagaConfig from app.events.core import ProducerMetrics, UnifiedProducer from app.events.event_store import EventStore, create_event_store -from app.events.event_store_consumer import EventStoreConsumer +from app.events.event_store_consumer import EventStoreService from app.events.schema.schema_registry import SchemaRegistryManager -from app.infrastructure.kafka.topics import get_all_topics from app.services.admin import AdminEventsService, AdminSettingsService, AdminUserService from app.services.auth_service import AuthService -from app.services.event_bus import EventBus, EventBusEvent from app.services.event_replay.replay_service import EventReplayService from app.services.event_service import EventService from app.services.execution_service import ExecutionService @@ -209,38 +206,22 @@ def get_kafka_producer( ) @provide - async def get_dlq_manager( + def get_dlq_manager( self, kafka_producer: AIOKafkaProducer, settings: Settings, schema_registry: SchemaRegistryManager, logger: logging.Logger, dlq_metrics: DLQMetrics, - ) -> AsyncIterable[DLQManager]: - """Provide DLQManager with DI-managed lifecycle. + ) -> DLQManager: + """Provide DLQManager instance. - Producer lifecycle managed by BoundaryClientProvider. This provider - manages the consumer and background tasks. + Message consumption handled by FastStream subscriber in dlq_processor worker. + Scheduled retries handled by timer in worker lifespan. + Producer lifecycle managed by BoundaryClientProvider. """ - topic_name = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DEAD_LETTER_QUEUE}" - consumer = AIOKafkaConsumer( - topic_name, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.DLQ_MANAGER}.{settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - auto_offset_reset="earliest", - client_id="dlq-manager-consumer", - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - await consumer.start() - - dlq_manager = DLQManager( + return DLQManager( settings=settings, - consumer=consumer, producer=kafka_producer, schema_registry=schema_registry, logger=logger, @@ -249,31 +230,6 @@ async def get_dlq_manager( default_retry_policy=RetryPolicy(topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF), ) - # Background task: process incoming DLQ messages - async def process_messages_loop() -> None: - async for msg in consumer: - await dlq_manager.process_consumer_message(msg) - - # Background task: periodic check for scheduled retries - async def monitor_loop() -> None: - while True: - await dlq_manager.check_scheduled_retries() - await asyncio.sleep(10) - - process_task = asyncio.create_task(process_messages_loop()) - monitor_task = asyncio.create_task(monitor_loop()) - logger.info("DLQ Manager started") - - yield dlq_manager - - # Cleanup - process_task.cancel() - monitor_task.cancel() - await asyncio.gather(process_task, monitor_task, return_exceptions=True) - - await consumer.stop() - logger.info("DLQ Manager stopped") - @provide def get_idempotency_config(self) -> IdempotencyConfig: return IdempotencyConfig() @@ -320,117 +276,23 @@ def get_event_store( ) @provide - async def get_event_store_consumer( + def get_event_store_service( self, event_store: EventStore, - schema_registry: SchemaRegistryManager, - settings: Settings, logger: logging.Logger, event_metrics: EventMetrics, - ) -> AsyncIterable[EventStoreConsumer]: - """Provide EventStoreConsumer with DI-managed lifecycle.""" - topics = get_all_topics() - topic_strings = [f"{settings.KAFKA_TOPIC_PREFIX}{topic}" for topic in topics] - - consumer = AIOKafkaConsumer( - *topic_strings, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"{GroupId.EVENT_STORE_CONSUMER}.{settings.KAFKA_GROUP_SUFFIX}", - enable_auto_commit=False, - max_poll_records=100, - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - fetch_max_wait_ms=5000, - ) + ) -> EventStoreService: + """Provide EventStoreService for event archival. - await consumer.start() - logger.info(f"Event store consumer started for topics: {topic_strings}") - - event_store_consumer = EventStoreConsumer( + Pure storage service - no consumer, no loops. + FastStream subscribers call store_event() to archive events. + """ + return EventStoreService( event_store=event_store, - consumer=consumer, - schema_registry_manager=schema_registry, logger=logger, event_metrics=event_metrics, ) - async def batch_loop() -> None: - while True: - await event_store_consumer.process_batch() - - task = asyncio.create_task(batch_loop()) - - yield event_store_consumer - - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - await consumer.stop() - logger.info("Event store consumer stopped") - - @provide - async def get_event_bus( - self, - kafka_producer: AIOKafkaProducer, - settings: Settings, - logger: logging.Logger, - connection_metrics: ConnectionMetrics, - ) -> AsyncIterable[EventBus]: - """Provide EventBus with DI-managed lifecycle. - - Producer lifecycle managed by BoundaryClientProvider. This provider - manages the consumer and background listener task. - """ - topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" - consumer = AIOKafkaConsumer( - topic, - bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS, - group_id=f"event-bus-{settings.SERVICE_NAME}", - auto_offset_reset="latest", - enable_auto_commit=True, - client_id=f"event-bus-consumer-{settings.SERVICE_NAME}", - session_timeout_ms=settings.KAFKA_SESSION_TIMEOUT_MS, - heartbeat_interval_ms=settings.KAFKA_HEARTBEAT_INTERVAL_MS, - max_poll_interval_ms=settings.KAFKA_MAX_POLL_INTERVAL_MS, - request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, - ) - - await consumer.start() - - event_bus = EventBus( - producer=kafka_producer, - consumer=consumer, - settings=settings, - logger=logger, - connection_metrics=connection_metrics, - ) - - # Create background listener task - async def listener_loop() -> None: - while True: - await event_bus.process_kafka_message() - - listener_task = asyncio.create_task(listener_loop()) - logger.info("Event bus started with Kafka backing") - - yield event_bus - - # Cleanup - listener_task.cancel() - try: - await listener_task - except asyncio.CancelledError: - pass - - await consumer.stop() - logger.info("Event bus stopped") - - class MetricsProvider(Provider): """Provides all metrics instances via DI (no contextvars needed).""" @@ -632,22 +494,27 @@ async def get_user_settings_service( self, repository: UserSettingsRepository, kafka_event_service: KafkaEventService, - event_bus: EventBus, + sse_redis_bus: SSERedisBus, logger: logging.Logger, - ) -> UserSettingsService: - service = UserSettingsService(repository, kafka_event_service, logger, event_bus) + ) -> AsyncIterable[UserSettingsService]: + service = UserSettingsService(repository, kafka_event_service, logger, sse_redis_bus) # Subscribe to settings update events for cross-instance cache invalidation. - # EventBus filters out self-published messages, so this handler only - # runs for events from OTHER instances. - async def _handle_settings_update(evt: EventBusEvent) -> None: - uid = evt.payload.get("user_id") + # Redis pub/sub delivers messages to ALL subscribers including self, + # but cache invalidation is idempotent so that's fine. + async def _handle_settings_update(data: dict[str, object]) -> None: + uid = data.get("user_id") if uid: await service.invalidate_cache(str(uid)) - await event_bus.subscribe("user.settings.updated*", _handle_settings_update) + subscription = await sse_redis_bus.subscribe_internal("user.settings.updated", _handle_settings_update) + await subscription.start() + logger.info("UserSettingsService cache invalidation subscription started") - return service + yield service + + await subscription.close() + logger.info("UserSettingsService cache invalidation subscription stopped") class AdminServicesProvider(Provider): @@ -671,46 +538,26 @@ def get_admin_settings_service( return AdminSettingsService(admin_settings_repository, logger) @provide - async def get_notification_service( + def get_notification_service( self, notification_repository: NotificationRepository, - event_bus: EventBus, sse_redis_bus: SSERedisBus, settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, - ) -> AsyncIterable[NotificationService]: - """Provide NotificationService with DI-managed background tasks.""" - service = NotificationService( + ) -> NotificationService: + """Provide NotificationService instance. + + Background tasks (pending batch processing, cleanup) managed by app lifespan. + """ + return NotificationService( notification_repository=notification_repository, - event_bus=event_bus, sse_bus=sse_redis_bus, settings=settings, logger=logger, notification_metrics=notification_metrics, ) - async def pending_loop() -> None: - while True: - await service.process_pending_batch() - await asyncio.sleep(5) - - async def cleanup_loop() -> None: - while True: - await asyncio.sleep(86400) # 24 hours - await service.cleanup_old() - - pending_task = asyncio.create_task(pending_loop()) - cleanup_task = asyncio.create_task(cleanup_loop()) - logger.info("NotificationService background tasks started") - - yield service - - pending_task.cancel() - cleanup_task.cancel() - await asyncio.gather(pending_task, cleanup_task, return_exceptions=True) - logger.info("NotificationService background tasks stopped") - @provide def get_grafana_alert_processor( self, diff --git a/backend/app/dlq/manager.py b/backend/app/dlq/manager.py index 3e0cd5c9..f49111d7 100644 --- a/backend/app/dlq/manager.py +++ b/backend/app/dlq/manager.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from typing import Callable -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer, ConsumerRecord +from aiokafka import AIOKafkaProducer from opentelemetry.trace import SpanKind from app.core.metrics import DLQMetrics @@ -34,13 +34,13 @@ class DLQManager: """Dead Letter Queue manager - pure logic class. - Lifecycle (start/stop consumer, background tasks) managed by DI provider. + Message consumption handled by FastStream subscriber. + Scheduled retries handled by timer in worker lifespan. """ def __init__( self, settings: Settings, - consumer: AIOKafkaConsumer, producer: AIOKafkaProducer, schema_registry: SchemaRegistryManager, logger: logging.Logger, @@ -58,7 +58,6 @@ def __init__( self.default_retry_policy = default_retry_policy or RetryPolicy( topic="default", strategy=RetryStrategy.EXPONENTIAL_BACKOFF ) - self.consumer: AIOKafkaConsumer = consumer self.producer: AIOKafkaProducer = producer # Topic-specific retry policies @@ -70,42 +69,35 @@ def __init__( self._dlq_events_topic = f"{settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.DLQ_EVENTS}" self._event_metadata = EventMetadata(service_name="dlq-manager", service_version="1.0.0") - def _kafka_msg_to_message(self, msg: ConsumerRecord[bytes, bytes]) -> DLQMessage: - """Parse Kafka ConsumerRecord into DLQMessage.""" - data = json.loads(msg.value) - headers = {k: v.decode() for k, v in (msg.headers or [])} - return DLQMessage(**data, dlq_offset=msg.offset, dlq_partition=msg.partition, headers=headers) + async def process_message(self, message: DLQMessage) -> None: + """Process a typed DLQ message. - async def process_consumer_message(self, msg: ConsumerRecord[bytes, bytes]) -> None: - """Process a single DLQ message. Called by DI provider's background task.""" - try: - start = asyncio.get_running_loop().time() - dlq_msg = self._kafka_msg_to_message(msg) - - # Record metrics - self.metrics.record_dlq_message_received(dlq_msg.original_topic, dlq_msg.event.event_type) - self.metrics.record_dlq_message_age((datetime.now(timezone.utc) - dlq_msg.failed_at).total_seconds()) - - # Process with tracing - ctx = extract_trace_context(dlq_msg.headers) - with get_tracer().start_as_current_span( - name="dlq.consume", - context=ctx, - kind=SpanKind.CONSUMER, - attributes={ - EventAttributes.KAFKA_TOPIC: self.dlq_topic, - EventAttributes.EVENT_TYPE: dlq_msg.event.event_type, - EventAttributes.EVENT_ID: dlq_msg.event.event_id, - }, - ): - await self._process_dlq_message(dlq_msg) - - # Commit and record duration - await self.consumer.commit() - self.metrics.record_dlq_processing_duration(asyncio.get_running_loop().time() - start, "process") + Called by FastStream subscriber handler. Commit handled by FastStream. - except Exception as e: - self.logger.error(f"Error processing DLQ message: {e}") + Args: + message: Typed DLQMessage (deserialized by FastStream/Avro) + """ + start = asyncio.get_running_loop().time() + + # Record metrics + self.metrics.record_dlq_message_received(message.original_topic, message.event.event_type) + self.metrics.record_dlq_message_age((datetime.now(timezone.utc) - message.failed_at).total_seconds()) + + # Process with tracing + ctx = extract_trace_context(message.headers) + with get_tracer().start_as_current_span( + name="dlq.consume", + context=ctx, + kind=SpanKind.CONSUMER, + attributes={ + EventAttributes.KAFKA_TOPIC: self.dlq_topic, + EventAttributes.EVENT_TYPE: message.event.event_type, + EventAttributes.EVENT_ID: message.event.event_id, + }, + ): + await self._process_dlq_message(message) + + self.metrics.record_dlq_processing_duration(asyncio.get_running_loop().time() - start, "process") async def _process_dlq_message(self, message: DLQMessage) -> None: # Apply filters diff --git a/backend/app/domain/enums/kafka.py b/backend/app/domain/enums/kafka.py index 97b5a5a8..6aad2503 100644 --- a/backend/app/domain/enums/kafka.py +++ b/backend/app/domain/enums/kafka.py @@ -50,7 +50,6 @@ class KafkaTopic(StringEnum): # Infrastructure topics DEAD_LETTER_QUEUE = "dead_letter_queue" DLQ_EVENTS = "dlq_events" - EVENT_BUS_STREAM = "event_bus_stream" WEBSOCKET_EVENTS = "websocket_events" diff --git a/backend/app/events/event_store_consumer.py b/backend/app/events/event_store_consumer.py index d7b0497a..b6a0c870 100644 --- a/backend/app/events/event_store_consumer.py +++ b/backend/app/events/event_store_consumer.py @@ -1,75 +1,52 @@ import logging -from aiokafka import AIOKafkaConsumer, ConsumerRecord, TopicPartition from opentelemetry.trace import SpanKind from app.core.metrics import EventMetrics from app.core.tracing.utils import trace_span -from app.domain.enums.kafka import GroupId from app.domain.events.typed import DomainEvent from app.events.event_store import EventStore -from app.events.schema.schema_registry import SchemaRegistryManager -class EventStoreConsumer: - """Consumes events from Kafka and stores them in MongoDB. +class EventStoreService: + """Stores domain events to MongoDB for audit/replay. - Pure logic class - lifecycle managed by DI provider. - Uses Kafka's native batching via getmany(). + Pure storage service - no consumer, no loops. + Called by FastStream subscribers to archive events. """ def __init__( self, event_store: EventStore, - consumer: AIOKafkaConsumer, - schema_registry_manager: SchemaRegistryManager, logger: logging.Logger, event_metrics: EventMetrics, - group_id: GroupId = GroupId.EVENT_STORE_CONSUMER, - batch_size: int = 100, - batch_timeout_ms: int = 5000, ): self.event_store = event_store - self.consumer = consumer - self.group_id = group_id - self.batch_size = batch_size - self.batch_timeout_ms = batch_timeout_ms self.logger = logger self.event_metrics = event_metrics - self.schema_registry_manager = schema_registry_manager - async def process_batch(self) -> None: - """Process a single batch of messages from Kafka. + async def store_event(self, event: DomainEvent, topic: str, consumer_group: str) -> bool: + """Store a single event. Called by FastStream handler. - Called repeatedly by DI provider's background task. + Returns True if stored, False if duplicate/failed. """ - batch_data: dict[TopicPartition, list[ConsumerRecord[bytes, bytes]]] = await self.consumer.getmany( - timeout_ms=self.batch_timeout_ms, - max_records=self.batch_size, - ) - - if not batch_data: - return + with trace_span( + name="event_store_service.store_event", + kind=SpanKind.CONSUMER, + attributes={"event.type": event.event_type, "event.id": event.event_id}, + ): + stored = await self.event_store.store_event(event) - events: list[DomainEvent] = [] - for tp, messages in batch_data.items(): - for msg in messages: - try: - event = await self.schema_registry_manager.deserialize_event(msg.value, msg.topic) - events.append(event) - self.event_metrics.record_kafka_message_consumed( - topic=msg.topic, - consumer_group=str(self.group_id), - ) - except Exception as e: - self.logger.error(f"Failed to deserialize message from {tp}: {e}", exc_info=True) + if stored: + self.event_metrics.record_kafka_message_consumed(topic=topic, consumer_group=consumer_group) + self.logger.debug(f"Stored event {event.event_id}") + else: + self.logger.debug(f"Duplicate event {event.event_id}, skipped") - if events: - await self._store_batch(events) - await self.consumer.commit() + return stored - async def _store_batch(self, events: list[DomainEvent]) -> None: - """Store a batch of events.""" + async def store_batch(self, events: list[DomainEvent]) -> dict[str, int]: + """Store a batch of events. For bulk operations.""" self.logger.info(f"Storing batch of {len(events)} events") with trace_span( @@ -84,3 +61,4 @@ async def _store_batch(self, events: list[DomainEvent]) -> None: f"stored={results['stored']}, duplicates={results['duplicates']}, " f"failed={results['failed']}" ) + return results diff --git a/backend/app/infrastructure/kafka/topics.py b/backend/app/infrastructure/kafka/topics.py index c82ed2c5..be5ae6d8 100644 --- a/backend/app/infrastructure/kafka/topics.py +++ b/backend/app/infrastructure/kafka/topics.py @@ -190,14 +190,6 @@ def get_topic_configs() -> dict[KafkaTopic, dict[str, Any]]: "compression.type": "gzip", }, }, - KafkaTopic.EVENT_BUS_STREAM: { - "num_partitions": 10, - "replication_factor": 1, - "config": { - "retention.ms": "86400000", # 1 day - "compression.type": "gzip", - }, - }, KafkaTopic.WEBSOCKET_EVENTS: { "num_partitions": 5, "replication_factor": 1, diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py deleted file mode 100644 index a96e74de..00000000 --- a/backend/app/services/event_bus.py +++ /dev/null @@ -1,240 +0,0 @@ -import asyncio -import fnmatch -import json -import logging -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Callable -from uuid import uuid4 - -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from aiokafka.errors import KafkaError -from pydantic import BaseModel, ConfigDict - -from app.core.metrics import ConnectionMetrics -from app.domain.enums.kafka import KafkaTopic -from app.settings import Settings - - -class EventBusEvent(BaseModel): - """Represents an event on the event bus.""" - - model_config = ConfigDict(from_attributes=True) - - id: str - event_type: str - timestamp: datetime - payload: dict[str, Any] - - -@dataclass -class Subscription: - """Represents a single event subscription.""" - - id: str = field(default_factory=lambda: str(uuid4())) - pattern: str = "" - handler: Callable[[EventBusEvent], Any] = field(default=lambda _: None) - - -class EventBus: - """Distributed event bus for cross-instance communication via Kafka. - - Pure logic class - lifecycle managed by DI provider. - - Publishers send events to Kafka. Subscribers receive events from OTHER instances - only - self-published messages are filtered out. This design means: - - Publishers should update their own state directly before calling publish() - - Handlers only run for events from other instances (cache invalidation, etc.) - - Supports pattern-based subscriptions using wildcards: - - execution.* - matches all execution events - - execution.123.* - matches all events for execution 123 - - *.completed - matches all completed events - """ - - def __init__( - self, - producer: AIOKafkaProducer, - consumer: AIOKafkaConsumer, - settings: Settings, - logger: logging.Logger, - connection_metrics: ConnectionMetrics, - ) -> None: - self.producer = producer - self.consumer = consumer - self.logger = logger - self.settings = settings - self.metrics = connection_metrics - self._subscriptions: dict[str, Subscription] = {} # id -> Subscription - self._pattern_index: dict[str, set[str]] = {} # pattern -> set of subscription ids - self._lock = asyncio.Lock() - self._topic = f"{self.settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" - self._instance_id = str(uuid4()) # Unique ID for filtering self-published messages - - async def process_kafka_message(self) -> None: - """Process a single message from Kafka consumer. - - Called by DI provider's background task. Filters out self-published messages. - """ - try: - msg = await asyncio.wait_for(self.consumer.getone(), timeout=0.1) - - # Skip messages from this instance - publisher handles its own state - headers = dict(msg.headers) if msg.headers else {} - source = headers.get("source_instance", b"").decode("utf-8") - if source == self._instance_id: - return - - event_dict = json.loads(msg.value.decode("utf-8")) - event = EventBusEvent.model_validate(event_dict) - await self._distribute_event(event.event_type, event) - - except asyncio.TimeoutError: - pass - except KafkaError as e: - self.logger.error(f"Consumer error: {e}") - except Exception as e: - self.logger.error(f"Error processing Kafka message: {e}") - - async def publish(self, event_type: str, data: dict[str, Any]) -> None: - """ - Publish an event to Kafka for cross-instance distribution. - - Local handlers receive events only from OTHER instances via the Kafka listener. - Publishers should update their own state directly before calling publish(). - - Args: - event_type: Event type (e.g., "execution.123.started") - data: Event data payload - """ - event = self._create_event(event_type, data) - try: - value = event.model_dump_json().encode("utf-8") - key = event_type.encode("utf-8") if event_type else None - headers = [("source_instance", self._instance_id.encode("utf-8"))] - - await self.producer.send_and_wait( - topic=self._topic, - value=value, - key=key, - headers=headers, - ) - except Exception as e: - self.logger.error(f"Failed to publish to Kafka: {e}") - - def _create_event(self, event_type: str, data: dict[str, Any]) -> EventBusEvent: - """Create a standardized event object.""" - return EventBusEvent( - id=str(uuid4()), - event_type=event_type, - timestamp=datetime.now(timezone.utc), - payload=data, - ) - - async def subscribe(self, pattern: str, handler: Callable[[EventBusEvent], Any]) -> str: - """ - Subscribe to events matching a pattern. - - Args: - pattern: Event pattern with wildcards (e.g., "execution.*") - handler: Async function to handle matching events - - Returns: - Subscription ID for later unsubscribe - """ - subscription = Subscription(pattern=pattern, handler=handler) - - async with self._lock: - # Store subscription - self._subscriptions[subscription.id] = subscription - - # Update pattern index - if pattern not in self._pattern_index: - self._pattern_index[pattern] = set() - self._pattern_index[pattern].add(subscription.id) - - # Update metrics - self._update_metrics(pattern) - self.metrics.increment_event_bus_subscriptions() - - self.logger.debug(f"Created subscription {subscription.id} for pattern: {pattern}") - return subscription.id - - async def unsubscribe(self, pattern: str, handler: Callable[[EventBusEvent], Any]) -> None: - """Unsubscribe a specific handler from a pattern.""" - async with self._lock: - # Find subscription with matching pattern and handler - for sub_id, subscription in list(self._subscriptions.items()): - if subscription.pattern == pattern and subscription.handler == handler: - await self._remove_subscription(sub_id) - return - - self.logger.warning(f"No subscription found for pattern {pattern} with given handler") - - async def _remove_subscription(self, subscription_id: str) -> None: - """Remove a subscription by ID (must be called within lock).""" - if subscription_id not in self._subscriptions: - self.logger.warning(f"Subscription {subscription_id} not found") - return - - subscription = self._subscriptions[subscription_id] - pattern = subscription.pattern - - # Remove from subscriptions - del self._subscriptions[subscription_id] - - # Update pattern index - if pattern in self._pattern_index: - self._pattern_index[pattern].discard(subscription_id) - if not self._pattern_index[pattern]: - del self._pattern_index[pattern] - - # Update metrics - self._update_metrics(pattern) - self.metrics.decrement_event_bus_subscriptions() - - self.logger.debug(f"Removed subscription {subscription_id} for pattern: {pattern}") - - async def _distribute_event(self, event_type: str, event: EventBusEvent) -> None: - """Distribute event to all matching local subscribers.""" - # Find matching subscriptions - matching_handlers = await self._find_matching_handlers(event_type) - - if not matching_handlers: - return - - # Execute all handlers concurrently - results = await asyncio.gather( - *(self._invoke_handler(handler, event) for handler in matching_handlers), return_exceptions=True - ) - - # Log any errors - for _i, result in enumerate(results): - if isinstance(result, Exception): - self.logger.error(f"Handler failed for event {event_type}: {result}") - - async def _find_matching_handlers(self, event_type: str) -> list[Callable[[EventBusEvent], Any]]: - """Find all handlers matching the event type.""" - async with self._lock: - handlers: list[Callable[[EventBusEvent], Any]] = [] - for pattern, sub_ids in self._pattern_index.items(): - if fnmatch.fnmatch(event_type, pattern): - handlers.extend( - self._subscriptions[sub_id].handler for sub_id in sub_ids if sub_id in self._subscriptions - ) - return handlers - - async def _invoke_handler(self, handler: Callable[[EventBusEvent], Any], event: EventBusEvent) -> None: - """Invoke a single handler, handling both sync and async.""" - if asyncio.iscoroutinefunction(handler): - await handler(event) - else: - await asyncio.to_thread(handler, event) - - def _update_metrics(self, pattern: str) -> None: - """Update metrics for a pattern (must be called within lock).""" - if self.metrics: - count = len(self._pattern_index.get(pattern, set())) - self.metrics.update_event_bus_subscribers(count, pattern) - - diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 45863b75..78173757 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -27,7 +27,6 @@ NotificationValidationError, ) from app.schemas_pydantic.sse import RedisNotificationMessage -from app.services.event_bus import EventBus from app.services.sse.redis_bus import SSERedisBus from app.settings import Settings @@ -98,14 +97,12 @@ class NotificationService: def __init__( self, notification_repository: NotificationRepository, - event_bus: EventBus, sse_bus: SSERedisBus, settings: Settings, logger: logging.Logger, notification_metrics: NotificationMetrics, ) -> None: self.repository = notification_repository - self.event_bus = event_bus self.metrics = notification_metrics self.settings = settings self.sse_bus = sse_bus @@ -211,17 +208,6 @@ async def create_notification( # Save to database notification = await self.repository.create_notification(create_data) - # Publish event - await self.event_bus.publish( - "notifications.created", - { - "notification_id": str(notification.notification_id), - "user_id": user_id, - "severity": str(severity), - "tags": notification.tags, - }, - ) - await self._deliver_notification(notification) return notification @@ -446,12 +432,7 @@ async def mark_as_read(self, user_id: str, notification_id: str) -> bool: """Mark notification as read.""" success = await self.repository.mark_as_read(notification_id, user_id) - if success: - await self.event_bus.publish( - "notifications.read", - {"notification_id": str(notification_id), "user_id": user_id, "read_at": datetime.now(UTC).isoformat()}, - ) - else: + if not success: raise NotificationNotFoundError(notification_id) return True @@ -527,14 +508,7 @@ async def update_subscription( async def mark_all_as_read(self, user_id: str) -> int: """Mark all notifications as read for a user.""" - count = await self.repository.mark_all_as_read(user_id) - - if count > 0: - await self.event_bus.publish( - "notifications.all_read", {"user_id": user_id, "count": count, "read_at": datetime.now(UTC).isoformat()} - ) - - return count + return await self.repository.mark_all_as_read(user_id) async def get_subscriptions(self, user_id: str) -> dict[NotificationChannel, DomainNotificationSubscription]: """Get all notification subscriptions for a user.""" diff --git a/backend/app/services/pod_monitor/config.py b/backend/app/services/pod_monitor/config.py index 44159037..eaf0bc65 100644 --- a/backend/app/services/pod_monitor/config.py +++ b/backend/app/services/pod_monitor/config.py @@ -26,7 +26,6 @@ class PodMonitorConfig: field_selector: str | None = None watch_timeout_seconds: int = 300 # 5 minutes watch_reconnect_delay: int = 5 - max_reconnect_attempts: int = 10 # Monitoring settings enable_metrics: bool = True diff --git a/backend/app/services/pod_monitor/monitor.py b/backend/app/services/pod_monitor/monitor.py index 6cd6ad36..31b67e27 100644 --- a/backend/app/services/pod_monitor/monitor.py +++ b/backend/app/services/pod_monitor/monitor.py @@ -274,22 +274,12 @@ async def _publish_event(self, event: DomainEvent, pod: k8s_client.V1Pod) -> Non self.logger.error(f"Error publishing event: {e}", exc_info=True) async def _backoff(self) -> None: - """Handle watch errors with exponential backoff.""" + """Handle watch errors with exponential backoff (capped, infinite retry).""" self._reconnect_attempts += 1 - if self._reconnect_attempts > self.config.max_reconnect_attempts: - self.logger.error( - f"Max reconnect attempts ({self.config.max_reconnect_attempts}) exceeded" - ) - raise RuntimeError("Max reconnect attempts exceeded") - - # Calculate exponential backoff backoff = min(self.config.watch_reconnect_delay * (2 ** (self._reconnect_attempts - 1)), MAX_BACKOFF_SECONDS) - self.logger.info( - f"Reconnecting watch in {backoff}s " - f"(attempt {self._reconnect_attempts}/{self.config.max_reconnect_attempts})" - ) + self.logger.info(f"Reconnecting watch in {backoff}s (attempt {self._reconnect_attempts})") self._metrics.increment_pod_monitor_watch_reconnects() await asyncio.sleep(backoff) diff --git a/backend/app/services/sse/redis_bus.py b/backend/app/services/sse/redis_bus.py index 3be68c2c..a089c20a 100644 --- a/backend/app/services/sse/redis_bus.py +++ b/backend/app/services/sse/redis_bus.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio +import json import logging -from typing import Type, TypeVar +from collections.abc import Awaitable, Callable import redis.asyncio as redis from pydantic import BaseModel @@ -9,8 +11,6 @@ from app.domain.events.typed import DomainEvent from app.schemas_pydantic.sse import RedisNotificationMessage, RedisSSEMessage -T = TypeVar("T", bound=BaseModel) - class SSERedisSubscription: """Subscription wrapper for Redis pubsub with typed message parsing.""" @@ -20,7 +20,7 @@ def __init__(self, pubsub: redis.client.PubSub, channel: str, logger: logging.Lo self._channel = channel self.logger = logger - async def get(self, model: Type[T]) -> T | None: + async def get[T: BaseModel](self, model: type[T]) -> T | None: """Get next typed message from the subscription.""" msg = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) if not msg or msg.get("type") != "message": @@ -42,7 +42,13 @@ async def close(self) -> None: class SSERedisBus: - """Redis-backed pub/sub bus for SSE event fan-out across workers.""" + """Redis-backed pub/sub bus for SSE event fan-out and internal messaging across workers. + + Supports: + - SSE execution event streaming (publish_event, open_subscription) + - SSE notification streaming (publish_notification, open_notification_subscription) + - Generic internal pub/sub for cross-instance coordination (publish_internal, subscribe_internal) + """ def __init__( self, @@ -50,11 +56,13 @@ def __init__( logger: logging.Logger, exec_prefix: str = "sse:exec:", notif_prefix: str = "sse:notif:", + internal_prefix: str = "internal:", ) -> None: self._redis = redis_client self.logger = logger self._exec_prefix = exec_prefix self._notif_prefix = notif_prefix + self._internal_prefix = internal_prefix def _exec_channel(self, execution_id: str) -> str: return f"{self._exec_prefix}{execution_id}" @@ -87,3 +95,94 @@ async def open_notification_subscription(self, user_id: str) -> SSERedisSubscrip await pubsub.subscribe(channel) await pubsub.get_message(timeout=1.0) return SSERedisSubscription(pubsub, channel, self.logger) + + # --- Internal Pub/Sub for Cross-Instance Coordination --- + + def _internal_channel(self, topic: str) -> str: + return f"{self._internal_prefix}{topic}" + + async def publish_internal(self, topic: str, data: dict[str, object]) -> None: + """Publish an internal event to Redis for cross-instance coordination. + + Unlike Kafka EventBus, this is fire-and-forget with no persistence. + Use for cache invalidation, real-time sync, etc. + + Args: + topic: Event topic (e.g., "user.settings.updated") + data: Event payload + """ + channel = self._internal_channel(topic) + message = json.dumps(data) + await self._redis.publish(channel, message) + self.logger.debug(f"Published internal event to {channel}") + + async def subscribe_internal( + self, + topic: str, + handler: Callable[[dict[str, object]], Awaitable[None]], + ) -> InternalSubscription: + """Subscribe to internal events on a topic. + + Returns a subscription object that must be started and eventually closed. + + Args: + topic: Event topic (e.g., "user.settings.updated") + handler: Async callback for each message + + Returns: + Subscription object with start() and close() methods + """ + pubsub = self._redis.pubsub() + channel = self._internal_channel(topic) + return InternalSubscription(pubsub, channel, handler, self.logger) + + +class InternalSubscription: + """Manages an internal pub/sub subscription with background listener.""" + + def __init__( + self, + pubsub: redis.client.PubSub, + channel: str, + handler: Callable[[dict[str, object]], Awaitable[None]], + logger: logging.Logger, + ) -> None: + self._pubsub = pubsub + self._channel = channel + self._handler = handler + self._logger = logger + self._task: asyncio.Task[None] | None = None + + async def start(self) -> None: + """Start listening for messages.""" + await self._pubsub.subscribe(self._channel) + await self._pubsub.get_message(timeout=1.0) # Consume subscribe confirmation + + async def listener() -> None: + while True: + try: + msg = await self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5) + if msg and msg.get("type") == "message": + data = json.loads(msg["data"]) + await self._handler(data) + except asyncio.CancelledError: + break + except Exception as e: + self._logger.error(f"Error processing internal message on {self._channel}: {e}") + + self._task = asyncio.create_task(listener()) + self._logger.debug(f"Started internal subscription on {self._channel}") + + async def close(self) -> None: + """Stop listening and cleanup.""" + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + try: + await self._pubsub.unsubscribe(self._channel) + finally: + await self._pubsub.aclose() # type: ignore[no-untyped-call] + self._logger.debug(f"Closed internal subscription on {self._channel}") diff --git a/backend/app/services/sse/sse_connection_registry.py b/backend/app/services/sse/sse_connection_registry.py index 575d13dc..4bb7c06c 100644 --- a/backend/app/services/sse/sse_connection_registry.py +++ b/backend/app/services/sse/sse_connection_registry.py @@ -40,10 +40,13 @@ async def register_connection(self, execution_id: str, connection_id: str) -> No async def unregister_connection(self, execution_id: str, connection_id: str) -> None: """Unregister an SSE connection.""" async with self._lock: - if execution_id in self._active_connections: - self._active_connections[execution_id].discard(connection_id) - if not self._active_connections[execution_id]: - del self._active_connections[execution_id] + connections = self._active_connections.get(execution_id) + if connections is None or connection_id not in connections: + return + + connections.remove(connection_id) + if not connections: + del self._active_connections[execution_id] self.logger.debug("Unregistered SSE connection", extra={"connection_id": connection_id}) self.metrics.decrement_sse_connections("executions") diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index 44d69f87..d784f82f 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -16,8 +16,8 @@ DomainUserSettingsChangedEvent, DomainUserSettingsUpdate, ) -from app.services.event_bus import EventBus from app.services.kafka_event_service import KafkaEventService +from app.services.sse.redis_bus import SSERedisBus _settings_adapter = TypeAdapter(DomainUserSettings) _update_adapter = TypeAdapter(DomainUserSettingsUpdate) @@ -29,12 +29,12 @@ def __init__( repository: UserSettingsRepository, event_service: KafkaEventService, logger: logging.Logger, - event_bus: EventBus, + sse_redis_bus: SSERedisBus, ) -> None: self.repository = repository self.event_service = event_service self.logger = logger - self._event_bus = event_bus + self._sse_bus = sse_redis_bus self._cache_ttl = timedelta(minutes=5) self._max_cache_size = 1000 self._cache: TTLCache[str, DomainUserSettings] = TTLCache( @@ -95,7 +95,8 @@ async def update_user_settings( changes_json = _update_adapter.dump_python(updates, exclude_none=True, mode="json") await self._publish_settings_event(user_id, changes_json, reason) - await self._event_bus.publish("user.settings.updated", {"user_id": user_id}) + # Notify other instances to invalidate their caches + await self._sse_bus.publish_internal("user.settings.updated", {"user_id": user_id}) self._add_to_cache(user_id, new_settings) if (await self.repository.count_events_since_snapshot(user_id)) >= 10: diff --git a/backend/tests/integration/core/test_container.py b/backend/tests/integration/core/test_container.py index 85ef5122..965961d1 100644 --- a/backend/tests/integration/core/test_container.py +++ b/backend/tests/integration/core/test_container.py @@ -1,5 +1,4 @@ import pytest -from app.core.database_context import Database from app.services.event_service import EventService from dishka import AsyncContainer @@ -8,12 +7,7 @@ @pytest.mark.asyncio async def test_container_resolves_services(app_container: AsyncContainer, scope: AsyncContainer) -> None: - # Container is the real Dishka container assert isinstance(app_container, AsyncContainer) - # Can resolve core dependencies from DI - db: Database = await scope.get(Database) - assert db.name and isinstance(db.name, str) - svc: EventService = await scope.get(EventService) assert isinstance(svc, EventService) diff --git a/backend/tests/integration/db/repositories/test_admin_settings_repository.py b/backend/tests/integration/db/repositories/test_admin_settings_repository.py index 1f61ce95..db461a5f 100644 --- a/backend/tests/integration/db/repositories/test_admin_settings_repository.py +++ b/backend/tests/integration/db/repositories/test_admin_settings_repository.py @@ -1,5 +1,5 @@ import pytest -from app.core.database_context import Database +from app.db.docs.admin_settings import AuditLogDocument from app.db.repositories.admin.admin_settings_repository import AdminSettingsRepository from app.domain.admin import SystemSettings from dishka import AsyncContainer @@ -26,12 +26,11 @@ async def test_get_system_settings_existing(repo: AdminSettingsRepository) -> No @pytest.mark.asyncio -async def test_update_and_reset_settings(repo: AdminSettingsRepository, db: Database) -> None: +async def test_update_and_reset_settings(repo: AdminSettingsRepository) -> None: s = SystemSettings() updated = await repo.update_system_settings(s, updated_by="admin", user_id="u1") assert isinstance(updated, SystemSettings) - # verify audit log entry exists - assert await db.get_collection("audit_log").count_documents({}) >= 1 + assert await AuditLogDocument.count() >= 1 reset = await repo.reset_system_settings("admin", "u1") assert isinstance(reset, SystemSettings) - assert await db.get_collection("audit_log").count_documents({}) >= 2 + assert await AuditLogDocument.count() >= 2 diff --git a/backend/tests/integration/services/admin/test_admin_user_service.py b/backend/tests/integration/services/admin/test_admin_user_service.py index b9ea3d98..73a0749a 100644 --- a/backend/tests/integration/services/admin/test_admin_user_service.py +++ b/backend/tests/integration/services/admin/test_admin_user_service.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone import pytest -from app.core.database_context import Database +from app.db.docs.user import UserDocument from app.domain.enums.user import UserRole from app.services.admin import AdminUserService from dishka import AsyncContainer @@ -12,18 +12,18 @@ @pytest.mark.asyncio async def test_get_user_overview_basic(scope: AsyncContainer) -> None: svc: AdminUserService = await scope.get(AdminUserService) - db: Database = await scope.get(Database) - await db.get_collection("users").insert_one({ - "user_id": "u1", - "username": "bob", - "email": "b@b.com", - "role": UserRole.USER, - "is_active": True, - "is_superuser": False, - "hashed_password": "h", - "created_at": datetime.now(timezone.utc), - "updated_at": datetime.now(timezone.utc), - }) + user = UserDocument( + user_id="u1", + username="bob", + email="b@b.com", + role=UserRole.USER, + is_active=True, + is_superuser=False, + hashed_password="h", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + await user.insert() overview = await svc.get_user_overview("u1", hours=1) assert overview.user.username == "bob" diff --git a/backend/tests/integration/services/events/test_event_bus.py b/backend/tests/integration/services/events/test_event_bus.py deleted file mode 100644 index 0a0ef543..00000000 --- a/backend/tests/integration/services/events/test_event_bus.py +++ /dev/null @@ -1,52 +0,0 @@ -import asyncio -from datetime import datetime, timezone -from uuid import uuid4 - -import pytest -from aiokafka import AIOKafkaProducer -from app.domain.enums.kafka import KafkaTopic -from app.services.event_bus import EventBus, EventBusEvent -from app.settings import Settings -from dishka import AsyncContainer - -pytestmark = pytest.mark.integration - - -@pytest.mark.asyncio -async def test_event_bus_publish_subscribe(scope: AsyncContainer, test_settings: Settings) -> None: - """Test EventBus receives events from other instances (cross-instance communication).""" - bus: EventBus = await scope.get(EventBus) - - # Future resolves when handler receives the event - no polling needed - received_future: asyncio.Future[EventBusEvent] = asyncio.get_running_loop().create_future() - - async def handler(event: EventBusEvent) -> None: - if not received_future.done(): - received_future.set_result(event) - - await bus.subscribe("test.*", handler) - - # Simulate message from another instance by producing directly to Kafka - event = EventBusEvent( - id=str(uuid4()), - event_type="test.created", - timestamp=datetime.now(timezone.utc), - payload={"x": 1}, - ) - - topic = f"{test_settings.KAFKA_TOPIC_PREFIX}{KafkaTopic.EVENT_BUS_STREAM}" - producer = AIOKafkaProducer(bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS) - await producer.start() - try: - await producer.send_and_wait( - topic=topic, - value=event.model_dump_json().encode("utf-8"), - key=b"test.created", - headers=[("source_instance", b"other-instance")], - ) - finally: - await producer.stop() - - # Await the future directly - true async, no polling - received = await asyncio.wait_for(received_future, timeout=10.0) - assert received.event_type == "test.created" diff --git a/backend/tests/unit/services/pod_monitor/test_monitor.py b/backend/tests/unit/services/pod_monitor/test_monitor.py index 2399c779..ec151cc9 100644 --- a/backend/tests/unit/services/pod_monitor/test_monitor.py +++ b/backend/tests/unit/services/pod_monitor/test_monitor.py @@ -56,15 +56,6 @@ async def test_process_raw_event_invalid_and_backoff(pod_monitor: PodMonitor) -> assert pod_monitor._reconnect_attempts >= 2 -@pytest.mark.asyncio -async def test_backoff_max_attempts(pod_monitor: PodMonitor) -> None: - pod_monitor.config.max_reconnect_attempts = 2 - pod_monitor._reconnect_attempts = 2 - - with pytest.raises(RuntimeError, match="Max reconnect attempts exceeded"): - await pod_monitor._backoff() - - @pytest.mark.asyncio async def test_watch_loop_with_cancellation(pod_monitor: PodMonitor) -> None: pod_monitor.config.enable_state_reconciliation = False diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 2cb8b6b6..ad68f8b5 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -1,19 +1,48 @@ +""" +DLQ Processor Worker using FastStream. + +Processes Dead Letter Queue messages with: +- FastStream subscriber for message consumption (push-based, not polling) +- Timer task in lifespan for scheduled retries +- Dishka DI for dependencies +""" + import asyncio +import json import logging -import signal -from contextlib import AsyncExitStack +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from datetime import datetime, timezone -from app.core.container import create_dlq_processor_container +from app.core.logging import setup_logger +from app.core.providers import ( + BoundaryClientProvider, + EventProvider, + LoggingProvider, + MessagingProvider, + MetricsProvider, + RedisServicesProvider, + RepositoryProvider, + SettingsProvider, +) +from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS from app.dlq import DLQMessage, RetryPolicy, RetryStrategy from app.dlq.manager import DLQManager +from app.domain.enums.kafka import CONSUMER_GROUP_SUBSCRIPTIONS, GroupId +from app.events.core import UnifiedProducer +from app.events.schema.schema_registry import SchemaRegistryManager from app.settings import Settings from beanie import init_beanie +from dishka import make_async_container +from dishka.integrations.faststream import FromDishka, setup_dishka +from faststream import FastStream +from faststream.kafka import KafkaBroker from pymongo.asynchronous.mongo_client import AsyncMongoClient def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> None: + """Configure topic-specific retry policies.""" manager.set_retry_policy( "execution-requests", RetryPolicy( @@ -57,6 +86,7 @@ def _configure_retry_policies(manager: DLQManager, logger: logging.Logger) -> No def _configure_filters(manager: DLQManager, testing: bool, logger: logging.Logger) -> None: + """Configure message filters.""" if not testing: def filter_test_events(message: DLQMessage) -> bool: @@ -73,42 +103,127 @@ def filter_old_messages(message: DLQMessage) -> bool: manager.add_filter(filter_old_messages) -async def main(settings: Settings) -> None: - """Run the DLQ processor. +def main() -> None: + """Entry point for DLQ processor worker. - DLQ lifecycle events (received, retried, discarded) are emitted to the - dlq_events Kafka topic for external observability. Logging is handled - internally by the DLQ manager. + FastStream handles: + - Signal handling (SIGINT/SIGTERM) + - Consumer loop + - Graceful shutdown """ - container = create_dlq_processor_container(settings) - logger = await container.get(logging.Logger) - logger.info("Starting DLQ Processor with DI container...") - - mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( - settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + settings = Settings() + + logger = setup_logger(settings.LOG_LEVEL) + logger.info("Starting DLQ Processor (FastStream)...") + + if settings.ENABLE_TRACING: + init_tracing( + service_name=GroupId.DLQ_PROCESSOR, + settings=settings, + logger=logger, + service_version=settings.TRACING_SERVICE_VERSION, + enable_console_exporter=False, + sampling_rate=settings.TRACING_SAMPLING_RATE, + ) + + # Create DI container + container = make_async_container( + SettingsProvider(), + LoggingProvider(), + BoundaryClientProvider(), + RedisServicesProvider(), + MetricsProvider(), + EventProvider(), + MessagingProvider(), + RepositoryProvider(), + context={Settings: settings}, ) - await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) - - manager = await container.get(DLQManager) - _configure_retry_policies(manager, logger) - _configure_filters(manager, testing=settings.TESTING, logger=logger) + # Build topic and group ID from config + topics = [f"{settings.KAFKA_TOPIC_PREFIX}{t}" for t in CONSUMER_GROUP_SUBSCRIPTIONS[GroupId.DLQ_PROCESSOR]] + group_id = f"{GroupId.DLQ_PROCESSOR}.{settings.KAFKA_GROUP_SUFFIX}" - stop_event = asyncio.Event() - loop = asyncio.get_running_loop() - - def signal_handler() -> None: - logger.info("Received signal, initiating shutdown...") - stop_event.set() - - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, signal_handler) + broker = KafkaBroker( + settings.KAFKA_BOOTSTRAP_SERVERS, + request_timeout_ms=settings.KAFKA_REQUEST_TIMEOUT_MS, + ) - async with AsyncExitStack() as stack: - stack.push_async_callback(container.close) - stack.push_async_callback(mongo_client.close) - await stop_event.wait() + @asynccontextmanager + async def lifespan(app: FastStream) -> AsyncIterator[None]: + """Initialize infrastructure and start scheduled retry timer.""" + app_logger = await container.get(logging.Logger) + app_logger.info("DLQ Processor starting...") + + # Initialize MongoDB + Beanie + mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( + settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 + ) + await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) + + # Resolve schema registry (initialization handled by provider) + await container.get(SchemaRegistryManager) + + # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) + await container.get(UnifiedProducer) + app_logger.info("Kafka producer ready") + + # Get DLQ manager and configure policies + manager = await container.get(DLQManager) + _configure_retry_policies(manager, app_logger) + _configure_filters(manager, testing=settings.TESTING, logger=app_logger) + app_logger.info("DLQ Manager configured") + + # Decoder: JSON bytes → typed DLQMessage + def decode_dlq_json(body: bytes) -> DLQMessage: + data = json.loads(body) + return DLQMessage.model_validate(data) + + # Register subscriber for DLQ messages + @broker.subscriber( + *topics, + group_id=group_id, + auto_commit=False, + decoder=decode_dlq_json, + ) + async def handle_dlq_message( + message: DLQMessage, + dlq_manager: FromDishka[DLQManager], + ) -> None: + """Handle incoming DLQ messages - invoked by FastStream when message arrives.""" + await dlq_manager.process_message(message) + + # Background task: periodic check for scheduled retries + async def retry_checker() -> None: + while True: + try: + await asyncio.sleep(10) + await manager.check_scheduled_retries() + except asyncio.CancelledError: + break + except Exception: + app_logger.exception("Error checking scheduled retries") + + retry_task = asyncio.create_task(retry_checker()) + app_logger.info("DLQ Processor ready, starting event processing...") + + try: + yield + finally: + app_logger.info("DLQ Processor shutting down...") + retry_task.cancel() + try: + await retry_task + except asyncio.CancelledError: + pass + await mongo_client.close() + await container.close() + app_logger.info("DLQ Processor shutdown complete") + + app = FastStream(broker, lifespan=lifespan) + setup_dishka(container=container, app=app, auto_inject=True) + + asyncio.run(app.run()) if __name__ == "__main__": - asyncio.run(main(Settings())) + main() diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index fb7c550e..bde50517 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -1,16 +1,13 @@ """ -Pod Monitor Worker (Simplified). +Pod Monitor Worker. -Note: Unlike other workers, PodMonitor watches Kubernetes pods directly +Unlike other workers, PodMonitor watches Kubernetes pods directly (not consuming Kafka messages), so FastStream's subscriber pattern doesn't apply. - -This version uses a minimal signal handling approach. """ import asyncio import logging import signal -from contextlib import suppress from app.core.logging import setup_logger from app.core.providers import ( @@ -54,44 +51,27 @@ async def run_pod_monitor(settings: Settings) -> None: logger = await container.get(logging.Logger) logger.info("Starting PodMonitor with DI container...") - # Initialize MongoDB + Beanie mongo_client: AsyncMongoClient[dict[str, object]] = AsyncMongoClient( settings.MONGODB_URL, tz_aware=True, serverSelectionTimeoutMS=5000 ) await init_beanie(database=mongo_client[settings.DATABASE_NAME], document_models=ALL_DOCUMENTS) - # Resolve schema registry (initialization handled by provider) await container.get(SchemaRegistryManager) - - # Resolve Kafka producer (lifecycle managed by DI - BoundaryClientProvider starts it) await container.get(UnifiedProducer) logger.info("Kafka producer ready") monitor = await container.get(PodMonitor) - # Signal handling with minimal boilerplate - shutdown = asyncio.Event() + # Signal cancels current task - monitor.run() handles CancelledError gracefully + task = asyncio.current_task() loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, shutdown.set) + loop.add_signal_handler(sig, task.cancel) # type: ignore[union-attr] logger.info("PodMonitor initialized, starting run...") try: - # Run monitor until shutdown - monitor_task = asyncio.create_task(monitor.run()) - shutdown_task = asyncio.create_task(shutdown.wait()) - - done, pending = await asyncio.wait( - [monitor_task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - for task in pending: - task.cancel() - with suppress(asyncio.CancelledError): - await task - + await monitor.run() finally: logger.info("Initiating graceful shutdown...") await mongo_client.close() From 62c375c3b4b7dfa5c9d0293f73ca7baed6f5251d Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 23:01:09 +0100 Subject: [PATCH 13/21] tests fixes --- backend/app/core/providers.py | 27 ++----- backend/app/services/user_settings_service.py | 77 +++++++++---------- .../tests/integration/dlq/test_dlq_manager.py | 53 ++++++------- .../test_user_settings_service.py | 5 +- 4 files changed, 71 insertions(+), 91 deletions(-) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 0a73833b..0aed5f60 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -490,31 +490,18 @@ class UserServicesProvider(Provider): scope = Scope.APP @provide - async def get_user_settings_service( + def get_user_settings_service( self, repository: UserSettingsRepository, kafka_event_service: KafkaEventService, - sse_redis_bus: SSERedisBus, + redis_client: redis.Redis, logger: logging.Logger, - ) -> AsyncIterable[UserSettingsService]: - service = UserSettingsService(repository, kafka_event_service, logger, sse_redis_bus) - - # Subscribe to settings update events for cross-instance cache invalidation. - # Redis pub/sub delivers messages to ALL subscribers including self, - # but cache invalidation is idempotent so that's fine. - async def _handle_settings_update(data: dict[str, object]) -> None: - uid = data.get("user_id") - if uid: - await service.invalidate_cache(str(uid)) - - subscription = await sse_redis_bus.subscribe_internal("user.settings.updated", _handle_settings_update) - await subscription.start() - logger.info("UserSettingsService cache invalidation subscription started") + ) -> UserSettingsService: + """Provide UserSettingsService with Redis-backed cache. - yield service - - await subscription.close() - logger.info("UserSettingsService cache invalidation subscription stopped") + No pub/sub subscription needed - Redis cache is shared across all instances. + """ + return UserSettingsService(repository, kafka_event_service, logger, redis_client) class AdminServicesProvider(Provider): diff --git a/backend/app/services/user_settings_service.py b/backend/app/services/user_settings_service.py index d784f82f..ddecd63b 100644 --- a/backend/app/services/user_settings_service.py +++ b/backend/app/services/user_settings_service.py @@ -1,8 +1,8 @@ import logging -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from typing import Any -from cachetools import TTLCache +import redis.asyncio as redis from pydantic import TypeAdapter from app.db.repositories.user_settings_repository import UserSettingsRepository @@ -17,42 +17,48 @@ DomainUserSettingsUpdate, ) from app.services.kafka_event_service import KafkaEventService -from app.services.sse.redis_bus import SSERedisBus _settings_adapter = TypeAdapter(DomainUserSettings) _update_adapter = TypeAdapter(DomainUserSettingsUpdate) class UserSettingsService: + """User settings service with Redis-backed cache. + + Uses Redis directly as cache (single source of truth across all instances). + No pub/sub invalidation needed - all instances read/write same Redis keys. + """ + + _CACHE_KEY_PREFIX = "user_settings:" + _CACHE_TTL_SECONDS = 300 # 5 minutes + def __init__( self, repository: UserSettingsRepository, event_service: KafkaEventService, logger: logging.Logger, - sse_redis_bus: SSERedisBus, + redis_client: redis.Redis, ) -> None: self.repository = repository self.event_service = event_service self.logger = logger - self._sse_bus = sse_redis_bus - self._cache_ttl = timedelta(minutes=5) - self._max_cache_size = 1000 - self._cache: TTLCache[str, DomainUserSettings] = TTLCache( - maxsize=self._max_cache_size, - ttl=self._cache_ttl.total_seconds(), - ) + self._redis = redis_client self.logger.info( "UserSettingsService initialized", - extra={"cache_ttl_seconds": self._cache_ttl.total_seconds(), "max_cache_size": self._max_cache_size}, + extra={"cache_ttl_seconds": self._CACHE_TTL_SECONDS, "cache_backend": "redis"}, ) + def _cache_key(self, user_id: str) -> str: + return f"{self._CACHE_KEY_PREFIX}{user_id}" + async def get_user_settings(self, user_id: str) -> DomainUserSettings: - """Get settings with cache; rebuild and cache on miss.""" - if user_id in self._cache: - cached = self._cache[user_id] - self.logger.debug(f"Settings cache hit for user {user_id}", extra={"cache_size": len(self._cache)}) - return cached + """Get settings with Redis cache; rebuild and cache on miss.""" + cache_key = self._cache_key(user_id) + cached = await self._redis.get(cache_key) + if cached: + self.logger.debug(f"Settings cache hit for user {user_id}") + return DomainUserSettings.model_validate_json(cached) return await self.get_user_settings_fresh(user_id) @@ -72,13 +78,13 @@ async def get_user_settings_fresh(self, user_id: str) -> DomainUserSettings: for event in events: settings = self._apply_event(settings, event) - self._add_to_cache(user_id, settings) + await self._set_cache(user_id, settings) return settings async def update_user_settings( self, user_id: str, updates: DomainUserSettingsUpdate, reason: str | None = None ) -> DomainUserSettings: - """Upsert provided fields into current settings, publish minimal event, and cache.""" + """Upsert provided fields into current settings, publish event, and update Redis cache.""" current = await self.get_user_settings(user_id) changes = _update_adapter.dump_python(updates, exclude_none=True) @@ -95,10 +101,9 @@ async def update_user_settings( changes_json = _update_adapter.dump_python(updates, exclude_none=True, mode="json") await self._publish_settings_event(user_id, changes_json, reason) - # Notify other instances to invalidate their caches - await self._sse_bus.publish_internal("user.settings.updated", {"user_id": user_id}) + # Update Redis cache directly - all instances see same cache + await self._set_cache(user_id, new_settings) - self._add_to_cache(user_id, new_settings) if (await self.repository.count_events_since_snapshot(user_id)) >= 10: await self.repository.create_snapshot(new_settings) return new_settings @@ -178,7 +183,7 @@ async def restore_settings_to_point(self, user_id: str, timestamp: datetime) -> settings = self._apply_event(settings, event) await self.repository.create_snapshot(settings) - self._add_to_cache(user_id, settings) + await self._set_cache(user_id, settings) await self.event_service.publish_event( event_type=EventType.USER_SETTINGS_UPDATED, @@ -210,22 +215,16 @@ def _apply_event(self, settings: DomainUserSettings, event: DomainUserSettingsCh async def invalidate_cache(self, user_id: str) -> None: """Invalidate cached settings for a user.""" - if self._cache.pop(user_id, None) is not None: - self.logger.debug(f"Invalidated cache for user {user_id}", extra={"cache_size": len(self._cache)}) - - def _add_to_cache(self, user_id: str, settings: DomainUserSettings) -> None: - """Add settings to TTL+LRU cache.""" - self._cache[user_id] = settings - self.logger.debug(f"Cached settings for user {user_id}", extra={"cache_size": len(self._cache)}) - - def get_cache_stats(self) -> dict[str, Any]: - """Get cache statistics for monitoring.""" - return { - "cache_size": len(self._cache), - "max_cache_size": self._max_cache_size, - "expired_entries": 0, - "cache_ttl_seconds": self._cache_ttl.total_seconds(), - } + cache_key = self._cache_key(user_id) + deleted = await self._redis.delete(cache_key) + if deleted: + self.logger.debug(f"Invalidated cache for user {user_id}") + + async def _set_cache(self, user_id: str, settings: DomainUserSettings) -> None: + """Set settings in Redis cache with TTL.""" + cache_key = self._cache_key(user_id) + await self._redis.setex(cache_key, self._CACHE_TTL_SECONDS, settings.model_dump_json()) + self.logger.debug(f"Cached settings for user {user_id}") async def reset_user_settings(self, user_id: str) -> None: """Reset user settings by deleting all data and cache.""" diff --git a/backend/tests/integration/dlq/test_dlq_manager.py b/backend/tests/integration/dlq/test_dlq_manager.py index 8ee6029f..ec67013c 100644 --- a/backend/tests/integration/dlq/test_dlq_manager.py +++ b/backend/tests/integration/dlq/test_dlq_manager.py @@ -1,12 +1,12 @@ import asyncio -import json import logging import uuid from datetime import datetime, timezone import pytest -from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka import AIOKafkaConsumer from app.dlq.manager import DLQManager +from app.dlq.models import DLQMessage from app.domain.enums.events import EventType from app.domain.enums.kafka import KafkaTopic from app.domain.events.typed import DLQMessageReceivedEvent @@ -26,13 +26,28 @@ @pytest.mark.asyncio async def test_dlq_manager_persists_and_emits_event(scope: AsyncContainer, test_settings: Settings) -> None: - """Test that DLQ manager persists messages and emits DLQMessageReceivedEvent.""" + """Test that DLQ manager persists messages and emits DLQMessageReceivedEvent. + + Note: DLQManager is now a simple service (not a DI-started consumer). + Message consumption is handled by the FastStream worker (workers/dlq_processor.py). + This test exercises DLQManager.process_message() directly. + """ schema_registry = await scope.get(SchemaRegistryManager) - await scope.get(DLQManager) # Ensure DI starts the manager + dlq_manager = await scope.get(DLQManager) prefix = test_settings.KAFKA_TOPIC_PREFIX ev = make_execution_requested_event(execution_id=f"exec-dlq-persist-{uuid.uuid4().hex[:8]}") + # Create a typed DLQMessage (as FastStream would deserialize it) + dlq_message = DLQMessage( + event=ev, + original_topic=f"{prefix}{KafkaTopic.EXECUTION_EVENTS}", + error="handler failed", + retry_count=0, + failed_at=datetime.now(timezone.utc), + producer_id="tests", + ) + # Future resolves when DLQMessageReceivedEvent is consumed received_future: asyncio.Future[DLQMessageReceivedEvent] = asyncio.get_running_loop().create_future() @@ -42,7 +57,7 @@ async def test_dlq_manager_persists_and_emits_event(scope: AsyncContainer, test_ dlq_events_topic, bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS, group_id=f"test-dlq-events.{uuid.uuid4().hex[:6]}", - auto_offset_reset="earliest", + auto_offset_reset="latest", enable_auto_commit=True, ) @@ -61,33 +76,15 @@ async def consume_dlq_events() -> None: except Exception as e: _test_logger.debug(f"Error deserializing DLQ event: {e}") - payload = { - "event": ev.model_dump(mode="json"), - "original_topic": f"{prefix}{str(KafkaTopic.EXECUTION_EVENTS)}", - "error": "handler failed", - "retry_count": 0, - "failed_at": datetime.now(timezone.utc).isoformat(), - "producer_id": "tests", - } - - # Produce to DLQ topic BEFORE starting consumers (auto_offset_reset="earliest") - producer = AIOKafkaProducer(bootstrap_servers=test_settings.KAFKA_BOOTSTRAP_SERVERS) - await producer.start() - try: - await producer.send_and_wait( - topic=f"{prefix}{str(KafkaTopic.DEAD_LETTER_QUEUE)}", - key=ev.event_id.encode(), - value=json.dumps(payload).encode(), - ) - finally: - await producer.stop() - - # Start consumer for DLQ events + # Start consumer BEFORE processing (auto_offset_reset="latest") await consumer.start() consume_task = asyncio.create_task(consume_dlq_events()) try: - # Manager is already started by DI - just wait for the event + # Process message directly via DLQManager (simulating FastStream handler) + await dlq_manager.process_message(dlq_message) + + # Wait for the emitted event received = await asyncio.wait_for(received_future, timeout=15.0) assert received.dlq_event_id == ev.event_id assert received.event_type == EventType.DLQ_MESSAGE_RECEIVED diff --git a/backend/tests/integration/services/user_settings/test_user_settings_service.py b/backend/tests/integration/services/user_settings/test_user_settings_service.py index 1acb9d2e..302b3aa3 100644 --- a/backend/tests/integration/services/user_settings/test_user_settings_service.py +++ b/backend/tests/integration/services/user_settings/test_user_settings_service.py @@ -36,11 +36,8 @@ async def test_get_update_and_history(scope: AsyncContainer) -> None: # Restore to current point (no-op but tests snapshot + event publish path) _ = await svc.restore_settings_to_point(user_id, datetime.now(timezone.utc)) - # Update wrappers + cache stats + # Update wrappers await svc.update_theme(user_id, Theme.DARK) await svc.update_notification_settings(user_id, DomainNotificationSettings()) await svc.update_editor_settings(user_id, DomainEditorSettings(tab_size=2)) await svc.update_custom_setting(user_id, "k", "v") - stats = svc.get_cache_stats() - # Cache size may be 0 due to event bus self-invalidation race condition - assert "cache_size" in stats and stats["cache_size"] >= 0 From afdc67223ec55c7576e086541fc4cd7b64a57f5d Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 23:19:59 +0100 Subject: [PATCH 14/21] tests fixes --- backend/app/events/schema/schema_registry.py | 8 +++++--- backend/tests/integration/dlq/test_dlq_manager.py | 2 ++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/backend/app/events/schema/schema_registry.py b/backend/app/events/schema/schema_registry.py index e5d62c1a..077c4991 100644 --- a/backend/app/events/schema/schema_registry.py +++ b/backend/app/events/schema/schema_registry.py @@ -95,9 +95,11 @@ async def serialize_event(self, event: DomainEvent) -> bytes: payload: dict[str, Any] = event.model_dump(mode="python", by_alias=False, exclude_unset=False) payload.pop("event_type", None) - # Convert datetime to microseconds for Avro logical type - if "timestamp" in payload and payload["timestamp"] is not None: - payload["timestamp"] = int(payload["timestamp"].timestamp() * 1_000_000) + # Convert all datetime fields to milliseconds for Avro timestamp-millis logical type + from datetime import datetime + for key, value in payload.items(): + if isinstance(value, datetime): + payload[key] = int(value.timestamp() * 1_000) return await self._serializer.encode_record_with_schema(subject, avro_schema_obj, payload) diff --git a/backend/tests/integration/dlq/test_dlq_manager.py b/backend/tests/integration/dlq/test_dlq_manager.py index ec67013c..0343e363 100644 --- a/backend/tests/integration/dlq/test_dlq_manager.py +++ b/backend/tests/integration/dlq/test_dlq_manager.py @@ -78,6 +78,8 @@ async def consume_dlq_events() -> None: # Start consumer BEFORE processing (auto_offset_reset="latest") await consumer.start() + # Small delay to ensure consumer is fully subscribed and ready + await asyncio.sleep(0.5) consume_task = asyncio.create_task(consume_dlq_events()) try: From 6a3a161b4625ea58f21592fdb16b44891aac5b86 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Wed, 21 Jan 2026 23:47:28 +0100 Subject: [PATCH 15/21] added missing providers --- backend/workers/dlq_processor.py | 2 ++ backend/workers/run_pod_monitor.py | 4 ++++ backend/workers/run_result_processor.py | 2 ++ backend/workers/run_saga_orchestrator.py | 2 ++ 4 files changed, 10 insertions(+) diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index ad68f8b5..c3558186 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -17,6 +17,7 @@ from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, + CoreServicesProvider, EventProvider, LoggingProvider, MessagingProvider, @@ -130,6 +131,7 @@ def main() -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), + CoreServicesProvider(), BoundaryClientProvider(), RedisServicesProvider(), MetricsProvider(), diff --git a/backend/workers/run_pod_monitor.py b/backend/workers/run_pod_monitor.py index bde50517..f8db9609 100644 --- a/backend/workers/run_pod_monitor.py +++ b/backend/workers/run_pod_monitor.py @@ -12,7 +12,9 @@ from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, + CoreServicesProvider, EventProvider, + KafkaServicesProvider, LoggingProvider, MessagingProvider, MetricsProvider, @@ -38,11 +40,13 @@ async def run_pod_monitor(settings: Settings) -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), + CoreServicesProvider(), BoundaryClientProvider(), RedisServicesProvider(), MetricsProvider(), EventProvider(), MessagingProvider(), + KafkaServicesProvider(), RepositoryProvider(), PodMonitorProvider(), context={Settings: settings}, diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index aa5ab748..dbb5f585 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -19,6 +19,7 @@ from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, + CoreServicesProvider, EventProvider, LoggingProvider, MessagingProvider, @@ -81,6 +82,7 @@ def main() -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), + CoreServicesProvider(), BoundaryClientProvider(), RedisServicesProvider(), MetricsProvider(), diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index ddb5ed8c..00a92bcc 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -20,6 +20,7 @@ from app.core.logging import setup_logger from app.core.providers import ( BoundaryClientProvider, + CoreServicesProvider, EventProvider, LoggingProvider, MessagingProvider, @@ -76,6 +77,7 @@ def main() -> None: container = make_async_container( SettingsProvider(), LoggingProvider(), + CoreServicesProvider(), BoundaryClientProvider(), RedisServicesProvider(), MetricsProvider(), From 594b4ad6a5974cdd1a56d1d6f82c687c85d05415 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 22 Jan 2026 00:14:12 +0100 Subject: [PATCH 16/21] lifespan fixes --- backend/workers/dlq_processor.py | 2 +- backend/workers/run_k8s_worker.py | 2 +- backend/workers/run_result_processor.py | 2 +- backend/workers/run_saga_orchestrator.py | 2 +- backend/workers/run_sse_bridge.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index c3558186..0fc89275 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -151,7 +151,7 @@ def main() -> None: ) @asynccontextmanager - async def lifespan(app: FastStream) -> AsyncIterator[None]: + async def lifespan() -> AsyncIterator[None]: """Initialize infrastructure and start scheduled retry timer.""" app_logger = await container.get(logging.Logger) app_logger.info("DLQ Processor starting...") diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 58e5ac6c..70799274 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -99,7 +99,7 @@ def main() -> None: # Create lifespan for infrastructure initialization @asynccontextmanager - async def lifespan(app: FastStream) -> AsyncIterator[None]: + async def lifespan() -> AsyncIterator[None]: """Initialize infrastructure before app starts.""" app_logger = await container.get(logging.Logger) app_logger.info("KubernetesWorker starting...") diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index dbb5f585..2338f186 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -108,7 +108,7 @@ def main() -> None: # Create lifespan for infrastructure initialization @asynccontextmanager - async def lifespan(app: FastStream) -> AsyncIterator[None]: + async def lifespan() -> AsyncIterator[None]: """Initialize infrastructure before app starts.""" app_logger = await container.get(logging.Logger) app_logger.info("ResultProcessor starting...") diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 00a92bcc..f856b6f2 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -102,7 +102,7 @@ def main() -> None: # Create lifespan for infrastructure initialization @asynccontextmanager - async def lifespan(app: FastStream) -> AsyncIterator[None]: + async def lifespan() -> AsyncIterator[None]: """Initialize infrastructure before app starts.""" app_logger = await container.get(logging.Logger) app_logger.info("SagaOrchestrator starting...") diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py index 2a650aa7..a43f4317 100644 --- a/backend/workers/run_sse_bridge.py +++ b/backend/workers/run_sse_bridge.py @@ -90,7 +90,7 @@ def main() -> None: ) @asynccontextmanager - async def lifespan(app: FastStream) -> AsyncIterator[None]: + async def lifespan() -> AsyncIterator[None]: app_logger = await container.get(logging.Logger) app_logger.info("SSE Bridge starting...") From b599e7137b6435786eafed519f455a11aa9f6ec3 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 22 Jan 2026 00:38:02 +0100 Subject: [PATCH 17/21] faststream decode func fix --- backend/workers/dlq_processor.py | 8 +++++--- backend/workers/run_k8s_worker.py | 8 +++++--- backend/workers/run_result_processor.py | 8 +++++--- backend/workers/run_saga_orchestrator.py | 8 +++++--- backend/workers/run_sse_bridge.py | 8 +++++--- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 0fc89275..272fa3b7 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -13,6 +13,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from datetime import datetime, timezone +from typing import Any from app.core.logging import setup_logger from app.core.providers import ( @@ -39,6 +40,7 @@ from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -175,9 +177,9 @@ async def lifespan() -> AsyncIterator[None]: _configure_filters(manager, testing=settings.TESTING, logger=app_logger) app_logger.info("DLQ Manager configured") - # Decoder: JSON bytes → typed DLQMessage - def decode_dlq_json(body: bytes) -> DLQMessage: - data = json.loads(body) + # Decoder: JSON message → typed DLQMessage + def decode_dlq_json(msg: StreamMessage[Any]) -> DLQMessage: + data = json.loads(msg.body) return DLQMessage.model_validate(data) # Register subscriber for DLQ messages diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 70799274..7d02259c 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -15,6 +15,7 @@ import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any from app.core.logging import setup_logger from app.core.providers import ( @@ -43,6 +44,7 @@ from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from faststream.message import StreamMessage def main() -> None: @@ -112,9 +114,9 @@ async def lifespan() -> AsyncIterator[None]: schema_registry = await container.get(SchemaRegistryManager) await logic.ensure_image_pre_puller_daemonset() - # Decoder: Avro bytes → typed DomainEvent - async def decode_avro(body: bytes) -> DomainEvent: - return await schema_registry.deserialize_event(body, "k8s_worker") + # Decoder: Avro message → typed DomainEvent + async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: + return await schema_registry.deserialize_event(msg.body, "k8s_worker") # Create subscriber with Avro decoder subscriber = broker.subscriber( diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 2338f186..ab471f74 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -15,6 +15,7 @@ import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any from app.core.logging import setup_logger from app.core.providers import ( @@ -49,6 +50,7 @@ from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -126,9 +128,9 @@ async def lifespan() -> AsyncIterator[None]: await container.get(UnifiedProducer) app_logger.info("Kafka producer ready") - # Decoder: Avro bytes → typed DomainEvent - async def decode_avro(body: bytes) -> DomainEvent: - return await schema_registry.deserialize_event(body, "result_processor") + # Decoder: Avro message → typed DomainEvent + async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: + return await schema_registry.deserialize_event(msg.body, "result_processor") # Create subscriber with Avro decoder subscriber = broker.subscriber( diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index f856b6f2..aac1a16e 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -16,6 +16,7 @@ import time from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any from app.core.logging import setup_logger from app.core.providers import ( @@ -44,6 +45,7 @@ from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -135,9 +137,9 @@ async def lifespan() -> AsyncIterator[None]: topics = [f"{settings.KAFKA_TOPIC_PREFIX}{t}" for t in trigger_topics] group_id = f"{GroupId.SAGA_ORCHESTRATOR}.{settings.KAFKA_GROUP_SUFFIX}" - # Decoder: Avro bytes → typed DomainEvent - async def decode_avro(body: bytes) -> DomainEvent: - return await schema_registry.deserialize_event(body, "saga_orchestrator") + # Decoder: Avro message → typed DomainEvent + async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: + return await schema_registry.deserialize_event(msg.body, "saga_orchestrator") # Register handler dynamically after determining topics # Saga orchestrator uses single handler - routing is internal to SagaLogic diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py index a43f4317..f155aa09 100644 --- a/backend/workers/run_sse_bridge.py +++ b/backend/workers/run_sse_bridge.py @@ -2,6 +2,7 @@ import logging from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any import redis.asyncio as redis from app.core.logging import setup_logger @@ -24,6 +25,7 @@ from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream from faststream.kafka import KafkaBroker +from faststream.message import StreamMessage class SSEBridgeProvider(Provider): @@ -97,9 +99,9 @@ async def lifespan() -> AsyncIterator[None]: # Resolve schema registry (initialization handled by provider) schema_registry = await container.get(SchemaRegistryManager) - # Decoder: Avro bytes → typed DomainEvent - async def decode_avro(body: bytes) -> DomainEvent: - return await schema_registry.deserialize_event(body, "sse_bridge") + # Decoder: Avro message → typed DomainEvent + async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: + return await schema_registry.deserialize_event(msg.body, "sse_bridge") # Single handler for all SSE-relevant events # No filter needed - we check event_type in handler since route_event handles all types From fd3c18c88d42998abae6c9e9150e99d32db25230 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 22 Jan 2026 00:56:31 +0100 Subject: [PATCH 18/21] removed coordinator block - now load is handled inside kafka/redis --- .github/workflows/stack-tests.yml | 3 -- backend/workers/Dockerfile.coordinator | 8 --- docker-compose.yaml | 26 --------- docs/components/workers/coordinator.md | 74 -------------------------- mkdocs.yml | 1 - 5 files changed, 112 deletions(-) delete mode 100644 backend/workers/Dockerfile.coordinator delete mode 100644 docs/components/workers/coordinator.md diff --git a/.github/workflows/stack-tests.yml b/.github/workflows/stack-tests.yml index 8711ea9f..1afa388c 100644 --- a/.github/workflows/stack-tests.yml +++ b/.github/workflows/stack-tests.yml @@ -139,7 +139,6 @@ jobs: - name: Build all images run: | docker build -t integr8scode-backend:latest --build-context base=docker-image://integr8scode-base:latest -f ./backend/Dockerfile ./backend - docker build -t integr8scode-coordinator:latest -f backend/workers/Dockerfile.coordinator --build-context base=docker-image://integr8scode-base:latest ./backend docker build -t integr8scode-k8s-worker:latest -f backend/workers/Dockerfile.k8s_worker --build-context base=docker-image://integr8scode-base:latest ./backend docker build -t integr8scode-pod-monitor:latest -f backend/workers/Dockerfile.pod_monitor --build-context base=docker-image://integr8scode-base:latest ./backend docker build -t integr8scode-result-processor:latest -f backend/workers/Dockerfile.result_processor --build-context base=docker-image://integr8scode-base:latest ./backend @@ -169,7 +168,6 @@ jobs: run: | docker save \ integr8scode-backend:latest \ - integr8scode-coordinator:latest \ integr8scode-k8s-worker:latest \ integr8scode-pod-monitor:latest \ integr8scode-result-processor:latest \ @@ -319,7 +317,6 @@ jobs: docker compose logs > logs/docker-compose.log 2>&1 docker compose logs backend > logs/backend.log 2>&1 docker compose logs kafka > logs/kafka.log 2>&1 - docker compose logs coordinator > logs/coordinator.log 2>&1 || true docker compose logs k8s-worker > logs/k8s-worker.log 2>&1 || true kubectl get events --sort-by='.metadata.creationTimestamp' -A > logs/k8s-events.log 2>&1 || true diff --git a/backend/workers/Dockerfile.coordinator b/backend/workers/Dockerfile.coordinator deleted file mode 100644 index ae97091b..00000000 --- a/backend/workers/Dockerfile.coordinator +++ /dev/null @@ -1,8 +0,0 @@ -# Coordinator worker -FROM base - -# Copy application code -COPY . . - -# Run the coordinator service -CMD ["python", "-m", "workers.run_coordinator"] diff --git a/docker-compose.yaml b/docker-compose.yaml index ed5a3379..86b75f95 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -409,32 +409,6 @@ services: restart: "no" # Run once and exit # Event-driven workers - coordinator: - build: - context: ./backend - dockerfile: workers/Dockerfile.coordinator - additional_contexts: - base: service:base - container_name: coordinator - depends_on: - base: - condition: service_completed_successfully - kafka-init: - condition: service_completed_successfully - mongo: - condition: service_started - env_file: - - ./backend/.env - environment: - - TRACING_SERVICE_NAME=execution-coordinator - - KAFKA_CONSUMER_GROUP_ID=execution-coordinator - volumes: - - ./backend/app:/app/app:ro - - ./backend/workers:/app/workers:ro - networks: - - app-network - restart: unless-stopped - k8s-worker: build: context: ./backend diff --git a/docs/components/workers/coordinator.md b/docs/components/workers/coordinator.md deleted file mode 100644 index 1020c3b6..00000000 --- a/docs/components/workers/coordinator.md +++ /dev/null @@ -1,74 +0,0 @@ -# Coordinator - -The coordinator owns admission and queuing policy for executions. It decides which executions can proceed based on -available resources and enforces per-user limits to prevent any single user from monopolizing the system. - -```mermaid -graph LR - Kafka[(Kafka)] --> Coord[Coordinator] - Coord --> Queue[Priority Queue] - Coord --> Resources[Resource Pool] - Coord --> Accepted[Accepted Events] - Accepted --> Kafka -``` - -## How it works - -When an `ExecutionRequested` event arrives, the coordinator checks: - -1. Is the queue full? (max 10,000 pending) -2. Has this user exceeded their limit? (max 100 concurrent) -3. Are there enough CPU and memory resources? - -If all checks pass, the coordinator allocates resources and publishes `ExecutionAccepted`. Otherwise, the request -is either queued for later or rejected. - -The coordinator runs a background scheduling loop that continuously pulls from the priority queue and attempts to -schedule pending executions as resources become available. - -## Priority queue - -Executions are processed in priority order. Lower numeric values are processed first: - -```python ---8<-- "backend/app/services/coordinator/queue_manager.py:14:19" -``` - -When resources are unavailable, executions are requeued with reduced priority to prevent starvation. - -## Resource management - -The coordinator tracks a pool of CPU and memory resources: - -| Parameter | Default | Description | -|---------------------------|---------|----------------------------| -| `total_cpu_cores` | 32 | Total CPU pool | -| `total_memory_mb` | 65,536 | Total memory pool (64GB) | -| `overcommit_factor` | 1.2 | Allow 20% overcommit | -| `max_queue_size` | 10,000 | Maximum pending executions | -| `max_executions_per_user` | 100 | Per-user limit | -| `stale_timeout_seconds` | 3,600 | Stale execution timeout | - -## Topics - -- **Consumes**: `execution_events` (requested, completed, failed, cancelled) -- **Produces**: `execution_events` (accepted) - -## Key files - -| File | Purpose | -|--------------------------------------------------------------------------------------------------------------------------------|-------------------------------| -| [`run_coordinator.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/workers/run_coordinator.py) | Entry point | -| [`coordinator.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/coordinator/coordinator.py) | Main coordinator service | -| [`queue_manager.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/coordinator/queue_manager.py) | Priority queue implementation | -| [`resource_manager.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/coordinator/resource_manager.py) | Resource pool and allocation | - -## Deployment - -```yaml -coordinator: - build: - dockerfile: workers/Dockerfile.coordinator -``` - -Usually runs as a single replica. Leader election via Redis is available if scaling is needed. diff --git a/mkdocs.yml b/mkdocs.yml index 3b29cd08..4b32ca58 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,7 +140,6 @@ nav: - K8s Worker: components/workers/k8s_worker.md - Pod Monitor: components/workers/pod_monitor.md - Result Processor: components/workers/result_processor.md - - Coordinator: components/workers/coordinator.md - Event Replay: components/workers/event_replay.md - DLQ Processor: components/workers/dlq_processor.md - Saga: From ef7c044248722104186e58e76b29969b2deaae01 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 22 Jan 2026 01:25:05 +0100 Subject: [PATCH 19/21] updated docs, updated AckPolicy for workers, updated log level for schmea registry - from now warning max --- backend/workers/dlq_processor.py | 3 +- backend/workers/run_k8s_worker.py | 3 +- backend/workers/run_result_processor.py | 3 +- backend/workers/run_saga_orchestrator.py | 3 +- backend/workers/run_sse_bridge.py | 3 +- docker-compose.yaml | 1 + docs/architecture/execution-queue.md | 122 ------------------ docs/architecture/kafka-topic-architecture.md | 16 +-- docs/architecture/lifecycle.md | 6 +- docs/architecture/services-overview.md | 10 +- docs/components/sse/execution-sse-flow.md | 4 +- docs/components/workers/index.md | 86 +++++++++++- docs/index.md | 3 +- docs/operations/deployment.md | 2 +- docs/operations/metrics-reference.md | 23 +--- docs/operations/tracing.md | 2 +- mkdocs.yml | 1 - 17 files changed, 112 insertions(+), 179 deletions(-) delete mode 100644 docs/architecture/execution-queue.md diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 272fa3b7..93bff458 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -39,6 +39,7 @@ from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream +from faststream.broker.message import AckPolicy from faststream.kafka import KafkaBroker from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -186,7 +187,7 @@ def decode_dlq_json(msg: StreamMessage[Any]) -> DLQMessage: @broker.subscriber( *topics, group_id=group_id, - auto_commit=False, + ack_policy=AckPolicy.ACK, decoder=decode_dlq_json, ) async def handle_dlq_message( diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 7d02259c..84e1da3a 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -43,6 +43,7 @@ from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream +from faststream.broker.message import AckPolicy from faststream.kafka import KafkaBroker from faststream.message import StreamMessage @@ -122,7 +123,7 @@ async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: subscriber = broker.subscriber( *topics, group_id=group_id, - auto_commit=False, + ack_policy=AckPolicy.ACK, decoder=decode_avro, ) diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index ab471f74..d1be0842 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -49,6 +49,7 @@ from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream +from faststream.broker.message import AckPolicy from faststream.kafka import KafkaBroker from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -136,7 +137,7 @@ async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: subscriber = broker.subscriber( *topics, group_id=group_id, - auto_commit=False, + ack_policy=AckPolicy.ACK, decoder=decode_avro, ) diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index aac1a16e..5429e227 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -44,6 +44,7 @@ from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream +from faststream.broker.message import AckPolicy from faststream.kafka import KafkaBroker from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient @@ -146,7 +147,7 @@ async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: @broker.subscriber( *topics, group_id=group_id, - auto_commit=False, + ack_policy=AckPolicy.ACK, decoder=decode_avro, ) async def handle_saga_event( diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py index f155aa09..48e23887 100644 --- a/backend/workers/run_sse_bridge.py +++ b/backend/workers/run_sse_bridge.py @@ -24,6 +24,7 @@ from dishka import Provider, Scope, make_async_container, provide from dishka.integrations.faststream import FromDishka, setup_dishka from faststream import FastStream +from faststream.broker.message import AckPolicy from faststream.kafka import KafkaBroker from faststream.message import StreamMessage @@ -108,7 +109,7 @@ async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: @broker.subscriber( *topics, group_id=group_id, - auto_commit=True, # SSE bridge is idempotent (Redis pubsub) + ack_policy=AckPolicy.ACK_FIRST, # SSE bridge is idempotent (Redis pubsub) decoder=decode_avro, ) async def handle_sse_event( diff --git a/docker-compose.yaml b/docker-compose.yaml index 86b75f95..c63bf097 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -339,6 +339,7 @@ services: SCHEMA_REGISTRY_HOST_NAME: schema-registry SCHEMA_REGISTRY_KAFKASTORE_BOOTSTRAP_SERVERS: kafka:29092 SCHEMA_REGISTRY_LISTENERS: http://0.0.0.0:8081 + SCHEMA_REGISTRY_LOG4J_ROOT_LOGLEVEL: WARN networks: - app-network healthcheck: diff --git a/docs/architecture/execution-queue.md b/docs/architecture/execution-queue.md deleted file mode 100644 index 876ff6df..00000000 --- a/docs/architecture/execution-queue.md +++ /dev/null @@ -1,122 +0,0 @@ -# Execution Queue - -The ExecutionCoordinator manages a priority queue for script executions, allocating CPU and memory resources before -spawning pods. It consumes `ExecutionRequested` events, validates resource availability, and emits commands to the -Kubernetes worker via the saga system. Per-user limits and stale timeout handling prevent queue abuse. - -## Architecture - -```mermaid -flowchart TB - subgraph Kafka - REQ[ExecutionRequested Event] --> COORD[ExecutionCoordinator] - COORD --> CMD[CreatePodCommand] - RESULT[Completed/Failed Events] --> COORD - end - - subgraph Coordinator - COORD --> QUEUE[QueueManager] - COORD --> RESOURCES[ResourceManager] - QUEUE --> HEAP[(Priority Heap)] - RESOURCES --> POOL[(Resource Pool)] - end - - subgraph Scheduling Loop - LOOP[Get Next Execution] --> CHECK{Resources Available?} - CHECK -->|Yes| ALLOCATE[Allocate Resources] - CHECK -->|No| REQUEUE[Requeue Execution] - ALLOCATE --> PUBLISH[Publish CreatePodCommand] - end -``` - -## Queue Priority - -Executions enter the queue with one of five priority levels. Lower numeric values are processed first: - -```python ---8<-- "backend/app/services/coordinator/queue_manager.py:14:19" -``` - -The queue uses Python's `heapq` module, which efficiently maintains the priority ordering. When resources are -unavailable, executions are requeued with reduced priority to prevent starvation of lower-priority work. - -## Per-User Limits - -The queue enforces per-user execution limits to prevent a single user from monopolizing resources: - -```python ---8<-- "backend/app/services/coordinator/queue_manager.py:42:54" -``` - -When a user exceeds their limit, new execution requests are rejected with an error message indicating the limit has been -reached. - -## Stale Timeout - -Executions that sit in the queue too long (default 1 hour) are automatically removed by a background cleanup task. This -prevents abandoned requests from consuming queue space indefinitely: - -```python ---8<-- "backend/app/services/coordinator/queue_manager.py:243:267" -``` - -## Resource Allocation - -The ResourceManager tracks a pool of CPU, memory, and GPU resources. Each execution requests an allocation based on -language defaults or explicit requirements: - -```python ---8<-- "backend/app/services/coordinator/resource_manager.py:121:130" -``` - -The pool maintains minimum reserve thresholds to ensure the system remains responsive even under heavy load. Allocations -that would exceed the safe threshold are rejected, and the execution is requeued for later processing. - -```python ---8<-- "backend/app/services/coordinator/resource_manager.py:135:148" -``` - -## Scheduling Loop - -The coordinator runs a background scheduling loop that continuously pulls executions from the queue and attempts to -schedule them: - -```python ---8<-- "backend/app/services/coordinator/coordinator.py:307:323" -``` - -A semaphore limits concurrent scheduling operations to prevent overwhelming the system during bursts of incoming -requests. - -## Event Flow - -The coordinator handles several event types: - -1. **ExecutionRequested** - Adds execution to queue, publishes `ExecutionAccepted` -2. **ExecutionCancelled** - Removes from queue, releases resources if allocated -3. **ExecutionCompleted** - Releases allocated resources -4. **ExecutionFailed** - Releases allocated resources - -When scheduling succeeds, the coordinator publishes a `CreatePodCommand` to the saga topic, triggering pod creation by -the Kubernetes worker. - -## Configuration - -| Parameter | Default | Description | -|-------------------------------|---------|--------------------------------------| -| `max_queue_size` | 10000 | Maximum executions in queue | -| `max_executions_per_user` | 100 | Per-user queue limit | -| `stale_timeout_seconds` | 3600 | When to discard old executions | -| `max_concurrent_scheduling` | 10 | Parallel scheduling operations | -| `scheduling_interval_seconds` | 0.5 | Polling interval when queue is empty | -| `total_cpu_cores` | 32.0 | Total CPU pool | -| `total_memory_mb` | 65536 | Total memory pool (64GB) | -| `overcommit_factor` | 1.2 | Allow 20% resource overcommit | - -## Key Files - -| File | Purpose | -|--------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------| -| [`services/coordinator/coordinator.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/coordinator/coordinator.py) | Main coordinator service | -| [`services/coordinator/queue_manager.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/coordinator/queue_manager.py) | Priority queue implementation | -| [`services/coordinator/resource_manager.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/services/coordinator/resource_manager.py) | Resource pool and allocation | diff --git a/docs/architecture/kafka-topic-architecture.md b/docs/architecture/kafka-topic-architecture.md index ac948c3c..065e81c1 100644 --- a/docs/architecture/kafka-topic-architecture.md +++ b/docs/architecture/kafka-topic-architecture.md @@ -15,21 +15,21 @@ The system uses *two separate Kafka topics* for execution flow: `execution_event Multiple services consume this topic: SSE streams updates to users, projection service maintains read-optimized views, saga orchestrator manages workflows, monitoring tracks health. These consumers care about *completeness and ordering* because they're building a comprehensive picture of system state. -**execution_tasks** is a *work queue*. It contains only events representing actual work to be done — executions that have been validated, authorized, rate-limited, and scheduled. When the coordinator publishes to `execution_tasks`, it's saying "this needs to be done now" rather than "this happened." The Kubernetes worker, the *sole consumer* of this topic, just needs to know what pods to create. +**execution_tasks** is a *work queue*. It contains only events representing actual work to be done — executions that have been validated, authorized, rate-limited, and scheduled. When the saga orchestrator publishes to `execution_tasks`, it's saying "this needs to be done now" rather than "this happened." The Kubernetes worker, the *sole consumer* of this topic, just needs to know what pods to create. ## Request flow When a user submits code, the API creates an `ExecutionRequestedEvent` and publishes it to `execution_events`. This acknowledges the request and makes it part of the permanent record. -The coordinator subscribes to `execution_events` and begins validation: +The saga orchestrator subscribes to `execution_events` and begins processing: -- Has the user exceeded their rate limit? +- Has the execution passed rate limiting (handled at API level)? - Are sufficient resources available? -- Should this execution be prioritized or queued? +- What state is the execution workflow in? -Some requests get rejected immediately. Others sit in a priority queue waiting for resources. Still others get cancelled before starting. +Some requests get rejected at the API level. Others proceed through the saga workflow. The saga orchestrator tracks state transitions and issues commands. -Only when the coordinator determines an execution is *ready to proceed* does it republish to `execution_tasks`. This represents a state transition — the event has moved from being a request to being *scheduled work*. +Only when the saga orchestrator determines an execution is *ready to proceed* does it publish to `execution_tasks`. This represents a state transition — the event has moved from being a request to being *scheduled work*. The Kubernetes worker then consumes from `execution_tasks`, creates resources (ConfigMaps, Pods), and publishes a `PodCreatedEvent` back to `execution_events`. It doesn't need to know about rate limits or queuing — all that complexity has been handled upstream. @@ -62,11 +62,11 @@ Monitoring becomes more precise. Different SLAs for different stages: ## Failure handling -If the Kubernetes worker crashes, `execution_tasks` accumulates messages but the rest of the system continues normally. Users can submit executions, the coordinator validates and queues them, other services process `execution_events`. When the worker recovers, it picks up where it left off. +If the Kubernetes worker crashes, `execution_tasks` accumulates messages but the rest of the system continues normally. Users can submit executions, the saga orchestrator processes events, other services consume `execution_events`. When the worker recovers, it picks up where it left off. In a single-topic architecture, a slow worker would cause backpressure affecting *all* consumers. SSE might delay updates. Projections might fall behind. The entire system degrades because one component can't keep up. -The coordinator acts as a *shock absorber* between user requests and pod creation. It can implement queuing, prioritization, and resource management without affecting upstream producers or downstream workers. During cluster capacity issues, the coordinator holds executions in its internal queue while still acknowledging receipt. +The saga orchestrator acts as a *shock absorber* between user requests and pod creation. It manages workflow state and resource coordination without affecting upstream producers or downstream workers. Kafka provides natural backpressure and queuing, while Redis-backed rate limiting handles capacity management at the API level. ## Extensibility diff --git a/docs/architecture/lifecycle.md b/docs/architecture/lifecycle.md index 21b3aef9..7ba2550b 100644 --- a/docs/architecture/lifecycle.md +++ b/docs/architecture/lifecycle.md @@ -8,11 +8,11 @@ The pattern that actually fits Python and asyncio is the language's own RAII: as ## What changed -Services with long-running background work now implement the async context manager protocol. Coordinator, KubernetesWorker, PodMonitor, SSE Kafka→Redis bridge, EventStoreConsumer, ResultProcessor, DLQManager, EventBus, and the Kafka producer all expose `__aenter__`/`__aexit__` that call `start`/`stop`. +Services with long-running background work now implement the async context manager protocol. KubernetesWorker, PodMonitor, SSE Kafka→Redis bridge, EventStoreConsumer, ResultProcessor, DLQManager, EventBus, and the Kafka producer all expose `__aenter__`/`__aexit__` that call `start`/`stop`. DI providers return unstarted instances for these services. The FastAPI lifespan acquires them and uses an `AsyncExitStack` to start/stop them in a single place. That removed scattered start/stop logic from providers and made shutdown order explicit. -Worker entrypoints (coordinator, k8s-worker, pod-monitor, event-replay, result-processor, dlq-processor) use `AsyncExitStack` as well. No more `if 'x' in locals()` cleanups or nested with statements. Each runner acquires the services it needs, enters them in the stack, and blocks. When it's time to exit, everything stops in reverse order. +Worker entrypoints (saga-orchestrator, k8s-worker, pod-monitor, event-replay, result-processor, dlq-processor, sse-bridge) use `AsyncExitStack` as well. No more `if 'x' in locals()` cleanups or nested with statements. Each runner acquires the services it needs, enters them in the stack, and blocks. When it's time to exit, everything stops in reverse order. ## Why this is better @@ -29,7 +29,7 @@ Use an `AsyncExitStack` at the call site: ```python async with AsyncExitStack() as stack: await stack.enter_async_context(producer) - await stack.enter_async_context(coordinator) + await stack.enter_async_context(consumer) # add more services as needed await asyncio.Event().wait() ``` diff --git a/docs/architecture/services-overview.md b/docs/architecture/services-overview.md index bce59980..8b5620ff 100644 --- a/docs/architecture/services-overview.md +++ b/docs/architecture/services-overview.md @@ -4,7 +4,7 @@ This document explains what lives under `backend/app/services/`, what each servi ## High-level architecture -The API (FastAPI) receives user requests for auth, execute, events, scripts, and settings. The Coordinator accepts validated execution requests and enqueues them to Kafka with metadata and idempotency guards. The Saga Orchestrator drives stateful execution via events and publishes commands to the K8s Worker. The K8s Worker builds and creates per-execution pods and supporting ConfigMaps with network isolation enforced at cluster level via Cilium policy. Pod Monitor watches K8s and translates pod phases and logs into domain events. Result Processor consumes completion/failure/timeout events, updates DB, and cleans resources. SSE Router fans execution events out to connected clients. DLQ Processor and Event Replay support reliability and investigations. +The API (FastAPI) receives user requests for auth, execute, events, scripts, and settings. The Saga Orchestrator drives stateful execution via events and publishes commands to the K8s Worker. The K8s Worker builds and creates per-execution pods and supporting ConfigMaps with network isolation enforced at cluster level via Cilium policy. Pod Monitor watches K8s and translates pod phases and logs into domain events. Result Processor consumes completion/failure/timeout events, updates DB, and cleans resources. SSE Router fans execution events out to connected clients. DLQ Processor and Event Replay support reliability and investigations. ## Event streams @@ -12,8 +12,6 @@ EXECUTION_EVENTS carries lifecycle updates like queued, started, running, and ca ## Execution pipeline services -The coordinator/ module contains QueueManager which maintains an in-memory view of pending executions with priorities, aging, and backpressure. It doesn't own metrics for queue depth (that's centralized in coordinator metrics) and doesn't publish commands directly, instead emitting events for the Saga Orchestrator to process. This provides fairness, limits, and stale-job cleanup in one place while preventing double publications. - The saga/ module has ExecutionSaga which encodes the multi-step execution flow from receiving a request through creating a pod command, observing pod outcomes, and committing the result. The Saga Orchestrator subscribes to EXECUTION events, reconstructs sagas, and issues SAGA_COMMANDS to the worker with goals of idempotency across restarts, clean compensation on failure, and avoiding duplicate side-effects. The k8s_worker/ module runs worker.py, a long-running service that consumes SAGA_COMMANDS and creates per-execution resources including ConfigMaps for script and entrypoint, and Pod manifests with hardened security context. It no longer creates per-execution NetworkPolicies since network isolation is managed by a static Cilium policy in the target namespace, and it refuses to run in the default namespace to avoid policy gaps. The pod_builder.py produces ConfigMaps and V1Pod specs with non-root user, read-only root FS, all capabilities dropped, seccomp RuntimeDefault, DNS disabled, and no service links or tokens. @@ -60,8 +58,6 @@ The Result Processor persists terminal execution outcomes, updates metrics, and The Pod Monitor observes K8s pod state and translates to domain events. It watches CoreV1 Pod events and publishes EXECUTION_EVENTS for running, container started, logs tail, etc., adding useful metadata and best-effort failure analysis. -The Coordinator owns the admission/queuing policy, sets priorities, and gates starts based on capacity. It interacts with ExecutionService (API) and Saga Orchestrator (events), ensuring queue depth metrics reflect only user requests and avoiding negative values via single ownership of the counter. - The Event Replay worker re-emits stored events to debug or rebuild projections, taking DB/event store and filters as inputs and outputting replayed events on regular topics with provenance markers. The DLQ Processor drains and retries dead-lettered messages with backoff and visibility, taking DLQ topic and retry policies as inputs and outputting successful re-publishes or parked messages with audit trail. See [Dead Letter Queue](../components/dead-letter-queue.md) for more on DLQ handling. @@ -70,11 +66,11 @@ The DLQ Processor drains and retries dead-lettered messages with backoff and vis The worker refuses to run in the default namespace. Use the setup script to apply the Cilium policy in a dedicated namespace and run the worker there. Apply `backend/k8s/policies/executor-deny-all-cnp.yaml` or use `scripts/setup_k8s.sh `. All executor pods are labeled `app=integr8s, component=executor` and are covered by the static deny-all policy. See [Security Policies](../security/policies.md) for details on network isolation. -Sagas and consumers use content-hash keys by default to avoid duplicates on restarts. Coordinator centralizes queue depth metrics, Result Processor normalizes error types, and Rate Limit service emits rich diagnostics even when disabled. +Sagas and consumers use content-hash keys by default to avoid duplicates on restarts. Result Processor normalizes error types, and Rate Limit service emits rich diagnostics even when disabled. ## Common flows -The main execution flow goes: User → API → Coordinator → Saga Orchestrator → K8s Worker → Pod → Pod Monitor → Result Processor. See [Lifecycle](lifecycle.md) for the full execution state machine. +The main execution flow goes: User → API → Saga Orchestrator → K8s Worker → Pod → Pod Monitor → Result Processor. See [Lifecycle](lifecycle.md) for the full execution state machine. For executing a script, a POST to `/api/v1/execute` triggers validation and enqueues EXECUTION_REQUESTED. The Saga issues CreatePodCommandEvent, the Worker creates ConfigMap and Pod, Pod Monitor emits running/progress events, and Result Processor persists the outcome and triggers cleanup on completion, failure, or timeout. diff --git a/docs/components/sse/execution-sse-flow.md b/docs/components/sse/execution-sse-flow.md index a81f057d..3cced307 100644 --- a/docs/components/sse/execution-sse-flow.md +++ b/docs/components/sse/execution-sse-flow.md @@ -2,13 +2,13 @@ The system uses an event-driven pipeline to run code and stream progress to the browser with Server-Sent Events. The Editor subscribes to a per-execution SSE stream and renders updates as the execution advances. When the result is ready, the stream delivers a `result_stored` event that already carries the final payload, so the browser can render immediately. -The flow starts when the API receives a request to execute a script. It writes an execution record in MongoDB and publishes an `execution_requested` event to Kafka. The coordinator and saga orchestrator pick it up, allocate resources, and send a command to the Kubernetes worker. The worker creates the pod and emits `pod_created`/`pod_running`. The pod monitor watches Kubernetes and emits pod and terminal execution events. The result processor listens for terminal events, writes the final state into the executions collection, and publishes a `result_stored` event. +The flow starts when the API receives a request to execute a script. It writes an execution record in MongoDB and publishes an `execution_requested` event to Kafka. The saga orchestrator picks it up, validates resources, and sends a command to the Kubernetes worker. The worker creates the pod and emits `pod_created`/`pod_running`. The pod monitor watches Kubernetes and emits pod and terminal execution events. The result processor listens for terminal events, writes the final state into the executions collection, and publishes a `result_stored` event. The SSE router maintains a small pool of Kafka consumers and routes only the events that belong to a given execution into Redis, which backs the browser's SSE connection. Progress events like `pod_running` arrive quickly and render status changes without polling. The stream closes when `result_stored` is delivered — emitted only after the result processor has written the final output to MongoDB, with the event containing the final result payload. Using `result_stored` as the terminal signal removes artificial waiting. Earlier iterations ended the SSE stream on `execution_completed`/`failed`/`timeout` and slept on the server to "give Mongo time" to commit. That pause is unnecessary once the stream ends only after the result processor confirms persistence. -This approach preserves clean attribution and ordering. The coordinator enriches pod creation commands with user information so pods are labeled correctly. The pod monitor converts Kubernetes phases into domain events. Timeout classification is deterministic: any pod finishing with `reason=DeadlineExceeded` results in an `execution_timeout` event. The result processor is the single writer of terminal state, so the UI never races the database — when the browser sees `result_stored`, the result is already present. +This approach preserves clean attribution and ordering. The saga orchestrator enriches pod creation commands with user information so pods are labeled correctly. The pod monitor converts Kubernetes phases into domain events. Timeout classification is deterministic: any pod finishing with `reason=DeadlineExceeded` results in an `execution_timeout` event. The result processor is the single writer of terminal state, so the UI never races the database — when the browser sees `result_stored`, the result is already present. ## Related docs diff --git a/docs/components/workers/index.md b/docs/components/workers/index.md index fe14d663..2dafd272 100644 --- a/docs/components/workers/index.md +++ b/docs/components/workers/index.md @@ -10,8 +10,6 @@ MongoDB and Redis provide shared state where needed. ```mermaid graph LR API[Backend API] -->|execution_requested| Kafka[(Kafka)] - Kafka --> Coord[Coordinator] - Coord -->|execution_accepted| Kafka Kafka --> Saga[Saga Orchestrator] Saga -->|create_pod_command| Kafka Kafka --> K8s[K8s Worker] @@ -21,22 +19,97 @@ graph LR PodMon -->|execution_completed| Kafka Kafka --> Result[Result Processor] Result --> Mongo[(MongoDB)] + Kafka --> SSE[SSE Bridge] + SSE --> Redis[(Redis)] ``` ## The workers | Worker | What it does | Entry point | |-------------------------------------------|-----------------------------------------------------------|----------------------------| -| [Coordinator](coordinator.md) | Admits executions, manages the queue, allocates resources | `run_coordinator.py` | | [Saga Orchestrator](saga_orchestrator.md) | Drives the execution state machine, issues pod commands | `run_saga_orchestrator.py` | | [K8s Worker](k8s_worker.md) | Creates ConfigMaps and Pods with security hardening | `run_k8s_worker.py` | | [Pod Monitor](pod_monitor.md) | Watches pods, translates K8s events to domain events | `run_pod_monitor.py` | | [Result Processor](result_processor.md) | Persists execution results, cleans up resources | `run_result_processor.py` | +| SSE Bridge | Routes execution events from Kafka to Redis for SSE | `run_sse_bridge.py` | | [Event Replay](event_replay.md) | Re-emits historical events for debugging | `run_event_replay.py` | | [DLQ Processor](dlq_processor.md) | Retries failed messages from the dead letter queue | `dlq_processor.py` | All entry points live in [`backend/workers/`](https://github.com/HardMax71/Integr8sCode/tree/main/backend/workers). +## FastStream framework + +All Kafka-consuming workers are built on [FastStream](https://faststream.ag2.ai/), an asynchronous Python framework for +building event-driven microservices. FastStream provides: + +- **Declarative subscribers** — define handlers with `@broker.subscriber()` decorators +- **Automatic serialization** — Avro/JSON encoding handled transparently via custom decoders +- **Dependency injection** — integrated with [Dishka](https://dishka.readthedocs.io/) for clean DI +- **Lifespan management** — startup/shutdown hooks for resource initialization +- **Acknowledgement policies** — fine-grained control over message processing guarantees + +### Message acknowledgement policies + +FastStream's `AckPolicy` controls when Kafka offsets are committed, determining message delivery semantics. Each policy +offers different trade-offs between throughput, reliability, and complexity: + +| Policy | On Success | On Error | Delivery Guarantee | Use Case | +|---------------------|---------------------------|-----------------------------|--------------------|---------------------------------------------| +| `ACK_FIRST` | Commit before processing | Already committed | At most once | High throughput, idempotent operations | +| `ACK` | Commit after processing | Commit anyway (no retry) | At least once | Reliable processing without retry needs | +| `NACK_ON_ERROR` | Commit after processing | Seek back for redelivery | At least once | Auto-retry on transient failures | +| `REJECT_ON_ERROR` | Commit after processing | Commit (discard message) | At least once | Permanent failures, no retry desired | +| `MANUAL` | User calls `msg.ack()` | User calls `msg.nack()` | User-controlled | Complex conditional acknowledgement | + +!!! note "Kafka-specific behavior" + Unlike RabbitMQ or NATS, Kafka lacks native message rejection. `REJECT_ON_ERROR` behaves identically to `ACK` — + the offset is committed regardless, and the message won't be redelivered. Use a Dead Letter Queue for failed + messages that need investigation. + +### Worker acknowledgement configuration + +Each worker uses the policy best suited to its reliability requirements: + +| Worker | Policy | Rationale | +|---------------------|--------------|----------------------------------------------------------------------------------------| +| Saga Orchestrator | `ACK` | Saga state is persisted to MongoDB; duplicate processing is idempotent | +| K8s Worker | `ACK` | Pod creation is idempotent (same execution ID); uses idempotency middleware | +| Result Processor | `ACK` | Result persistence is idempotent; duplicate writes are safe | +| SSE Bridge | `ACK_FIRST` | Redis pubsub is fire-and-forget; missing an event causes a client retry, not data loss | +| DLQ Processor | `ACK` | DLQ messages are tracked in MongoDB; safe to acknowledge after processing | + +Example subscriber configuration: + +```python +from faststream.broker.message import AckPolicy + +@broker.subscriber( + *topics, + group_id="saga-orchestrator", + ack_policy=AckPolicy.ACK, + decoder=decode_avro, +) +async def handle_event(event: DomainEvent) -> None: + await saga_logic.process(event) +``` + +### Exception-based flow control + +FastStream supports interrupting message processing at any call stack level: + +```python +from faststream.exceptions import AckMessage, NackMessage + +async def deep_processing_function(): + if should_skip: + raise AckMessage() # Acknowledge and stop processing + if should_retry: + raise NackMessage() # Reject for redelivery (NATS/RabbitMQ only) +``` + +For Kafka, `NackMessage` has no effect since Kafka doesn't support message-level rejection. Use `NACK_ON_ERROR` policy +instead for automatic retry behavior. + ## Running locally Docker Compose starts everything: @@ -49,11 +122,10 @@ For debugging a specific worker, run it directly: ```bash cd backend -python -m workers.run_coordinator +python -m workers.run_saga_orchestrator ``` ## Scaling -Most workers can run as single replicas. The stateful ones (Coordinator, Saga Orchestrator) use event sourcing to -recover after restarts. The stateless ones (K8s Worker, Pod Monitor, Result Processor) can scale horizontally if -throughput becomes an issue. +Most workers can run as single replicas. The Saga Orchestrator uses event sourcing to recover after restarts. +The stateless ones (K8s Worker, Pod Monitor, Result Processor) can scale horizontally if throughput becomes an issue. diff --git a/docs/index.md b/docs/index.md index c11db4ed..ac85f6fc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -60,8 +60,7 @@ Svelte frontend → FastAPI backend (MongoDB, Kafka, Redis) → Kubernetes pods flowchart TB User:::userStyle --> Frontend Frontend --> API - API --> Coordinator - Coordinator --> Saga + API --> Saga[Saga Orchestrator] Saga --> K8sWorker[K8s Worker] K8sWorker --> Pod K8sWorker --> ResultProcessor[Result Processor] diff --git a/docs/operations/deployment.md b/docs/operations/deployment.md index 12dd323c..158896f0 100644 --- a/docs/operations/deployment.md +++ b/docs/operations/deployment.md @@ -284,7 +284,7 @@ Check pod status and logs using standard kubectl commands. ```bash kubectl get pods -n integr8scode kubectl logs -n integr8scode -l app.kubernetes.io/component=backend -kubectl logs -n integr8scode -l app.kubernetes.io/component=coordinator +kubectl logs -n integr8scode -l app.kubernetes.io/component=saga-orchestrator ``` The deploy script's `status` command shows both Docker Compose and Kubernetes status in one view. diff --git a/docs/operations/metrics-reference.md b/docs/operations/metrics-reference.md index ab0f58fe..82e534f6 100644 --- a/docs/operations/metrics-reference.md +++ b/docs/operations/metrics-reference.md @@ -31,22 +31,6 @@ Track script execution performance and resource usage. | `execution.queue.depth` | UpDownCounter | - | Queued executions | | `execution.queue.wait_time` | Histogram | lang_and_version | Queue wait time (seconds) | -### Coordinator Metrics - -Track scheduling and resource allocation. - -| Metric | Type | Labels | Description | -|------------------------------------------|---------------|---------------------|---------------------------| -| `coordinator.processing.time` | Histogram | - | Event processing time | -| `coordinator.scheduling.duration` | Histogram | - | Scheduling time | -| `coordinator.executions.active` | UpDownCounter | - | Active managed executions | -| `coordinator.queue.wait_time` | Histogram | priority, queue | Queue wait by priority | -| `coordinator.executions.scheduled.total` | Counter | status | Scheduled executions | -| `coordinator.rate_limited.total` | Counter | limit_type, user_id | Rate limited requests | -| `coordinator.resource.allocations.total` | Counter | resource_type | Resource allocations | -| `coordinator.resource.utilization` | UpDownCounter | resource_type | Current utilization | -| `coordinator.scheduling.decisions.total` | Counter | decision, reason | Scheduling decisions | - ### Rate Limit Metrics Track rate-limiting behavior. @@ -161,8 +145,7 @@ avg_over_time(execution_queue_depth[1h]) | File | Purpose | |------------------------------------------------------------------------------------------------------------------------------|--------------------------------------| -| [`core/metrics/base.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/base.py) | Base metrics class and configuration | -| [`core/metrics/execution.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/execution.py) | Execution metrics | -| [`core/metrics/coordinator.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/coordinator.py) | Coordinator metrics | -| [`core/metrics/rate_limit.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/rate_limit.py) | Rate limit metrics | +| [`core/metrics/base.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/base.py) | Base metrics class and configuration | +| [`core/metrics/execution.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/execution.py) | Execution metrics | +| [`core/metrics/rate_limit.py`](https://github.com/HardMax71/Integr8sCode/blob/main/backend/app/core/metrics/rate_limit.py) | Rate limit metrics | | [`core/metrics/`](https://github.com/HardMax71/Integr8sCode/tree/main/backend/app/core/metrics) | All metrics modules | diff --git a/docs/operations/tracing.md b/docs/operations/tracing.md index a8050d6f..6566bb2f 100644 --- a/docs/operations/tracing.md +++ b/docs/operations/tracing.md @@ -79,7 +79,7 @@ If the message fails and lands in the DLQ, the DLQ manager still sees the origin When an endpoint is slow, open the request span and look at the child spans. You'll see if the time is in rate-limit checks, Mongo, Kafka publishing, or downstream webhooks. When a message fails, open the `dlq.consume` span for the message and follow the links back to the original request and the producer that created it. When you want to understand load, browse traces by endpoint or topic — the spans include batch sizes, message sizes (from Kafka instrumentation), and DB timings, so you can quickly spot hot spots without adding print statements. -For local development, point the app at a Jaeger all-in-one or an OpenTelemetry Collector that forwards to Jaeger. With the docker-compose setup, Jaeger typically exposes a UI at `http://localhost:16686`. Open it, select the service (for example integr8scode-backend, dlq-processor, event-replay, or execution-coordinator), and find traces. You should see the HTTP server spans, kafka.consume spans on workers, MongoDB spans, replay and saga spans, and notification outbound calls under the same trace. +For local development, point the app at a Jaeger all-in-one or an OpenTelemetry Collector that forwards to Jaeger. With the docker-compose setup, Jaeger typically exposes a UI at `http://localhost:16686`. Open it, select the service (for example integr8scode-backend, dlq-processor, event-replay, or saga-orchestrator), and find traces. You should see the HTTP server spans, kafka.consume spans on workers, MongoDB spans, replay and saga spans, and notification outbound calls under the same trace. Tracing is sampled. If you don't set an endpoint for the OTLP exporter the SDK drops spans after local processing; if you do set one (e.g. an OTel Collector, Tempo, or Jaeger) you get full traces in your backend. The sampler can be ratio-based or adaptive; both are supported. If you don't care about traces in a particular environment, set `OTEL_SDK_DISABLED=true` or set the sampling rate to 0. diff --git a/mkdocs.yml b/mkdocs.yml index 4b32ca58..e2983bcd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -111,7 +111,6 @@ nav: - Authentication: architecture/authentication.md - Rate Limiting: architecture/rate-limiting.md - Idempotency: architecture/idempotency.md - - Execution Queue: architecture/execution-queue.md - Runtime Registry: architecture/runtime-registry.md - Middleware: architecture/middleware.md - Domain Exceptions: architecture/domain-exceptions.md From 5a9fdc020ba4e9e08f6d232eb1ab6e4b98b1fde1 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 22 Jan 2026 01:42:12 +0100 Subject: [PATCH 20/21] dumb error fix --- backend/workers/dlq_processor.py | 3 +-- backend/workers/run_k8s_worker.py | 3 +-- backend/workers/run_result_processor.py | 3 +-- backend/workers/run_saga_orchestrator.py | 3 +-- backend/workers/run_sse_bridge.py | 3 +-- docs/components/workers/index.md | 2 +- 6 files changed, 6 insertions(+), 11 deletions(-) diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index 93bff458..b614a071 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -38,8 +38,7 @@ from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka -from faststream import FastStream -from faststream.broker.message import AckPolicy +from faststream import AckPolicy, FastStream from faststream.kafka import KafkaBroker from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 84e1da3a..08d24e99 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -42,8 +42,7 @@ from app.settings import Settings from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka -from faststream import FastStream -from faststream.broker.message import AckPolicy +from faststream import AckPolicy, FastStream from faststream.kafka import KafkaBroker from faststream.message import StreamMessage diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index d1be0842..71682f3e 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -48,8 +48,7 @@ from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka -from faststream import FastStream -from faststream.broker.message import AckPolicy +from faststream import AckPolicy, FastStream from faststream.kafka import KafkaBroker from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 5429e227..76f36fe8 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -43,8 +43,7 @@ from beanie import init_beanie from dishka import make_async_container from dishka.integrations.faststream import FromDishka, setup_dishka -from faststream import FastStream -from faststream.broker.message import AckPolicy +from faststream import AckPolicy, FastStream from faststream.kafka import KafkaBroker from faststream.message import StreamMessage from pymongo.asynchronous.mongo_client import AsyncMongoClient diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py index 48e23887..f38b9c0a 100644 --- a/backend/workers/run_sse_bridge.py +++ b/backend/workers/run_sse_bridge.py @@ -23,8 +23,7 @@ from app.settings import Settings from dishka import Provider, Scope, make_async_container, provide from dishka.integrations.faststream import FromDishka, setup_dishka -from faststream import FastStream -from faststream.broker.message import AckPolicy +from faststream import AckPolicy, FastStream from faststream.kafka import KafkaBroker from faststream.message import StreamMessage diff --git a/docs/components/workers/index.md b/docs/components/workers/index.md index 2dafd272..40d258dc 100644 --- a/docs/components/workers/index.md +++ b/docs/components/workers/index.md @@ -81,7 +81,7 @@ Each worker uses the policy best suited to its reliability requirements: Example subscriber configuration: ```python -from faststream.broker.message import AckPolicy +from faststream import AckPolicy @broker.subscriber( *topics, From 4b104db904a84d25a437e8137fe6ba2a46369b52 Mon Sep 17 00:00:00 2001 From: HardMax71 Date: Thu, 22 Jan 2026 21:03:20 +0100 Subject: [PATCH 21/21] filter by event type --- backend/app/domain/enums/kafka.py | 2 +- backend/app/events/core/producer.py | 7 +++- backend/workers/dlq_processor.py | 12 +++++-- backend/workers/run_k8s_worker.py | 6 ++-- backend/workers/run_result_processor.py | 10 +++--- backend/workers/run_saga_orchestrator.py | 43 +++++++++++++----------- backend/workers/run_sse_bridge.py | 16 ++++++--- docker-compose.yaml | 2 ++ 8 files changed, 62 insertions(+), 36 deletions(-) diff --git a/backend/app/domain/enums/kafka.py b/backend/app/domain/enums/kafka.py index 6aad2503..f7b81228 100644 --- a/backend/app/domain/enums/kafka.py +++ b/backend/app/domain/enums/kafka.py @@ -77,7 +77,7 @@ class GroupId(StringEnum): KafkaTopic.POD_STATUS_UPDATES, }, GroupId.RESULT_PROCESSOR: { - KafkaTopic.EXECUTION_EVENTS, # Listens for COMPLETED/FAILED/TIMEOUT, publishes to EXECUTION_RESULTS + KafkaTopic.EXECUTION_EVENTS, }, GroupId.SAGA_ORCHESTRATOR: { # Orchestrator is triggered by domain events, specifically EXECUTION_REQUESTED, diff --git a/backend/app/events/core/producer.py b/backend/app/events/core/producer.py index d4c7a432..3847cf61 100644 --- a/backend/app/events/core/producer.py +++ b/backend/app/events/core/producer.py @@ -49,8 +49,13 @@ async def produce( serialized_value = await self._schema_registry.serialize_event(event_to_produce) topic = f"{self._topic_prefix}{EVENT_TYPE_TO_TOPIC[event_to_produce.event_type]}" + # Always include event_type header for routing, merge with any additional headers + all_headers = {"event_type": str(event_to_produce.event_type)} + if headers: + all_headers.update(headers) + # Convert headers to list of tuples format - header_list = [(k, v.encode()) for k, v in headers.items()] if headers else None + header_list = [(k, v.encode()) for k, v in all_headers.items()] await self._producer.send_and_wait( topic=topic, diff --git a/backend/workers/dlq_processor.py b/backend/workers/dlq_processor.py index b614a071..581e3b62 100644 --- a/backend/workers/dlq_processor.py +++ b/backend/workers/dlq_processor.py @@ -182,13 +182,16 @@ def decode_dlq_json(msg: StreamMessage[Any]) -> DLQMessage: data = json.loads(msg.body) return DLQMessage.model_validate(data) - # Register subscriber for DLQ messages - @broker.subscriber( + # Create subscriber with JSON decoder (two-step pattern) + subscriber = broker.subscriber( *topics, group_id=group_id, ack_policy=AckPolicy.ACK, decoder=decode_dlq_json, ) + + # DLQ messages have "original_topic" header set by producer + @subscriber(filter=lambda msg: msg.headers.get("original_topic") is not None) async def handle_dlq_message( message: DLQMessage, dlq_manager: FromDishka[DLQManager], @@ -196,6 +199,11 @@ async def handle_dlq_message( """Handle incoming DLQ messages - invoked by FastStream when message arrives.""" await dlq_manager.process_message(message) + # Default handler for any other messages (shouldn't happen, but prevents message loss) + @subscriber() + async def handle_other(message: DLQMessage) -> None: + pass + # Background task: periodic check for scheduled retries async def retry_checker() -> None: while True: diff --git a/backend/workers/run_k8s_worker.py b/backend/workers/run_k8s_worker.py index 08d24e99..ddbe42a4 100644 --- a/backend/workers/run_k8s_worker.py +++ b/backend/workers/run_k8s_worker.py @@ -126,15 +126,15 @@ async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: decoder=decode_avro, ) - # Route by event_type header (producer sets this, Kafka stores as bytes) - @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.CREATE_POD_COMMAND.encode()) + # Route by event_type header (FastStream decodes headers as strings, not bytes) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.CREATE_POD_COMMAND) async def handle_create_pod_command( event: CreatePodCommandEvent, worker_logic: FromDishka[K8sWorkerLogic], ) -> None: await worker_logic.handle_create_pod_command(event) - @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.DELETE_POD_COMMAND.encode()) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.DELETE_POD_COMMAND) async def handle_delete_pod_command( event: DeletePodCommandEvent, worker_logic: FromDishka[K8sWorkerLogic], diff --git a/backend/workers/run_result_processor.py b/backend/workers/run_result_processor.py index 71682f3e..cc659f21 100644 --- a/backend/workers/run_result_processor.py +++ b/backend/workers/run_result_processor.py @@ -140,29 +140,29 @@ async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: decoder=decode_avro, ) - # Route by event_type header (producer sets this, Kafka stores as bytes) - @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_COMPLETED.encode()) + # Route by event_type header + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_COMPLETED) async def handle_completed( event: ExecutionCompletedEvent, logic: FromDishka[ProcessorLogic], ) -> None: await logic._handle_completed(event) - @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_FAILED.encode()) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_FAILED) async def handle_failed( event: ExecutionFailedEvent, logic: FromDishka[ProcessorLogic], ) -> None: await logic._handle_failed(event) - @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_TIMEOUT.encode()) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_TIMEOUT) async def handle_timeout( event: ExecutionTimeoutEvent, logic: FromDishka[ProcessorLogic], ) -> None: await logic._handle_timeout(event) - # Default handler for unmatched events (prevents message loss) + # Default handler for unmatched events (preserves ordering, acks non-terminal events) @subscriber() async def handle_other(event: DomainEvent) -> None: pass diff --git a/backend/workers/run_saga_orchestrator.py b/backend/workers/run_saga_orchestrator.py index 76f36fe8..0fa54dbf 100644 --- a/backend/workers/run_saga_orchestrator.py +++ b/backend/workers/run_saga_orchestrator.py @@ -33,11 +33,11 @@ ) from app.core.tracing import init_tracing from app.db.docs import ALL_DOCUMENTS +from app.domain.enums.events import EventType from app.domain.enums.kafka import GroupId from app.domain.events.typed import DomainEvent from app.events.core import UnifiedProducer from app.events.schema.schema_registry import SchemaRegistryManager -from app.services.idempotency.faststream_middleware import IdempotencyMiddleware from app.services.saga.saga_logic import SagaLogic from app.settings import Settings from beanie import init_beanie @@ -141,33 +141,36 @@ async def lifespan() -> AsyncIterator[None]: async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: return await schema_registry.deserialize_event(msg.body, "saga_orchestrator") - # Register handler dynamically after determining topics - # Saga orchestrator uses single handler - routing is internal to SagaLogic - @broker.subscriber( + # Create subscriber with Avro decoder (two-step pattern like result_processor) + subscriber = broker.subscriber( *topics, group_id=group_id, ack_policy=AckPolicy.ACK, decoder=decode_avro, ) - async def handle_saga_event( - event: DomainEvent, - saga_logic: FromDishka[SagaLogic], - ) -> None: - """ - Handle saga trigger events. - Dependencies are automatically injected via Dishka. - Routing is handled internally by SagaLogic based on saga configuration. - """ - # Handle the event through saga logic (internal routing) - await saga_logic.handle_event(event) - - # Opportunistic timeout check (replaces background loop) + # Helper for opportunistic timeout check + async def _maybe_check_timeouts(saga_logic: SagaLogic) -> None: now = time.monotonic() if now - timeout_check_state["last_check"] >= timeout_check_state["interval"]: await saga_logic.check_timeouts_once() timeout_check_state["last_check"] = now + # Route by event_type header (FastStream decodes headers as strings, not bytes) + @subscriber(filter=lambda msg: msg.headers.get("event_type") == EventType.EXECUTION_REQUESTED) + async def handle_execution_requested( + event: DomainEvent, + saga_logic: FromDishka[SagaLogic], + ) -> None: + """Handle execution_requested events that trigger ExecutionSaga.""" + await saga_logic.handle_event(event) + await _maybe_check_timeouts(saga_logic) + + # Default handler for other events on subscribed topics (execution_accepted, etc.) + @subscriber() + async def handle_other(event: DomainEvent) -> None: + pass + app_logger.info(f"Subscribing to topics: {topics}") app_logger.info("Infrastructure initialized, starting event processing...") @@ -185,8 +188,10 @@ async def handle_saga_event( # Setup Dishka integration for automatic DI in handlers setup_dishka(container=container, app=app, auto_inject=True) - # Add idempotency middleware (appends to end = most inner, runs after Dishka) - broker.add_middleware(IdempotencyMiddleware) + # NOTE: IdempotencyMiddleware disabled for saga-orchestrator. + # The saga pattern provides its own idempotency via saga_id, and the middleware + # was causing duplicate detection issues with FastStream's filter evaluation timing. + # broker.add_middleware(IdempotencyMiddleware) # Run! FastStream handles signal handling, consumer loops, graceful shutdown asyncio.run(app.run()) diff --git a/backend/workers/run_sse_bridge.py b/backend/workers/run_sse_bridge.py index f38b9c0a..319e5321 100644 --- a/backend/workers/run_sse_bridge.py +++ b/backend/workers/run_sse_bridge.py @@ -103,21 +103,27 @@ async def lifespan() -> AsyncIterator[None]: async def decode_avro(msg: StreamMessage[Any]) -> DomainEvent: return await schema_registry.deserialize_event(msg.body, "sse_bridge") - # Single handler for all SSE-relevant events - # No filter needed - we check event_type in handler since route_event handles all types - @broker.subscriber( + # Create subscriber with Avro decoder (two-step pattern) + subscriber = broker.subscriber( *topics, group_id=group_id, ack_policy=AckPolicy.ACK_FIRST, # SSE bridge is idempotent (Redis pubsub) decoder=decode_avro, ) + + # Filter for SSE-relevant events (FastStream decodes headers as strings, not bytes) + @subscriber(filter=lambda msg: msg.headers.get("event_type", "") in SSE_RELEVANT_EVENTS) async def handle_sse_event( event: DomainEvent, router: FromDishka[SSEEventRouter], ) -> None: """Route domain events to Redis for SSE delivery.""" - if event.event_type in SSE_RELEVANT_EVENTS: - await router.route_event(event) + await router.route_event(event) + + # Default handler for non-SSE events (prevents message loss) + @subscriber() + async def handle_other(event: DomainEvent) -> None: + pass app_logger.info(f"Subscribing to topics: {topics}") app_logger.info("SSE Bridge ready") diff --git a/docker-compose.yaml b/docker-compose.yaml index c63bf097..5dc6f963 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -458,6 +458,7 @@ services: - TRACING_SERVICE_NAME=pod-monitor - KAFKA_CONSUMER_GROUP_ID=pod-monitor - KUBECONFIG=/app/kubeconfig.yaml + - LOG_LEVEL=WARNING volumes: - ./backend/app:/app/app:ro - ./backend/workers:/app/workers:ro @@ -565,6 +566,7 @@ services: environment: - TRACING_SERVICE_NAME=event-replay - KAFKA_CONSUMER_GROUP_ID=event-replay + - LOG_LEVEL=WARNING volumes: - ./backend/app:/app/app:ro - ./backend/workers:/app/workers:ro