diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index df89eccf..0d0f251b 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -1,8 +1,12 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.proto.reducer import reduce_pb2_grpc @@ -15,6 +19,7 @@ REDUCE_STREAM_SOCK_PATH, REDUCE_STREAM_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.reducestreamer._dtypes import ( @@ -23,7 +28,7 @@ ReduceStreamer, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -156,6 +161,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncReduceStreamServicer(self.reduce_stream_handler) + self._error: BaseException | None = None def start(self): """ @@ -166,6 +172,9 @@ def start(self): "Starting Async Reduce Stream Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -178,15 +187,42 @@ async def aexec(self): # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + reduce_pb2_grpc.add_ReduceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Reducestreamer] - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Reduce Stream Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py index 3bec4a8a..569ad8c9 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py @@ -3,7 +3,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING +from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING, _LOGGER from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc from pynumaflow.reducestreamer._dtypes import ( Datum, @@ -12,7 +12,7 @@ ReduceRequest, ) from pynumaflow.reducestreamer.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -47,6 +47,12 @@ def __init__( ): # The Reduce handler can be a function or a builder class instance. self.__reduce_handler: ReduceStreamAsyncCallable | _ReduceStreamBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def ReduceFn( self, @@ -94,20 +100,50 @@ async def ReduceFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window EOF response or Window result response # back to the client else: yield msg + except GeneratorExit: + # ReduceFn is an async generator (it yields messages). When Numaflow closes a + # window, gRPC calls .aclose() on this generator, throwing GeneratorExit at + # the yield point. This is normal stream lifecycle — return cleanly. + return + except asyncio.CancelledError: + # SIGTERM: aiorun cancelled all tasks. Signal the server to stop so + # Server.__del__ doesn't try to schedule on a closed event loop. + if self._shutdown_event is not None: + self._shutdown_event.set() + return except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return # Wait for the process_input_stream task to finish for a clean exit try: await producer + except asyncio.CancelledError: + if self._shutdown_event is not None: + self._shutdown_event.set() + return except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py index 519c043b..2e436178 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py @@ -195,6 +195,9 @@ async def __invoke_reduce( new_instance = self.__reduce_handler.create() try: _ = await new_instance(keys, request_iterator, output, md) + except asyncio.CancelledError: + _LOGGER.info("ReduceStream __invoke_reduce cancelled, returning cleanly") + return # If there is an error in the reduce operation, log and # then send the error to the result queue except BaseException as err: @@ -217,6 +220,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2. # append the task data to the existing task # if the task does not exist, create a new task await self.send_datum_to_task(request) + except asyncio.CancelledError: + _LOGGER.info("ReduceStream process_input_stream cancelled, returning cleanly") + return # If there is an error in the reduce operation, log and # then send the error to the result queue except BaseException as e: @@ -261,6 +267,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2. # Once all tasks are completed, senf EOF the global result queue await self.global_result_queue.put(STREAM_EOF) + except asyncio.CancelledError: + _LOGGER.info("ReduceStream post-processing cancelled, returning cleanly") + return except BaseException as e: err_msg = f"Reduce Streaming Error: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index 801c7f90..eaaea013 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -3,8 +3,7 @@ import threading import unittest from collections.abc import AsyncIterable -from unittest.mock import patch - +from unittest.mock import MagicMock import grpc from grpc.aio._server import Server @@ -18,13 +17,14 @@ Metadata, ) from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc +from pynumaflow.reducestreamer.servicer.async_servicer import AsyncReduceStreamServicer +from pynumaflow.reducestreamer.servicer.task_manager import TaskManager from pynumaflow.shared.asynciter import NonBlockingIterator from tests.testing_utils import ( mock_message, mock_interval_window_start, mock_interval_window_end, get_time_args, - mock_terminate_on_stop, ) LOGGER = setup_logging(__name__) @@ -128,7 +128,6 @@ def NewAsyncReduceStreamer(): return udfs -@patch("psutil.Process.kill", mock_terminate_on_stop) async def start_server(udfs): server = grpc.aio.server() reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) @@ -141,8 +140,6 @@ async def start_server(udfs): await server.wait_for_termination() -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) class TestAsyncReduceStreamerErr(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -172,9 +169,6 @@ def tearDownClass(cls) -> None: except BaseException as e: LOGGER.error(e) - # TODO: Check why terminating even after mocking - # We are mocking the terminate function from the psutil to not exit the program during testing - @patch("psutil.Process.kill", mock_terminate_on_stop) def test_reduce(self) -> None: stub = self.__stub() request, metadata = start_request(multiple_window=False) @@ -191,8 +185,6 @@ def test_reduce(self) -> None: return self.fail("Expected an exception.") - # TODO: Check why terminating even after mocking - @patch("psutil.Process.kill", mock_terminate_on_stop) def test_reduce_window_len(self) -> None: stub = self.__stub() request, metadata = start_request(multiple_window=True) @@ -228,6 +220,132 @@ def __stub(self): return reduce_pb2_grpc.ReduceStub(_channel) +async def _emit_one_handler(keys, datums, output, md): + """Handler that emits one message eagerly, then blocks reading remaining datums.""" + await output.put(Message(b"result", keys=keys)) + async for _ in datums: + pass + + +def test_cancelled_error_in_consumer_loop(): + """athrow(CancelledError) at the yield point exercises the except CancelledError branch.""" + servicer = AsyncReduceStreamServicer(_emit_one_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + request, _ = start_request(multiple_window=False) + + async def _run(): + async def requests(): + yield request + await asyncio.sleep(999) + + gen = servicer.ReduceFn(requests(), MagicMock()) + # Drive the pipeline until the handler's message is yielded. + await gen.__anext__() + # Simulate task cancellation (e.g. SIGTERM) at the yield point. + try: + await gen.athrow(asyncio.CancelledError()) + except StopAsyncIteration: + pass + + asyncio.run(_run()) + assert shutdown_event.is_set() + assert servicer._error is None + + +def test_base_exception_in_consumer_loop(): + """athrow(ValueError) at the yield point exercises the except BaseException branch.""" + servicer = AsyncReduceStreamServicer(_emit_one_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + request, _ = start_request(multiple_window=False) + + async def _run(): + async def requests(): + yield request + await asyncio.sleep(999) + + ctx = MagicMock() + gen = servicer.ReduceFn(requests(), ctx) + await gen.__anext__() + try: + await gen.athrow(ValueError("boom")) + except StopAsyncIteration: + pass + return ctx + + ctx = asyncio.run(_run()) + assert shutdown_event.is_set() + assert isinstance(servicer._error, ValueError) + ctx.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) + + +_original_process_input_stream = TaskManager.process_input_stream + + +def test_cancelled_error_awaiting_producer(): + """CancelledError from the producer task after it finishes its real work.""" + servicer = AsyncReduceStreamServicer(_emit_one_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + request, _ = start_request(multiple_window=False) + + async def raise_after_real_work(self, request_iterator): + await _original_process_input_stream(self, request_iterator) + raise asyncio.CancelledError() + + TaskManager.process_input_stream = raise_after_real_work + try: + + async def _run(): + async def requests(): + yield request + + gen = servicer.ReduceFn(requests(), MagicMock()) + async for _ in gen: + pass + + asyncio.run(_run()) + finally: + TaskManager.process_input_stream = _original_process_input_stream + + assert shutdown_event.is_set() + assert servicer._error is None + + +def test_base_exception_awaiting_producer(): + """BaseException from the producer task after it finishes its real work.""" + servicer = AsyncReduceStreamServicer(_emit_one_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + request, _ = start_request(multiple_window=False) + + async def raise_after_real_work(self, request_iterator): + await _original_process_input_stream(self, request_iterator) + raise RuntimeError("producer boom") + + TaskManager.process_input_stream = raise_after_real_work + try: + + async def _run(): + async def requests(): + yield request + + ctx = MagicMock() + gen = servicer.ReduceFn(requests(), ctx) + async for _ in gen: + pass + return ctx + + ctx = asyncio.run(_run()) + finally: + TaskManager.process_input_stream = _original_process_input_stream + + assert shutdown_event.is_set() + assert isinstance(servicer._error, RuntimeError) + ctx.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main()