From d7d7c7183ba47121f91bb5dba052f166b0f2de5b Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 00:42:09 +0000 Subject: [PATCH 1/7] Validate message fields before protobuf encoding for better error messages Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/_validate.py | 14 +++++++++ .../pynumaflow/accumulator/_dtypes.py | 2 ++ .../pynumaflow/batchmapper/_dtypes.py | 2 ++ .../pynumaflow/pynumaflow/mapper/_dtypes.py | 2 ++ .../pynumaflow/mapstreamer/_dtypes.py | 2 ++ .../pynumaflow/pynumaflow/reducer/_dtypes.py | 2 ++ .../pynumaflow/reducestreamer/_dtypes.py | 2 ++ .../pynumaflow/reducestreamer/async_server.py | 7 ++++- .../reducestreamer/servicer/async_servicer.py | 2 +- .../reducestreamer/servicer/task_manager.py | 29 ++++++++++++++++--- .../pynumaflow/pynumaflow/sinker/_dtypes.py | 2 ++ .../pynumaflow/pynumaflow/sourcer/_dtypes.py | 2 ++ .../pynumaflow/sourcetransformer/_dtypes.py | 2 ++ 13 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 packages/pynumaflow/pynumaflow/_validate.py diff --git a/packages/pynumaflow/pynumaflow/_validate.py b/packages/pynumaflow/pynumaflow/_validate.py new file mode 100644 index 00000000..4621ed69 --- /dev/null +++ b/packages/pynumaflow/pynumaflow/_validate.py @@ -0,0 +1,14 @@ +def _validate_message_fields(value, keys, tags): + """Validate common Message fields at construction time. + + Raises TypeError with a clear message pointing at the caller's code + rather than letting bad types propagate to protobuf serialization. + """ + if value is not None and not isinstance(value, bytes): + raise TypeError(f"Message 'value' must be bytes, got {type(value).__name__}") + if keys is not None: + if not isinstance(keys, list) or not all(isinstance(k, str) for k in keys): + raise TypeError(f"Message 'keys' must be a list of strings, got {keys!r}") + if tags is not None: + if not isinstance(tags, list) or not all(isinstance(t, str) for t in tags): + raise TypeError(f"Message 'tags' must be a list of strings, got {tags!r}") diff --git a/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py b/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py index 62f388c7..c6899529 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") @@ -389,6 +390,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py b/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py index 0e1eea92..754f4916 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py @@ -6,6 +6,7 @@ from collections.abc import AsyncIterable, Callable from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") B = TypeVar("B", bound="BatchResponse") @@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str] """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/mapper/_dtypes.py b/packages/pynumaflow/pynumaflow/mapper/_dtypes.py index 31245767..06fa35d2 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/mapper/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow._constants import DROP from pynumaflow._metadata import UserMetadata, SystemMetadata +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -40,6 +41,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py b/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py index 81bc02e9..8c8659e6 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py @@ -7,6 +7,7 @@ from warnings import warn from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str] """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/reducer/_dtypes.py b/packages/pynumaflow/pynumaflow/reducer/_dtypes.py index f39a7f82..98885021 100644 --- a/packages/pynumaflow/pynumaflow/reducer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/reducer/_dtypes.py @@ -10,6 +10,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -48,6 +49,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str] """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py b/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py index dfc652de..f8749795 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") @@ -270,6 +271,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 0d0f251b..2ed3f199 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -171,7 +171,12 @@ def start(self): _LOGGER.info( "Starting Async Reduce Stream Server", ) - aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + def _shutdown_handler(loop): + _LOGGER.info("Received graceful shutdown signal, shutting down ReduceStream server") + if self.shutdown_callback: + self.shutdown_callback(loop) + + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=_shutdown_handler) if self._error: _LOGGER.critical("Server exiting due to UDF error: %s", self._error) sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py index 569ad8c9..3369c522 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py @@ -101,7 +101,7 @@ async def ReduceFn( # If the message is an exception, we raise the exception if isinstance(msg, BaseException): err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" - _LOGGER.critical(err_msg, exc_info=True) + _LOGGER.critical(err_msg, exc_info=msg) update_context_err(context, msg, err_msg) self._error = msg if self._shutdown_event is not None: diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py index 2e436178..324c46f2 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py @@ -275,6 +275,18 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2. _LOGGER.critical(err_msg, exc_info=True) await self.global_result_queue.put(e) + # Cancel and await remaining tasks to suppress "never retrieved" warnings + for task in self.get_tasks(): + for fut in (task.future, task.consumer_future): + if fut and not fut.done(): + fut.cancel() + for fut in (task.future, task.consumer_future): + if fut: + try: + await fut + except (asyncio.CancelledError, BaseException): + pass + async def write_to_global_queue( self, input_queue: NonBlockingIterator, output_queue: NonBlockingIterator, window ): @@ -284,10 +296,19 @@ async def write_to_global_queue( to the global result queue """ reader = input_queue.read_iterator() - async for msg in reader: - res = reduce_pb2.ReduceResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags) - out = reduce_pb2.ReduceResponse(result=res, window=window) - await output_queue.put(out) + try: + async for msg in reader: + res = reduce_pb2.ReduceResponse.Result( + keys=msg.keys, value=msg.value, tags=msg.tags + ) + out = reduce_pb2.ReduceResponse(result=res, window=window) + await output_queue.put(out) + except Exception as e: + # Using Exception (not BaseException) so that asyncio.CancelledError + # (a BaseException subclass in Python 3.9+) propagates normally + # when the task is cancelled during shutdown. + _LOGGER.critical("Error serializing reduce result: %s", e, exc_info=True) + await output_queue.put(e) def clean_background(self, task): """ diff --git a/packages/pynumaflow/pynumaflow/sinker/_dtypes.py b/packages/pynumaflow/pynumaflow/sinker/_dtypes.py index d807f591..656fed7f 100644 --- a/packages/pynumaflow/pynumaflow/sinker/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sinker/_dtypes.py @@ -7,6 +7,7 @@ from warnings import warn from pynumaflow._metadata import SystemMetadata, UserMetadata +from pynumaflow._validate import _validate_message_fields R = TypeVar("R", bound="Response") Rs = TypeVar("Rs", bound="Responses") @@ -35,6 +36,7 @@ def __init__( keys: list[str] | None = None, user_metadata: UserMetadata | None = None, ): + _validate_message_fields(value, keys, None) self._value = value self._keys = keys self._user_metadata = user_metadata diff --git a/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py b/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py index 99b0d5a8..4f407483 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py @@ -6,6 +6,7 @@ from typing import TypeAlias from pynumaflow._metadata import UserMetadata +from pynumaflow._validate import _validate_message_fields from pynumaflow.shared.asynciter import NonBlockingIterator @@ -77,6 +78,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(payload, keys, None) self._payload = payload self._offset = offset self._event_time = event_time diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py b/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py index b53f61f8..66ffb195 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow._constants import DROP from pynumaflow._metadata import UserMetadata, SystemMetadata +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -43,6 +44,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._tags = tags or [] self._keys = keys or [] From 7035915bb9c51d96c9bfec4b6e4621e5448aa5f6 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 00:48:04 +0000 Subject: [PATCH 2/7] file formatting Branch-Creation-Time: 2026-03-12T00:42:31+0000 Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/reducestreamer/async_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 2ed3f199..63123c88 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -171,6 +171,7 @@ def start(self): _LOGGER.info( "Starting Async Reduce Stream Server", ) + def _shutdown_handler(loop): _LOGGER.info("Received graceful shutdown signal, shutting down ReduceStream server") if self.shutdown_callback: From 30c3065a00cf80ba9ac04b15fad24410808354c1 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 04:06:31 +0000 Subject: [PATCH 3/7] Unit tests Signed-off-by: Sreekanth --- .../reducestreamer/test_async_reduce_err.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index eaaea013..1fa60624 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -346,6 +346,95 @@ async def requests(): ctx.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) +async def _blocking_handler(keys, datums, output, md): + """Handler that blocks forever reading datums (never finishes on its own).""" + async for _ in datums: + pass + await output.put(Message(b"done", keys=keys)) + + +def test_cancel_and_await_remaining_tasks_on_post_processing_error(): + """ + When a BaseException occurs during post-processing (after the input stream + is exhausted), the TaskManager should cancel and await all remaining task + futures to suppress 'never retrieved' warnings. + """ + from unittest.mock import patch + + tm = TaskManager(_blocking_handler) + + request, _ = start_request(multiple_window=False) + # Use OPEN so create_task is called + request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN + + async def _run(): + async def requests(): + yield request + + # Patch stream_send_eof to raise after the task is created but before + # it completes, so the task futures are still running when the except + # block executes. + with patch.object(tm, "stream_send_eof", side_effect=RuntimeError("send_eof boom")): + await tm.process_input_stream(requests()) + + # After process_input_stream returns, verify the error was placed in + # the global result queue. + reader = tm.global_result_queue.read_iterator() + first_item = await reader.__anext__() + assert isinstance(first_item, RuntimeError) + assert "send_eof boom" in str(first_item) + + # Verify all task futures completed (cancelled or finished). + for task in tm.get_tasks(): + assert task.future.done(), "task.future should be done after cleanup" + assert task.consumer_future.done(), "task.consumer_future should be done after cleanup" + + asyncio.run(_run()) + + +def test_cancel_and_await_with_already_done_futures(): + """ + When post-processing fails but some futures are already done, + the cleanup code should handle them gracefully (skip cancellation). + """ + from unittest.mock import patch + + async def _fast_handler(keys, datums, output, md): + """Handler that finishes immediately without reading datums.""" + await output.put(Message(b"fast", keys=keys)) + + tm = TaskManager(_fast_handler) + request, _ = start_request(multiple_window=False) + request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN + + async def _run(): + async def requests(): + yield request + + # Let the real stream_send_eof run (which sends EOF to the handler), + # then patch get_unique_windows to raise after all tasks complete. + original_send_eof = tm.stream_send_eof + + async def send_eof_then_wait_and_raise(): + await original_send_eof() + # Wait for the task futures to finish + for task in tm.get_tasks(): + await task.future + await task.result_queue.put("__STREAM_EOF__") + await task.consumer_future + raise RuntimeError("late post-processing error") + + with patch.object(tm, "stream_send_eof", side_effect=send_eof_then_wait_and_raise): + await tm.process_input_stream(requests()) + + # Verify cleanup completed without issues + for task in tm.get_tasks(): + assert task.future.done() + assert task.consumer_future.done() + + asyncio.run(_run()) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() From 7af7468b4a8de20822f64eb514cfda3134722e36 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 04:21:10 +0000 Subject: [PATCH 4/7] Unit tests Signed-off-by: Sreekanth --- .../reducestreamer/test_async_reduce_err.py | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index 1fa60624..79d9f8cf 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -353,23 +353,44 @@ async def _blocking_handler(keys, datums, output, md): await output.put(Message(b"done", keys=keys)) +def _make_reduce_request(operation_event): + """Create a ReduceRequest DTO (not raw protobuf) matching what datum_generator produces.""" + from pynumaflow.reducestreamer._dtypes import ReduceRequest as ReduceRequestDTO + + event_time_timestamp, watermark_timestamp = get_time_args() + window = reduce_pb2.Window( + start=mock_interval_window_start(), + end=mock_interval_window_end(), + slot="slot-0", + ) + payload = Datum( + keys=["test_key"], + value=mock_message(), + event_time=event_time_timestamp.ToDatetime(), + watermark=watermark_timestamp.ToDatetime(), + ) + return ReduceRequestDTO( + operation=operation_event, + windows=[window], + payload=payload, + ) + + def test_cancel_and_await_remaining_tasks_on_post_processing_error(): """ When a BaseException occurs during post-processing (after the input stream is exhausted), the TaskManager should cancel and await all remaining task - futures to suppress 'never retrieved' warnings. + futures that are still running. """ from unittest.mock import patch + from pynumaflow.reducestreamer._dtypes import WindowOperation tm = TaskManager(_blocking_handler) - - request, _ = start_request(multiple_window=False) - # Use OPEN so create_task is called - request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN + req = _make_reduce_request(int(WindowOperation.OPEN)) async def _run(): async def requests(): - yield request + yield req # Patch stream_send_eof to raise after the task is created but before # it completes, so the task futures are still running when the except @@ -377,6 +398,9 @@ async def requests(): with patch.object(tm, "stream_send_eof", side_effect=RuntimeError("send_eof boom")): await tm.process_input_stream(requests()) + # Verify tasks were actually created + assert len(tm.get_tasks()) > 0, "tasks should have been created" + # After process_input_stream returns, verify the error was placed in # the global result queue. reader = tm.global_result_queue.read_iterator() @@ -395,38 +419,41 @@ async def requests(): def test_cancel_and_await_with_already_done_futures(): """ When post-processing fails but some futures are already done, - the cleanup code should handle them gracefully (skip cancellation). + the cleanup code should skip cancellation for those (fut.done() is True). """ from unittest.mock import patch + from pynumaflow.reducestreamer._dtypes import WindowOperation + from pynumaflow._constants import STREAM_EOF async def _fast_handler(keys, datums, output, md): """Handler that finishes immediately without reading datums.""" await output.put(Message(b"fast", keys=keys)) tm = TaskManager(_fast_handler) - request, _ = start_request(multiple_window=False) - request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN + req = _make_reduce_request(int(WindowOperation.OPEN)) async def _run(): async def requests(): - yield request + yield req - # Let the real stream_send_eof run (which sends EOF to the handler), - # then patch get_unique_windows to raise after all tasks complete. original_send_eof = tm.stream_send_eof async def send_eof_then_wait_and_raise(): + # Let the real stream_send_eof run (sends EOF to handler input) await original_send_eof() - # Wait for the task futures to finish + # Wait for all task futures to complete so they are .done() for task in tm.get_tasks(): await task.future - await task.result_queue.put("__STREAM_EOF__") + await task.result_queue.put(STREAM_EOF) await task.consumer_future raise RuntimeError("late post-processing error") with patch.object(tm, "stream_send_eof", side_effect=send_eof_then_wait_and_raise): await tm.process_input_stream(requests()) + # Verify tasks were actually created + assert len(tm.get_tasks()) > 0, "tasks should have been created" + # Verify cleanup completed without issues for task in tm.get_tasks(): assert task.future.done() From 7ee08fc7533c41306d3270ff56455e7239af0497 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 04:28:25 +0000 Subject: [PATCH 5/7] Unit tests Signed-off-by: Sreekanth --- .../tests/reducestreamer/test_async_reduce.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index fd7a49bd..59974bcb 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -295,6 +295,70 @@ def test_max_threads(self): self.assertEqual(server.max_threads, 4) + def test_start_shutdown_handler_without_callback(self): + """Test that _shutdown_handler logs and works when no shutdown_callback is set.""" + from unittest.mock import patch, MagicMock + + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + self.assertIsNone(server.shutdown_callback) + + def close_coro(coro, **kwargs): + coro.close() + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: + mock_aiorun.run.side_effect = close_coro + server.start() + + # Extract the shutdown_callback passed to aiorun.run + call_kwargs = mock_aiorun.run.call_args[1] + shutdown_handler = call_kwargs["shutdown_callback"] + + # Invoke the handler — should not raise even without a callback + mock_loop = MagicMock() + shutdown_handler(mock_loop) + + def test_start_shutdown_handler_with_callback(self): + """Test that _shutdown_handler invokes the user-provided shutdown_callback.""" + from unittest.mock import patch, MagicMock + + user_callback = MagicMock() + server = ReduceStreamAsyncServer( + reduce_stream_instance=ExampleClass, shutdown_callback=user_callback + ) + + def close_coro(coro, **kwargs): + coro.close() + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: + mock_aiorun.run.side_effect = close_coro + server.start() + + shutdown_handler = mock_aiorun.run.call_args[1]["shutdown_callback"] + mock_loop = MagicMock() + shutdown_handler(mock_loop) + + user_callback.assert_called_once_with(mock_loop) + + def test_start_exits_on_error(self): + """Test that start() calls sys.exit(1) when servicer reports an error.""" + from unittest.mock import patch + + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + + def fake_aiorun_run(coro, **kwargs): + # Simulate aiorun completing after a UDF error was recorded + coro.close() # prevent "coroutine never awaited" warning + server._error = RuntimeError("UDF failure") + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun, patch( + "pynumaflow.reducestreamer.async_server.sys" + ) as mock_sys: + mock_aiorun.run.side_effect = fake_aiorun_run + server.start() + + mock_sys.exit.assert_called_once_with(1) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() From a0a2eb94c40fc95f63500df85172a4e3b8a3fd84 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 04:29:57 +0000 Subject: [PATCH 6/7] file formatting Signed-off-by: Sreekanth --- packages/pynumaflow/tests/reducestreamer/test_async_reduce.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index 59974bcb..6232b1fb 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -294,7 +294,6 @@ def test_max_threads(self): server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) self.assertEqual(server.max_threads, 4) - def test_start_shutdown_handler_without_callback(self): """Test that _shutdown_handler logs and works when no shutdown_callback is set.""" from unittest.mock import patch, MagicMock From a4715258204b415f4318992b6a828e364ac5a9a8 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 12 Mar 2026 04:36:07 +0000 Subject: [PATCH 7/7] more tests Signed-off-by: Sreekanth --- packages/pynumaflow/tests/test_validate.py | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 packages/pynumaflow/tests/test_validate.py diff --git a/packages/pynumaflow/tests/test_validate.py b/packages/pynumaflow/tests/test_validate.py new file mode 100644 index 00000000..a82d550c --- /dev/null +++ b/packages/pynumaflow/tests/test_validate.py @@ -0,0 +1,31 @@ +import pytest + +from pynumaflow._validate import _validate_message_fields + + +class TestValidateMessageFields: + def test_invalid_value_type_raises(self): + with pytest.raises(TypeError, match="Message 'value' must be bytes, got str"): + _validate_message_fields("not bytes", None, None) + + def test_invalid_keys_type_raises(self): + with pytest.raises(TypeError, match="Message 'keys' must be a list of strings"): + _validate_message_fields(None, "not-a-list", None) + + def test_invalid_keys_element_type_raises(self): + with pytest.raises(TypeError, match="Message 'keys' must be a list of strings"): + _validate_message_fields(None, [1, 2], None) + + def test_invalid_tags_type_raises(self): + with pytest.raises(TypeError, match="Message 'tags' must be a list of strings"): + _validate_message_fields(None, None, "not-a-list") + + def test_invalid_tags_element_type_raises(self): + with pytest.raises(TypeError, match="Message 'tags' must be a list of strings"): + _validate_message_fields(None, None, [1, 2]) + + def test_valid_inputs_no_error(self): + _validate_message_fields(b"data", ["key1"], ["tag1"]) + + def test_all_none_no_error(self): + _validate_message_fields(None, None, None)