diff --git a/packages/pynumaflow/pynumaflow/_validate.py b/packages/pynumaflow/pynumaflow/_validate.py new file mode 100644 index 00000000..4621ed69 --- /dev/null +++ b/packages/pynumaflow/pynumaflow/_validate.py @@ -0,0 +1,14 @@ +def _validate_message_fields(value, keys, tags): + """Validate common Message fields at construction time. + + Raises TypeError with a clear message pointing at the caller's code + rather than letting bad types propagate to protobuf serialization. + """ + if value is not None and not isinstance(value, bytes): + raise TypeError(f"Message 'value' must be bytes, got {type(value).__name__}") + if keys is not None: + if not isinstance(keys, list) or not all(isinstance(k, str) for k in keys): + raise TypeError(f"Message 'keys' must be a list of strings, got {keys!r}") + if tags is not None: + if not isinstance(tags, list) or not all(isinstance(t, str) for t in tags): + raise TypeError(f"Message 'tags' must be a list of strings, got {tags!r}") diff --git a/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py b/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py index 62f388c7..c6899529 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") @@ -389,6 +390,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py b/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py index 0e1eea92..754f4916 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py @@ -6,6 +6,7 @@ from collections.abc import AsyncIterable, Callable from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") B = TypeVar("B", bound="BatchResponse") @@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str] """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/mapper/_dtypes.py b/packages/pynumaflow/pynumaflow/mapper/_dtypes.py index 31245767..06fa35d2 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/mapper/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow._constants import DROP from pynumaflow._metadata import UserMetadata, SystemMetadata +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -40,6 +41,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py b/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py index 81bc02e9..8c8659e6 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py @@ -7,6 +7,7 @@ from warnings import warn from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str] """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/reducer/_dtypes.py b/packages/pynumaflow/pynumaflow/reducer/_dtypes.py index f39a7f82..98885021 100644 --- a/packages/pynumaflow/pynumaflow/reducer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/reducer/_dtypes.py @@ -10,6 +10,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -48,6 +49,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str] """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py b/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py index dfc652de..f8749795 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow._constants import DROP +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") @@ -270,6 +271,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._keys = keys or [] self._tags = tags or [] self._value = value or b"" diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 0d0f251b..63123c88 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -171,7 +171,13 @@ def start(self): _LOGGER.info( "Starting Async Reduce Stream Server", ) - aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + + def _shutdown_handler(loop): + _LOGGER.info("Received graceful shutdown signal, shutting down ReduceStream server") + if self.shutdown_callback: + self.shutdown_callback(loop) + + aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=_shutdown_handler) if self._error: _LOGGER.critical("Server exiting due to UDF error: %s", self._error) sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py index 569ad8c9..3369c522 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py @@ -101,7 +101,7 @@ async def ReduceFn( # If the message is an exception, we raise the exception if isinstance(msg, BaseException): err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" - _LOGGER.critical(err_msg, exc_info=True) + _LOGGER.critical(err_msg, exc_info=msg) update_context_err(context, msg, err_msg) self._error = msg if self._shutdown_event is not None: diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py index 2e436178..324c46f2 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py @@ -275,6 +275,18 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2. _LOGGER.critical(err_msg, exc_info=True) await self.global_result_queue.put(e) + # Cancel and await remaining tasks to suppress "never retrieved" warnings + for task in self.get_tasks(): + for fut in (task.future, task.consumer_future): + if fut and not fut.done(): + fut.cancel() + for fut in (task.future, task.consumer_future): + if fut: + try: + await fut + except (asyncio.CancelledError, BaseException): + pass + async def write_to_global_queue( self, input_queue: NonBlockingIterator, output_queue: NonBlockingIterator, window ): @@ -284,10 +296,19 @@ async def write_to_global_queue( to the global result queue """ reader = input_queue.read_iterator() - async for msg in reader: - res = reduce_pb2.ReduceResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags) - out = reduce_pb2.ReduceResponse(result=res, window=window) - await output_queue.put(out) + try: + async for msg in reader: + res = reduce_pb2.ReduceResponse.Result( + keys=msg.keys, value=msg.value, tags=msg.tags + ) + out = reduce_pb2.ReduceResponse(result=res, window=window) + await output_queue.put(out) + except Exception as e: + # Using Exception (not BaseException) so that asyncio.CancelledError + # (a BaseException subclass in Python 3.9+) propagates normally + # when the task is cancelled during shutdown. + _LOGGER.critical("Error serializing reduce result: %s", e, exc_info=True) + await output_queue.put(e) def clean_background(self, task): """ diff --git a/packages/pynumaflow/pynumaflow/sinker/_dtypes.py b/packages/pynumaflow/pynumaflow/sinker/_dtypes.py index d807f591..656fed7f 100644 --- a/packages/pynumaflow/pynumaflow/sinker/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sinker/_dtypes.py @@ -7,6 +7,7 @@ from warnings import warn from pynumaflow._metadata import SystemMetadata, UserMetadata +from pynumaflow._validate import _validate_message_fields R = TypeVar("R", bound="Response") Rs = TypeVar("Rs", bound="Responses") @@ -35,6 +36,7 @@ def __init__( keys: list[str] | None = None, user_metadata: UserMetadata | None = None, ): + _validate_message_fields(value, keys, None) self._value = value self._keys = keys self._user_metadata = user_metadata diff --git a/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py b/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py index 99b0d5a8..4f407483 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sourcer/_dtypes.py @@ -6,6 +6,7 @@ from typing import TypeAlias from pynumaflow._metadata import UserMetadata +from pynumaflow._validate import _validate_message_fields from pynumaflow.shared.asynciter import NonBlockingIterator @@ -77,6 +78,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(payload, keys, None) self._payload = payload self._offset = offset self._event_time = event_time diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py b/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py index b53f61f8..66ffb195 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py @@ -8,6 +8,7 @@ from pynumaflow._constants import DROP from pynumaflow._metadata import UserMetadata, SystemMetadata +from pynumaflow._validate import _validate_message_fields M = TypeVar("M", bound="Message") Ms = TypeVar("Ms", bound="Messages") @@ -43,6 +44,7 @@ def __init__( """ Creates a Message object to send value to a vertex. """ + _validate_message_fields(value, keys, tags) self._tags = tags or [] self._keys = keys or [] diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index fd7a49bd..6232b1fb 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -294,6 +294,69 @@ def test_max_threads(self): server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) self.assertEqual(server.max_threads, 4) + def test_start_shutdown_handler_without_callback(self): + """Test that _shutdown_handler logs and works when no shutdown_callback is set.""" + from unittest.mock import patch, MagicMock + + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + self.assertIsNone(server.shutdown_callback) + + def close_coro(coro, **kwargs): + coro.close() + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: + mock_aiorun.run.side_effect = close_coro + server.start() + + # Extract the shutdown_callback passed to aiorun.run + call_kwargs = mock_aiorun.run.call_args[1] + shutdown_handler = call_kwargs["shutdown_callback"] + + # Invoke the handler — should not raise even without a callback + mock_loop = MagicMock() + shutdown_handler(mock_loop) + + def test_start_shutdown_handler_with_callback(self): + """Test that _shutdown_handler invokes the user-provided shutdown_callback.""" + from unittest.mock import patch, MagicMock + + user_callback = MagicMock() + server = ReduceStreamAsyncServer( + reduce_stream_instance=ExampleClass, shutdown_callback=user_callback + ) + + def close_coro(coro, **kwargs): + coro.close() + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: + mock_aiorun.run.side_effect = close_coro + server.start() + + shutdown_handler = mock_aiorun.run.call_args[1]["shutdown_callback"] + mock_loop = MagicMock() + shutdown_handler(mock_loop) + + user_callback.assert_called_once_with(mock_loop) + + def test_start_exits_on_error(self): + """Test that start() calls sys.exit(1) when servicer reports an error.""" + from unittest.mock import patch + + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + + def fake_aiorun_run(coro, **kwargs): + # Simulate aiorun completing after a UDF error was recorded + coro.close() # prevent "coroutine never awaited" warning + server._error = RuntimeError("UDF failure") + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun, patch( + "pynumaflow.reducestreamer.async_server.sys" + ) as mock_sys: + mock_aiorun.run.side_effect = fake_aiorun_run + server.start() + + mock_sys.exit.assert_called_once_with(1) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index eaaea013..79d9f8cf 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -346,6 +346,122 @@ async def requests(): ctx.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) +async def _blocking_handler(keys, datums, output, md): + """Handler that blocks forever reading datums (never finishes on its own).""" + async for _ in datums: + pass + await output.put(Message(b"done", keys=keys)) + + +def _make_reduce_request(operation_event): + """Create a ReduceRequest DTO (not raw protobuf) matching what datum_generator produces.""" + from pynumaflow.reducestreamer._dtypes import ReduceRequest as ReduceRequestDTO + + event_time_timestamp, watermark_timestamp = get_time_args() + window = reduce_pb2.Window( + start=mock_interval_window_start(), + end=mock_interval_window_end(), + slot="slot-0", + ) + payload = Datum( + keys=["test_key"], + value=mock_message(), + event_time=event_time_timestamp.ToDatetime(), + watermark=watermark_timestamp.ToDatetime(), + ) + return ReduceRequestDTO( + operation=operation_event, + windows=[window], + payload=payload, + ) + + +def test_cancel_and_await_remaining_tasks_on_post_processing_error(): + """ + When a BaseException occurs during post-processing (after the input stream + is exhausted), the TaskManager should cancel and await all remaining task + futures that are still running. + """ + from unittest.mock import patch + from pynumaflow.reducestreamer._dtypes import WindowOperation + + tm = TaskManager(_blocking_handler) + req = _make_reduce_request(int(WindowOperation.OPEN)) + + async def _run(): + async def requests(): + yield req + + # Patch stream_send_eof to raise after the task is created but before + # it completes, so the task futures are still running when the except + # block executes. + with patch.object(tm, "stream_send_eof", side_effect=RuntimeError("send_eof boom")): + await tm.process_input_stream(requests()) + + # Verify tasks were actually created + assert len(tm.get_tasks()) > 0, "tasks should have been created" + + # After process_input_stream returns, verify the error was placed in + # the global result queue. + reader = tm.global_result_queue.read_iterator() + first_item = await reader.__anext__() + assert isinstance(first_item, RuntimeError) + assert "send_eof boom" in str(first_item) + + # Verify all task futures completed (cancelled or finished). + for task in tm.get_tasks(): + assert task.future.done(), "task.future should be done after cleanup" + assert task.consumer_future.done(), "task.consumer_future should be done after cleanup" + + asyncio.run(_run()) + + +def test_cancel_and_await_with_already_done_futures(): + """ + When post-processing fails but some futures are already done, + the cleanup code should skip cancellation for those (fut.done() is True). + """ + from unittest.mock import patch + from pynumaflow.reducestreamer._dtypes import WindowOperation + from pynumaflow._constants import STREAM_EOF + + async def _fast_handler(keys, datums, output, md): + """Handler that finishes immediately without reading datums.""" + await output.put(Message(b"fast", keys=keys)) + + tm = TaskManager(_fast_handler) + req = _make_reduce_request(int(WindowOperation.OPEN)) + + async def _run(): + async def requests(): + yield req + + original_send_eof = tm.stream_send_eof + + async def send_eof_then_wait_and_raise(): + # Let the real stream_send_eof run (sends EOF to handler input) + await original_send_eof() + # Wait for all task futures to complete so they are .done() + for task in tm.get_tasks(): + await task.future + await task.result_queue.put(STREAM_EOF) + await task.consumer_future + raise RuntimeError("late post-processing error") + + with patch.object(tm, "stream_send_eof", side_effect=send_eof_then_wait_and_raise): + await tm.process_input_stream(requests()) + + # Verify tasks were actually created + assert len(tm.get_tasks()) > 0, "tasks should have been created" + + # Verify cleanup completed without issues + for task in tm.get_tasks(): + assert task.future.done() + assert task.consumer_future.done() + + asyncio.run(_run()) + + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main() diff --git a/packages/pynumaflow/tests/test_validate.py b/packages/pynumaflow/tests/test_validate.py new file mode 100644 index 00000000..a82d550c --- /dev/null +++ b/packages/pynumaflow/tests/test_validate.py @@ -0,0 +1,31 @@ +import pytest + +from pynumaflow._validate import _validate_message_fields + + +class TestValidateMessageFields: + def test_invalid_value_type_raises(self): + with pytest.raises(TypeError, match="Message 'value' must be bytes, got str"): + _validate_message_fields("not bytes", None, None) + + def test_invalid_keys_type_raises(self): + with pytest.raises(TypeError, match="Message 'keys' must be a list of strings"): + _validate_message_fields(None, "not-a-list", None) + + def test_invalid_keys_element_type_raises(self): + with pytest.raises(TypeError, match="Message 'keys' must be a list of strings"): + _validate_message_fields(None, [1, 2], None) + + def test_invalid_tags_type_raises(self): + with pytest.raises(TypeError, match="Message 'tags' must be a list of strings"): + _validate_message_fields(None, None, "not-a-list") + + def test_invalid_tags_element_type_raises(self): + with pytest.raises(TypeError, match="Message 'tags' must be a list of strings"): + _validate_message_fields(None, None, [1, 2]) + + def test_valid_inputs_no_error(self): + _validate_message_fields(b"data", ["key1"], ["tag1"]) + + def test_all_none_no_error(self): + _validate_message_fields(None, None, None)