Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions packages/pynumaflow/pynumaflow/accumulator/async_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import contextlib
import inspect
import sys

import aiorun
import grpc

from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer
from pynumaflow.info.server import write as info_server_write
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
from pynumaflow.proto.accumulator import accumulator_pb2_grpc

Expand All @@ -15,6 +19,7 @@
MAX_NUM_THREADS,
ACCUMULATOR_SOCK_PATH,
ACCUMULATOR_SERVER_INFO_FILE_PATH,
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
)

from pynumaflow.accumulator._dtypes import (
Expand All @@ -23,7 +28,7 @@
Accumulator,
)

from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server
from pynumaflow.shared.server import NumaflowServer, check_instance


def get_handler(
Expand Down Expand Up @@ -157,6 +162,7 @@ def __init__(
]
# Get the servicer instance for the async server
self.servicer = AsyncAccumulatorServicer(self.accumulator_handler)
self._error: BaseException | None = None

def start(self):
"""
Expand All @@ -167,6 +173,9 @@ def start(self):
"Starting Async Accumulator Server",
)
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
if self._error:
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
sys.exit(1)

async def aexec(self):
"""
Expand All @@ -176,18 +185,52 @@ async def aexec(self):
# As the server is async, we need to create a new server instance in the
# same thread as the event loop so that all the async calls are made in the
# same context
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)

# The asyncio.Event must be created here (inside aexec) rather than in __init__,
# because it must be bound to the running event loop that aiorun creates.
# At __init__ time no event loop exists yet.
shutdown_event = asyncio.Event()
self.servicer.set_shutdown_event(shutdown_event)

accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server)

serv_info = ServerInfo.get_default_server_info()
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator]
await start_async_server(
server_async=server,
sock_path=self.sock_path,
max_threads=self.max_threads,
cleanup_coroutines=list(),
server_info_file=self.server_info_file,
server_info=serv_info,

await server.start()
info_server_write(server_info=serv_info, info_file=self.server_info_file)

_LOGGER.info(
"Async GRPC Server listening on: %s with max threads: %s",
self.sock_path,
self.max_threads,
)

async def _watch_for_shutdown():
"""Wait for the shutdown event and stop the server with a grace period."""
await shutdown_event.wait()
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
# Stop accepting new requests and wait for a maximum of
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

shutdown_task = asyncio.create_task(_watch_for_shutdown())
await server.wait_for_termination()

# Propagate error so start() can exit with a non-zero code
self._error = self.servicer._error

shutdown_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await shutdown_task

_LOGGER.info("Stopping event loop...")
# We use aiorun to manage the event loop. The aiorun.run() runs
# forever until loop.stop() is called. If we don't stop the
# event loop explicitly here, the python process will not exit.
# It reamins stuck for 5 minutes until liveness and readiness probe
# fails enough times and k8s sends a SIGTERM
asyncio.get_event_loop().stop()
_LOGGER.info("Event loop stopped")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING
from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc
from pynumaflow.accumulator._dtypes import (
Datum,
Expand All @@ -13,7 +13,7 @@
KeyedWindow,
)
from pynumaflow.accumulator.servicer.task_manager import TaskManager
from pynumaflow.shared.server import handle_async_error
from pynumaflow.shared.server import update_context_err
from pynumaflow.types import NumaflowServicerContext


Expand Down Expand Up @@ -57,6 +57,12 @@ def __init__(
):
# The accumulator handler can be a function or a builder class instance.
self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler
self._shutdown_event: asyncio.Event | None = None
self._error: BaseException | None = None

def set_shutdown_event(self, event: asyncio.Event):
"""Wire up the shutdown event created by the server's aexec() coroutine."""
self._shutdown_event = event

async def AccumulateFn(
self,
Expand Down Expand Up @@ -104,20 +110,35 @@ async def AccumulateFn(
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, msg, err_msg)
self._error = msg
if self._shutdown_event is not None:
self._shutdown_event.set()
return
# Send window EOF response or Window result response
# back to the client
else:
yield msg
except BaseException as e:
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, e, err_msg)
self._error = e
if self._shutdown_event is not None:
self._shutdown_event.set()
return
# Wait for the process_input_stream task to finish for a clean exit
try:
await producer
except BaseException as e:
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, e, err_msg)
self._error = e
if self._shutdown_event is not None:
self._shutdown_event.set()
return

async def IsReady(
Expand Down
69 changes: 54 additions & 15 deletions packages/pynumaflow/pynumaflow/batchmapper/async_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import asyncio
import contextlib
import sys

import aiorun
import grpc

Expand All @@ -8,9 +12,11 @@
BATCH_MAP_SOCK_PATH,
MAP_SERVER_INFO_FILE_PATH,
MAX_NUM_THREADS,
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
)
from pynumaflow.batchmapper._dtypes import BatchMapCallable
from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer
from pynumaflow.info.server import write as info_server_write
from pynumaflow.info.types import (
ServerInfo,
MAP_MODE_KEY,
Expand All @@ -19,7 +25,7 @@
ContainerType,
)
from pynumaflow.proto.mapper import map_pb2_grpc
from pynumaflow.shared.server import NumaflowServer, start_async_server
from pynumaflow.shared.server import NumaflowServer


class BatchMapAsyncServer(NumaflowServer):
Expand Down Expand Up @@ -92,13 +98,17 @@ async def handler(
]

self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance)
self._error: BaseException | None = None

def start(self):
"""
Starter function for the Async Batch Map server, we need a separate caller
to the aexec so that all the async coroutines can be started from a single context
"""
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
if self._error:
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
sys.exit(1)

async def aexec(self):
"""
Expand All @@ -108,25 +118,54 @@ async def aexec(self):
# As the server is async, we need to create a new server instance in the
# same thread as the event loop so that all the async calls are made in the
# same context
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)
map_pb2_grpc.add_MapServicer_to_server(
self.servicer,
server,
)
_LOGGER.info("Starting Batch Map Server")

# The asyncio.Event must be created here (inside aexec) rather than in __init__,
# because it must be bound to the running event loop that aiorun creates.
# At __init__ time no event loop exists yet.
shutdown_event = asyncio.Event()
self.servicer.set_shutdown_event(shutdown_event)

map_pb2_grpc.add_MapServicer_to_server(self.servicer, server)

serv_info = ServerInfo.get_default_server_info()
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper]
# Add the MAP_MODE metadata to the server info for the correct map mode
serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap

# Start the async server
await start_async_server(
server_async=server,
sock_path=self.sock_path,
max_threads=self.max_threads,
cleanup_coroutines=list(),
server_info_file=self.server_info_file,
server_info=serv_info,
await server.start()
info_server_write(server_info=serv_info, info_file=self.server_info_file)

_LOGGER.info(
"Async GRPC Server listening on: %s with max threads: %s",
self.sock_path,
self.max_threads,
)

async def _watch_for_shutdown():
"""Wait for the shutdown event and stop the server with a grace period."""
await shutdown_event.wait()
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
# Stop accepting new requests and wait for a maximum of
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

shutdown_task = asyncio.create_task(_watch_for_shutdown())
await server.wait_for_termination()

# Propagate error so start() can exit with a non-zero code
self._error = self.servicer._error

shutdown_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await shutdown_task

_LOGGER.info("Stopping event loop...")
# We use aiorun to manage the event loop. The aiorun.run() runs
# forever until loop.stop() is called. If we don't stop the
# event loop explicitly here, the python process will not exit.
# It reamins stuck for 5 minutes until liveness and readiness probe
# fails enough times and k8s sends a SIGTERM
asyncio.get_event_loop().stop()
_LOGGER.info("Event loop stopped")
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow.shared.server import handle_async_error
from pynumaflow.shared.server import update_context_err
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING

Expand All @@ -26,6 +26,12 @@ def __init__(
):
self.background_tasks = set()
self.__batch_map_handler: BatchMapCallable = handler
self._shutdown_event: asyncio.Event | None = None
self._error: BaseException | None = None

def set_shutdown_event(self, event: asyncio.Event):
"""Wire up the shutdown event created by the server's aexec() coroutine."""
self._shutdown_event = event

async def MapFn(
self,
Expand Down Expand Up @@ -97,8 +103,12 @@ async def MapFn(
await req_queue.put(datum)

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, err, err_msg)
self._error = err
if self._shutdown_event is not None:
self._shutdown_event.set()
return

async def IsReady(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError, Message, Messages
from pynumaflow._metadata import _user_and_system_metadata_from_proto
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.server import handle_async_error
from pynumaflow.shared.server import update_context_err
from pynumaflow.types import NumaflowServicerContext


Expand All @@ -25,6 +25,12 @@ def __init__(
):
self.background_tasks = set()
self.__map_handler: MapAsyncCallable = handler
self._shutdown_event: asyncio.Event | None = None
self._error: BaseException | None = None

def set_shutdown_event(self, event: asyncio.Event):
"""Wire up the shutdown event created by the server's aexec() coroutine."""
self._shutdown_event = event

async def MapFn(
self,
Expand Down Expand Up @@ -57,16 +63,25 @@ async def MapFn(
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, msg, err_msg)
self._error = msg
if self._shutdown_event is not None:
self._shutdown_event.set()
return
# Send window response back to the client
else:
yield msg
# wait for the producer task to complete
await producer
except BaseException as e:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, e, err_msg)
self._error = e
if self._shutdown_event is not None:
self._shutdown_event.set()
return

async def _process_inputs(
Expand Down
Loading
Loading