Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions packages/pynumaflow/pynumaflow/_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def _validate_message_fields(value, keys, tags):
"""Validate common Message fields at construction time.

Raises TypeError with a clear message pointing at the caller's code
rather than letting bad types propagate to protobuf serialization.
"""
if value is not None and not isinstance(value, bytes):
raise TypeError(f"Message 'value' must be bytes, got {type(value).__name__}")
if keys is not None:
if not isinstance(keys, list) or not all(isinstance(k, str) for k in keys):
raise TypeError(f"Message 'keys' must be a list of strings, got {keys!r}")
if tags is not None:
if not isinstance(tags, list) or not all(isinstance(t, str) for t in tags):
raise TypeError(f"Message 'tags' must be a list of strings, got {tags!r}")
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/accumulator/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow._constants import DROP
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")

Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._keys = keys or []
self._tags = tags or []
self._value = value or b""
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import AsyncIterable, Callable

from pynumaflow._constants import DROP
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")
B = TypeVar("B", bound="BatchResponse")
Expand All @@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str]
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._keys = keys or []
self._tags = tags or []
self._value = value or b""
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/mapper/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pynumaflow._constants import DROP
from pynumaflow._metadata import UserMetadata, SystemMetadata
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")
Ms = TypeVar("Ms", bound="Messages")
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._keys = keys or []
self._tags = tags or []
self._value = value or b""
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from warnings import warn

from pynumaflow._constants import DROP
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")
Ms = TypeVar("Ms", bound="Messages")
Expand All @@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str]
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._keys = keys or []
self._tags = tags or []
self._value = value or b""
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/reducer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow._constants import DROP
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")
Ms = TypeVar("Ms", bound="Messages")
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str]
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._keys = keys or []
self._tags = tags or []
self._value = value or b""
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow._constants import DROP
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")

Expand Down Expand Up @@ -270,6 +271,7 @@ def __init__(
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._keys = keys or []
self._tags = tags or []
self._value = value or b""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ def start(self):
_LOGGER.info(
"Starting Async Reduce Stream Server",
)
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)

def _shutdown_handler(loop):
_LOGGER.info("Received graceful shutdown signal, shutting down ReduceStream server")
if self.shutdown_callback:
self.shutdown_callback(loop)

aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=_shutdown_handler)
if self._error:
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
sys.exit(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def ReduceFn(
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
_LOGGER.critical(err_msg, exc_info=True)
_LOGGER.critical(err_msg, exc_info=msg)
update_context_err(context, msg, err_msg)
self._error = msg
if self._shutdown_event is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,18 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2.
_LOGGER.critical(err_msg, exc_info=True)
await self.global_result_queue.put(e)

# Cancel and await remaining tasks to suppress "never retrieved" warnings
for task in self.get_tasks():
for fut in (task.future, task.consumer_future):
if fut and not fut.done():
fut.cancel()
for fut in (task.future, task.consumer_future):
if fut:
try:
await fut
except (asyncio.CancelledError, BaseException):
pass

async def write_to_global_queue(
self, input_queue: NonBlockingIterator, output_queue: NonBlockingIterator, window
):
Expand All @@ -284,10 +296,19 @@ async def write_to_global_queue(
to the global result queue
"""
reader = input_queue.read_iterator()
async for msg in reader:
res = reduce_pb2.ReduceResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
out = reduce_pb2.ReduceResponse(result=res, window=window)
await output_queue.put(out)
try:
async for msg in reader:
res = reduce_pb2.ReduceResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
out = reduce_pb2.ReduceResponse(result=res, window=window)
await output_queue.put(out)
except Exception as e:
# Using Exception (not BaseException) so that asyncio.CancelledError
# (a BaseException subclass in Python 3.9+) propagates normally
# when the task is cancelled during shutdown.
_LOGGER.critical("Error serializing reduce result: %s", e, exc_info=True)
await output_queue.put(e)

def clean_background(self, task):
"""
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/sinker/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from warnings import warn

from pynumaflow._metadata import SystemMetadata, UserMetadata
from pynumaflow._validate import _validate_message_fields

R = TypeVar("R", bound="Response")
Rs = TypeVar("Rs", bound="Responses")
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
keys: list[str] | None = None,
user_metadata: UserMetadata | None = None,
):
_validate_message_fields(value, keys, None)
self._value = value
self._keys = keys
self._user_metadata = user_metadata
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/sourcer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TypeAlias

from pynumaflow._metadata import UserMetadata
from pynumaflow._validate import _validate_message_fields
from pynumaflow.shared.asynciter import NonBlockingIterator


Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(payload, keys, None)
self._payload = payload
self._offset = offset
self._event_time = event_time
Expand Down
2 changes: 2 additions & 0 deletions packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pynumaflow._constants import DROP
from pynumaflow._metadata import UserMetadata, SystemMetadata
from pynumaflow._validate import _validate_message_fields

M = TypeVar("M", bound="Message")
Ms = TypeVar("Ms", bound="Messages")
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
"""
Creates a Message object to send value to a vertex.
"""
_validate_message_fields(value, keys, tags)
self._tags = tags or []
self._keys = keys or []

Expand Down
63 changes: 63 additions & 0 deletions packages/pynumaflow/tests/reducestreamer/test_async_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,69 @@ def test_max_threads(self):
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass)
self.assertEqual(server.max_threads, 4)

def test_start_shutdown_handler_without_callback(self):
"""Test that _shutdown_handler logs and works when no shutdown_callback is set."""
from unittest.mock import patch, MagicMock

server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass)
self.assertIsNone(server.shutdown_callback)

def close_coro(coro, **kwargs):
coro.close()

with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun:
mock_aiorun.run.side_effect = close_coro
server.start()

# Extract the shutdown_callback passed to aiorun.run
call_kwargs = mock_aiorun.run.call_args[1]
shutdown_handler = call_kwargs["shutdown_callback"]

# Invoke the handler — should not raise even without a callback
mock_loop = MagicMock()
shutdown_handler(mock_loop)

def test_start_shutdown_handler_with_callback(self):
"""Test that _shutdown_handler invokes the user-provided shutdown_callback."""
from unittest.mock import patch, MagicMock

user_callback = MagicMock()
server = ReduceStreamAsyncServer(
reduce_stream_instance=ExampleClass, shutdown_callback=user_callback
)

def close_coro(coro, **kwargs):
coro.close()

with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun:
mock_aiorun.run.side_effect = close_coro
server.start()

shutdown_handler = mock_aiorun.run.call_args[1]["shutdown_callback"]
mock_loop = MagicMock()
shutdown_handler(mock_loop)

user_callback.assert_called_once_with(mock_loop)

def test_start_exits_on_error(self):
"""Test that start() calls sys.exit(1) when servicer reports an error."""
from unittest.mock import patch

server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass)

def fake_aiorun_run(coro, **kwargs):
# Simulate aiorun completing after a UDF error was recorded
coro.close() # prevent "coroutine never awaited" warning
server._error = RuntimeError("UDF failure")

with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun, patch(
"pynumaflow.reducestreamer.async_server.sys"
) as mock_sys:
mock_aiorun.run.side_effect = fake_aiorun_run
server.start()

mock_sys.exit.assert_called_once_with(1)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down
Loading
Loading