From 2c2886b3f1ee5e17e3e398cceff09a141f952348 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 15:27:46 +0530 Subject: [PATCH 1/3] graceful shutdown for all UDFs Signed-off-by: Sreekanth --- .../pynumaflow/accumulator/async_server.py | 61 +++++++-- .../accumulator/servicer/async_servicer.py | 31 ++++- .../pynumaflow/batchmapper/async_server.py | 69 +++++++--- .../batchmapper/servicer/async_servicer.py | 16 ++- .../mapper/_servicer/_async_servicer.py | 23 +++- .../mapper/_servicer/_sync_servicer.py | 44 +++++-- .../pynumaflow/mapper/async_server.py | 71 +++++++--- .../pynumaflow/mapper/multiproc_server.py | 14 ++ .../pynumaflow/mapper/sync_server.py | 7 + .../pynumaflow/mapstreamer/async_server.py | 69 +++++++--- .../mapstreamer/servicer/async_servicer.py | 25 +++- .../pynumaflow/reducer/async_server.py | 65 ++++++++-- .../reducer/servicer/async_servicer.py | 29 ++++- .../reducer/servicer/task_manager.py | 15 +-- .../pynumaflow/pynumaflow/shared/server.py | 49 +------ .../pynumaflow/pynumaflow/sideinput/server.py | 6 + .../pynumaflow/sideinput/servicer/servicer.py | 10 +- .../pynumaflow/sourcer/async_server.py | 67 ++++++++-- .../sourcer/servicer/async_servicer.py | 56 ++++++-- .../sourcetransformer/async_server.py | 71 +++++++--- .../sourcetransformer/multiproc_server.py | 14 ++ .../pynumaflow/sourcetransformer/server.py | 6 + .../servicer/_async_servicer.py | 23 +++- .../sourcetransformer/servicer/_servicer.py | 42 ++++-- .../tests/map/test_sync_map_shutdown.py | 121 ++++++++++++++++++ 25 files changed, 795 insertions(+), 209 deletions(-) create mode 100644 packages/pynumaflow/tests/map/test_sync_map_shutdown.py diff --git a/packages/pynumaflow/pynumaflow/accumulator/async_server.py b/packages/pynumaflow/pynumaflow/accumulator/async_server.py index 200e4422..6a2f165d 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/async_server.py +++ b/packages/pynumaflow/pynumaflow/accumulator/async_server.py @@ -1,9 +1,13 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.proto.accumulator import accumulator_pb2_grpc @@ -15,6 +19,7 @@ MAX_NUM_THREADS, ACCUMULATOR_SOCK_PATH, ACCUMULATOR_SERVER_INFO_FILE_PATH, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.accumulator._dtypes import ( @@ -23,7 +28,7 @@ Accumulator, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -157,6 +162,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncAccumulatorServicer(self.accumulator_handler) + self._error: BaseException | None = None def start(self): """ @@ -167,6 +173,9 @@ def start(self): "Starting Async Accumulator Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -176,18 +185,52 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator] - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py index 16be7911..cd35962d 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py @@ -3,7 +3,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING +from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc from pynumaflow.accumulator._dtypes import ( Datum, @@ -13,7 +13,7 @@ KeyedWindow, ) from pynumaflow.accumulator.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -57,6 +57,12 @@ def __init__( ): # The accumulator handler can be a function or a builder class instance. self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def AccumulateFn( self, @@ -104,20 +110,35 @@ async def AccumulateFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window EOF response or Window result response # back to the client else: yield msg except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return # Wait for the process_input_stream task to finish for a clean exit try: await producer except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py index 1078e012..4fa3221b 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -8,9 +12,11 @@ BATCH_MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.batchmapper._dtypes import BatchMapCallable from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -19,7 +25,7 @@ ContainerType, ) from pynumaflow.proto.mapper import map_pb2_grpc -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer class BatchMapAsyncServer(NumaflowServer): @@ -92,6 +98,7 @@ async def handler( ] self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance) + self._error: BaseException | None = None def start(self): """ @@ -99,6 +106,9 @@ def start(self): to the aexec so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -108,25 +118,54 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server( - self.servicer, - server, - ) - _LOGGER.info("Starting Batch Map Server") + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py index 523a4ad4..b6d866d3 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py @@ -7,7 +7,7 @@ from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING @@ -26,6 +26,12 @@ def __init__( ): self.background_tasks = set() self.__batch_map_handler: BatchMapCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -97,8 +103,12 @@ async def MapFn( await req_queue.put(datum) except BaseException as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py index 90a55b7b..0cbf18f2 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError, Message, Messages from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -25,6 +25,12 @@ def __init__( ): self.background_tasks = set() self.__map_handler: MapAsyncCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -57,7 +63,12 @@ async def MapFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window response back to the client else: @@ -65,8 +76,12 @@ async def MapFn( # wait for the producer task to complete await producer except BaseException as e: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index cb757e3c..17895992 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -2,8 +2,9 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterator +import grpc from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING @@ -26,6 +27,10 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False): self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + # Graceful shutdown: when set, a watcher thread in _run_server() calls + # server.stop() instead of hard-killing the process via psutil. + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def MapFn( self, @@ -36,6 +41,7 @@ def MapFn( Applies a function to each datum element. The pascal case function name comes from the proto map_pb2_grpc.py file. """ + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -57,10 +63,19 @@ def MapFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + if isinstance(res, grpc.RpcError): + # Client disconnected mid-stream — the reader thread + # surfaced the error via the queue. Not a UDF fault. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + result_queue.close() + self.shutdown_event.set() + return + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + update_context_err(context, res, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -69,12 +84,23 @@ def MapFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + # Client disconnected — not a UDF error, but we still need to + # shut down the server so the process can exit cleanly. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index 5bba75d7..7cddfc57 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -7,7 +11,10 @@ MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -18,10 +25,7 @@ from pynumaflow.mapper._dtypes import MapAsyncCallable from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer from pynumaflow.proto.mapper import map_pb2_grpc -from pynumaflow.shared.server import ( - NumaflowServer, - start_async_server, -) +from pynumaflow.shared.server import NumaflowServer class MapAsyncServer(NumaflowServer): @@ -92,6 +96,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncMapServicer(handler=mapper_instance) + self._error: BaseException | None = None def start(self) -> None: """ @@ -99,32 +104,66 @@ def start(self) -> None: so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self) -> None: """ Starts the Async gRPC server on the given UNIX socket with given max threads. """ - # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) - server_new = grpc.aio.server(options=self._server_options) - server_new.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server(self.servicer, server_new) + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap - # Start the async server - await start_async_server( - server_async=server_new, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py index 5d68a96b..de08f075 100644 --- a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py @@ -1,4 +1,8 @@ +import multiprocessing +import sys + from pynumaflow._constants import ( + _LOGGER, NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, @@ -104,6 +108,11 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: self._process_count = min(server_count, 2 * _PROCESS_COUNT) self.servicer = SyncMapServicer(handler=mapper_instance, multiproc=True) + # Shared event across all worker processes for coordinated shutdown. + # When any worker's servicer sets this event, all workers' watcher + # threads trigger server.stop() for a graceful coordinated exit. + self._shutdown_event = multiprocessing.Event() + def start(self) -> None: """ Starts the N grpc servers gRPC serves on the with given max threads. @@ -129,4 +138,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=server_info, + shutdown_event=self._shutdown_event, ) + + if self._shutdown_event.is_set(): + _LOGGER.critical("Server exiting due to worker error") + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/mapper/sync_server.py b/packages/pynumaflow/pynumaflow/mapper/sync_server.py index 9c2431b6..a96ceb7f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/sync_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/sync_server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -112,4 +114,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index 187c720d..b6b0fb23 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -1,6 +1,11 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -18,11 +23,12 @@ _LOGGER, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.mapstreamer._dtypes import MapStreamCallable -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer class MapStreamAsyncServer(NumaflowServer): @@ -111,6 +117,7 @@ async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Messag ] self.servicer = AsyncMapStreamServicer(handler=self.map_stream_instance) + self._error: BaseException | None = None def start(self): """ @@ -118,6 +125,9 @@ def start(self): to the aexec so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -127,25 +137,54 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server( - self.servicer, - server, - ) - _LOGGER.info("Starting Map Stream Server") + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.StreamMap - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py index f5a9a999..c5aa3545 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow.mapstreamer import Datum from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2 -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -22,6 +22,12 @@ class AsyncMapStreamServicer(map_pb2_grpc.MapServicer): def __init__(self, handler: MapStreamCallable): self.__map_stream_handler: MapStreamCallable = handler self._background_tasks: set[asyncio.Task] = set() + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -51,7 +57,12 @@ async def MapFn( # Consume results as they arrive and stream them to the client async for msg in global_result_queue.read_iterator(): if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return else: # msg is a map_pb2.MapResponse, already formed @@ -61,8 +72,12 @@ async def MapFn( await producer except BaseException as e: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( @@ -124,7 +139,7 @@ async def _invoke_map_stream( except BaseException as err: _LOGGER.critical("MapFn handler error", exc_info=True) # Surface handler error to the main producer; - # it will call handle_async_error and end the RPC + # it will set the shutdown event and end the RPC await result_queue.put(err) async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index aee4d355..ff5f9d8e 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -1,8 +1,12 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.proto.reducer import reduce_pb2_grpc @@ -15,6 +19,7 @@ _LOGGER, REDUCE_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.reducer._dtypes import ( @@ -23,7 +28,7 @@ Reducer, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -143,6 +148,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncReduceServicer(self.reducer_handler) + self._error: BaseException | None = None def start(self): """ @@ -153,6 +159,9 @@ def start(self): "Starting Async Reduce Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -162,20 +171,52 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - reduce_servicer = self.servicer - reduce_pb2_grpc.add_ReduceServicer_to_server(reduce_servicer, server) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + reduce_pb2_grpc.add_ReduceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Reducer] - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py index 44db8077..3ea646e3 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import AsyncIterable from google.protobuf import empty_pb2 as _empty_pb2 @@ -12,7 +13,7 @@ WindowOperation, ) from pynumaflow.reducer.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -53,6 +54,12 @@ def __init__( ): # The Reduce handler can be a function or a builder class instance. self.__reduce_handler: ReduceAsyncCallable | _ReduceBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def ReduceFn( self, @@ -103,9 +110,13 @@ async def ReduceFn( await task_manager.append_task(request) except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() + return # send EOF to all the tasks once the request iterator is exhausted # This will signal the tasks to stop reading the data on their @@ -134,9 +145,13 @@ async def ReduceFn( yield reduce_pb2.ReduceResponse(window=window, EOF=True) except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py index 2e21d60a..bfc802a7 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py @@ -2,8 +2,6 @@ from datetime import datetime, timezone from collections.abc import AsyncIterable -import grpc - from pynumaflow.exceptions import UDFError from pynumaflow.proto.reducer import reduce_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator @@ -21,7 +19,7 @@ ReduceAsyncCallable, ReduceWindow, ) -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -169,14 +167,9 @@ async def __invoke_reduce( msgs = await new_instance(keys, request_iterator, md) except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await asyncio.gather( - self.context.abort(grpc.StatusCode.UNKNOWN, details=repr(err)), - return_exceptions=True, - ) - exit_on_error(err=repr(err), parent=False, context=self.context, update_context=False) - return + err_msg = f"ReduceError: {repr(err)}" + update_context_err(self.context, err, err_msg) + raise datum_responses = [] for msg in msgs: diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 3986e0dc..64a2fd03 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import io import multiprocessing @@ -14,7 +13,6 @@ from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor import grpc -import psutil from pynumaflow._constants import ( _LOGGER, @@ -90,7 +88,7 @@ def _run_server( udf_type: str, server_info_file: str | None = None, server_info: ServerInfo | None = None, - shutdown_event: threading.Event | None = None, + shutdown_event: threading.Event | multiprocessing.Event | None = None, ) -> None: """ Starts the Synchronous server instance on the given UNIX socket @@ -151,6 +149,7 @@ def start_multiproc_server( server_info: ServerInfo | None = None, server_options=None, udf_type: str = UDFType.Map, + shutdown_event: multiprocessing.Event | None = None, ): """ Start N grpc servers in different processes where N = The number of CPUs or the @@ -179,6 +178,7 @@ def start_multiproc_server( worker = multiprocessing.Process( target=_run_server, args=(servicer, bind_address, max_threads, server_options, udf_type), + kwargs={"shutdown_event": shutdown_event} if shutdown_event else {}, ) worker.start() workers.append(worker) @@ -278,37 +278,6 @@ def get_grpc_status(err: str, detail: str | None = None): return rpc_status.to_status(status) -def exit_on_error( - context: NumaflowServicerContext, err: str, parent: bool = False, update_context=True -): - """ - Exit the current/parent process on an error. - - Args: - context (NumaflowServicerContext): The gRPC context. - err (str): The error message. - parent (bool, optional): Whether this is the parent process. - Defaults to False. - update_context(bool, optional) : Is there a need to update - the context with the error codes - """ - if update_context: - # Create a status object with the error details - grpc_status = get_grpc_status(err) - - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(err) - context.set_trailing_metadata(grpc_status.trailing_metadata) - - p = psutil.Process(os.getpid()) - # If the parent flag is true, we exit from the parent process - # Use this for Multiproc right now to exit from the parent fork - if parent: - p = psutil.Process(os.getppid()) - _LOGGER.info("Killing process: Got exception %s", err) - p.kill() - - def update_context_err(context: NumaflowServicerContext, e: BaseException, err_msg: str): """ Update the context with the error and log the exception. @@ -330,15 +299,3 @@ def get_exception_traceback_str(exc) -> str: return file.getvalue().rstrip() -async def handle_async_error( - context: NumaflowServicerContext, exception: BaseException, exception_type: str -): - """ - Handle exceptions for async servers by updating the context and exiting. - """ - err_msg = f"{exception_type}: {repr(exception)}" - update_context_err(context, exception, err_msg) - await asyncio.gather( - context.abort(grpc.StatusCode.INTERNAL, details=err_msg), return_exceptions=True - ) - exit_on_error(err=err_msg, parent=False, context=context, update_context=False) diff --git a/packages/pynumaflow/pynumaflow/sideinput/server.py b/packages/pynumaflow/pynumaflow/sideinput/server.py index 7bb27b86..445e23e1 100644 --- a/packages/pynumaflow/pynumaflow/sideinput/server.py +++ b/packages/pynumaflow/pynumaflow/sideinput/server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.shared import NumaflowServer from pynumaflow.shared.server import sync_server_start @@ -99,4 +101,8 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SideInput, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py b/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py index 7f46bf68..da836948 100644 --- a/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py +++ b/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py @@ -1,3 +1,5 @@ +import threading + from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow._constants import ( @@ -5,7 +7,7 @@ ERR_UDF_EXCEPTION_STRING, ) from pynumaflow.proto.sideinput import sideinput_pb2_grpc, sideinput_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sideinput._dtypes import RetrieverCallable from pynumaflow.types import NumaflowServicerContext @@ -16,6 +18,8 @@ def __init__( handler: RetrieverCallable, ): self.__retrieve_handler: RetrieverCallable = handler + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def RetrieveSideInput( self, request: _empty_pb2.Empty, context: NumaflowServicerContext @@ -30,7 +34,9 @@ def RetrieveSideInput( except BaseException as err: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) - exit_on_error(context, err_msg) + update_context_err(context, err, err_msg) + self.error = err + self.shutdown_event.set() return return sideinput_pb2.SideInputResponse(value=rspn.value, no_broadcast=rspn.no_broadcast) diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index 3bca9dfb..2f54b158 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -1,6 +1,11 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer @@ -10,10 +15,12 @@ NUM_THREADS_DEFAULT, SOURCE_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.proto.sourcer import source_pb2_grpc -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer from pynumaflow.sourcer._dtypes import SourceCallable @@ -153,6 +160,7 @@ async def partitions_handler(self) -> PartitionsResponse: ] self.servicer = AsyncSourceServicer(source_handler=sourcer_instance) + self._error: BaseException | None = None def start(self): """ @@ -160,6 +168,9 @@ def start(self): so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -168,20 +179,52 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - source_servicer = self.servicer - source_pb2_grpc.add_SourceServicer_to_server(source_servicer, server) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + source_pb2_grpc.add_SourceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sourcer] - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py index 3e0839c4..0f8a4db9 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py @@ -5,7 +5,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sourcer import ReadRequest, Offset, NackRequest, AckRequest, SourceCallable from pynumaflow.proto.sourcer import source_pb2 from pynumaflow.proto.sourcer import source_pb2_grpc @@ -71,6 +71,12 @@ def __init__(self, source_handler: SourceCallable): self.source_handler = source_handler self.__initialize_handlers() self.cleanup_coroutines = [] + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event def __initialize_handlers(self): """Initialize handler methods from the provided source handler.""" @@ -110,7 +116,12 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await handle_async_error(context, resp, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(resp)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, resp, err_msg) + self._error = resp + if self._shutdown_event is not None: + self._shutdown_event.set() return yield _create_read_response(resp) @@ -121,7 +132,12 @@ async def ReadFn( yield _create_eot_response() except BaseException as err: _LOGGER.critical("User-Defined Source ReadFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def __invoke_read( self, req: source_pb2.ReadRequest, niter: NonBlockingIterator[Message | Exception] @@ -169,7 +185,12 @@ async def AckFn( yield _create_ack_response() except BaseException as err: _LOGGER.critical("User-Defined Source AckFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def NackFn( self, @@ -186,7 +207,12 @@ async def NackFn( await self.__source_nack_handler(NackRequest(offsets=offsets)) except BaseException as err: _LOGGER.critical("User-Defined Source NackFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return return source_pb2.NackResponse( result=source_pb2.NackResponse.Result(success=_empty_pb2.Empty()) ) @@ -211,8 +237,14 @@ async def PendingFn( count = await self.__source_pending_handler() except BaseException as err: _LOGGER.critical("PendingFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) - raise + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PendingResponse( + result=source_pb2.PendingResponse.Result(count=0) + ) resp = source_pb2.PendingResponse.Result(count=count.count) return source_pb2.PendingResponse(result=resp) @@ -226,8 +258,14 @@ async def PartitionsFn( partitions = await self.__source_partitions_handler() except BaseException as err: _LOGGER.critical("PartitionsFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) - raise + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PartitionsResponse( + result=source_pb2.PartitionsResponse.Result(partitions=[]) + ) resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) return source_pb2.PartitionsResponse(result=resp) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index 990e4587..16ce1496 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -7,17 +11,17 @@ MAX_NUM_THREADS, SOURCE_TRANSFORMER_SOCK_PATH, SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType, ) from pynumaflow.proto.sourcetransformer import transform_pb2_grpc -from pynumaflow.shared.server import ( - NumaflowServer, - start_async_server, -) +from pynumaflow.shared.server import NumaflowServer from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer @@ -115,6 +119,7 @@ def __init__( ("grpc.max_receive_message_length", self.max_message_size), ] self.servicer = SourceTransformAsyncServicer(handler=source_transform_instance) + self._error: BaseException | None = None def start(self) -> None: """ @@ -122,32 +127,66 @@ def start(self) -> None: so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self) -> None: """ Starts the Async gRPC server on the given UNIX socket with given max threads. """ - # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) - server_new = grpc.aio.server(options=self._server_options) - server_new.add_insecure_port(self.sock_path) - transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server_new) + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ ContainerType.Sourcetransformer ] - # Start the async server - await start_async_server( - server_async=server_new, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py index dbc8b7b5..f1ff372e 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py @@ -1,9 +1,13 @@ +import multiprocessing +import sys + from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.sourcetransformer.servicer._servicer import SourceTransformServicer from pynumaflow.shared.server import start_multiproc_server from pynumaflow._constants import ( + _LOGGER, MAX_MESSAGE_SIZE, SOURCE_TRANSFORMER_SOCK_PATH, NUM_THREADS_DEFAULT, @@ -129,6 +133,11 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: self._process_count = min(server_count, 2 * _PROCESS_COUNT) self.servicer = SourceTransformServicer(handler=source_transform_instance, multiproc=True) + # Shared event across all worker processes for coordinated shutdown. + # When any worker's servicer sets this event, all workers' watcher + # threads trigger server.stop() for a graceful coordinated exit. + self._shutdown_event = multiprocessing.Event() + def start(self): """ Starts the N gRPC servers on the given socket path with given max threads. @@ -148,4 +157,9 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SourceTransformer, server_info=serv_info, + shutdown_event=self._shutdown_event, ) + + if self._shutdown_event.is_set(): + _LOGGER.critical("Server exiting due to worker error") + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/server.py index 7069e2b6..c410adce 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ContainerType, MINIMUM_NUMAFLOW_VERSION, ServerInfo from pynumaflow._constants import ( MAX_MESSAGE_SIZE, @@ -128,4 +130,8 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SourceTransformer, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py index da7384c2..d85fcebd 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable from pynumaflow.types import NumaflowServicerContext @@ -28,6 +28,12 @@ def __init__( ): self.background_tasks = set() self.__transform_handler: SourceTransformAsyncCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def SourceTransformFn( self, @@ -61,7 +67,12 @@ async def SourceTransformFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window response back to the client else: @@ -69,8 +80,12 @@ async def SourceTransformFn( # wait for the producer task to complete await producer except BaseException as e: - _LOGGER.critical("SourceTransformFnError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py index 2091bf47..1b93b96c 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py @@ -2,10 +2,11 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterable +import grpc from google.protobuf import empty_pb2 as _empty_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.shared.synciter import SyncIterator from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable @@ -46,6 +47,10 @@ def __init__(self, handler: SourceTransformCallable, multiproc: bool = False): self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + # Graceful shutdown: when set, a watcher thread in _run_server() calls + # server.stop() instead of hard-killing the process via psutil. + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def SourceTransformFn( self, @@ -56,6 +61,8 @@ def SourceTransformFn( Applies a function to each datum element. The pascal case function name comes from the generated transform_pb2_grpc.py file. """ + # Initialize before try so it's accessible in except blocks + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -78,10 +85,18 @@ def SourceTransformFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + if isinstance(res, grpc.RpcError): + # Client disconnected mid-stream — the reader thread + # surfaced the error via the queue. Not a UDF fault. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + result_queue.close() + self.shutdown_event.set() + return + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + update_context_err(context, res, err_msg) + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -90,12 +105,21 @@ def SourceTransformFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( diff --git a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py new file mode 100644 index 00000000..d8df92f0 --- /dev/null +++ b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py @@ -0,0 +1,121 @@ +""" +Shutdown-event tests for the synchronous Map servicer. + +Mirrors the sinker shutdown test pattern (tests/sink/test_server.py lines 345-461). +Each test verifies that the servicer sets shutdown_event (and optionally captures the +error) under a specific failure mode, enabling graceful server stop via the watcher +thread in _run_server() instead of a hard process kill. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.mapper import Datum, Messages, Message +from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import map_handler, err_map_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SyncMapServicer(handler=err_map_handler) + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + servicer = SyncMapServicer(handler=map_handler) + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + # Send a data message without a handshake first + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "MapFn: expected handshake as the first message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + servicer = SyncMapServicer(handler=map_handler) + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + servicer = SyncMapServicer(handler=map_handler) + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + # Should have at least the handshake response + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None From aa9496f301997d2ae8982be72138f6cccfc83842 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 15:39:37 +0530 Subject: [PATCH 2/3] file formatting Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/shared/server.py | 2 -- .../pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py | 4 +--- packages/pynumaflow/tests/map/test_sync_map_shutdown.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 64a2fd03..d2ba610a 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -297,5 +297,3 @@ def get_exception_traceback_str(exc) -> str: file = io.StringIO() traceback.print_exception(exc, value=exc, tb=exc.__traceback__, file=file) return file.getvalue().rstrip() - - diff --git a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py index 0f8a4db9..1929e536 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py @@ -242,9 +242,7 @@ async def PendingFn( self._error = err if self._shutdown_event is not None: self._shutdown_event.set() - return source_pb2.PendingResponse( - result=source_pb2.PendingResponse.Result(count=0) - ) + return source_pb2.PendingResponse(result=source_pb2.PendingResponse.Result(count=0)) resp = source_pb2.PendingResponse.Result(count=count.count) return source_pb2.PendingResponse(result=resp) diff --git a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py index d8df92f0..cf8523c1 100644 --- a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py +++ b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py @@ -13,7 +13,6 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.mapper import Datum, Messages, Message from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, get_test_datums From 6c37cb5e1591bdb25b3cf46c0f7fb541d765cf13 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 20:33:15 +0530 Subject: [PATCH 3/3] fix CI Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/shared/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index d2ba610a..e5aec76c 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -1,6 +1,7 @@ import contextlib import io import multiprocessing +import multiprocessing.synchronize import os import socket import threading @@ -88,7 +89,7 @@ def _run_server( udf_type: str, server_info_file: str | None = None, server_info: ServerInfo | None = None, - shutdown_event: threading.Event | multiprocessing.Event | None = None, + shutdown_event: threading.Event | None = None, ) -> None: """ Starts the Synchronous server instance on the given UNIX socket @@ -149,7 +150,7 @@ def start_multiproc_server( server_info: ServerInfo | None = None, server_options=None, udf_type: str = UDFType.Map, - shutdown_event: multiprocessing.Event | None = None, + shutdown_event: multiprocessing.synchronize.Event | None = None, ): """ Start N grpc servers in different processes where N = The number of CPUs or the