diff --git a/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py b/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py index 658c76a4..62f388c7 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py +++ b/packages/pynumaflow/pynumaflow/accumulator/_dtypes.py @@ -17,9 +17,9 @@ class WindowOperation(IntEnum): Enumerate the type of Window operation received. """ - OPEN = (0,) - CLOSE = (1,) - APPEND = (2,) + OPEN = 0 + CLOSE = 1 + APPEND = 2 @dataclass(init=False, slots=True) diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py index 77c3b294..a9758bf7 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py @@ -14,6 +14,7 @@ _AccumulatorBuilderClass, AccumulatorAsyncCallable, WindowOperation, + AccumulatorRequest, ) from pynumaflow.proto.accumulator import accumulator_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator @@ -93,7 +94,7 @@ async def stream_send_eof(self): for unified_key in task_keys: await self.tasks[unified_key].iterator.put(STREAM_EOF) - async def close_task(self, req): + async def close_task(self, req: AccumulatorRequest): """ Closes a running accumulator task for a given key. Based on the request we compute the unique key, and then @@ -104,8 +105,9 @@ async def close_task(self, req): 3. Wait for all the results from the task to be written to the global result queue 4. Remove the task from the tracker """ - d = req.payload - keys = d.keys + # Use keyed_window.keys for task lookup since payload.keys may be empty + # (e.g., CLOSE operations don't carry data, so payload.keys is not populated). + keys = req.keyed_window.keys unified_key = build_unique_key_name(keys) curr_task = self.tasks.get(unified_key, None) @@ -120,14 +122,16 @@ async def close_task(self, req): # Put the exception in the result queue await self.global_result_queue.put(err) - async def create_task(self, req): + async def create_task(self, req: AccumulatorRequest): """ Creates a new accumulator task for the given request. Based on the request we compute a unique key, and then it creates a new task or appends the request to the existing task. """ d = req.payload - keys = d.keys + # Use keyed_window.keys for task lookup — the authoritative key identity + # for the window, consistent across all operation types (OPEN, APPEND, CLOSE). + keys = req.keyed_window.keys unified_key = build_unique_key_name(keys) curr_task = self.tasks.get(unified_key, None) @@ -138,7 +142,7 @@ async def create_task(self, req): # Create a new result queue for the current task # We create a new result queue for each task, so that # the results of the accumulator operation can be sent to the - # the global result queue, which in turn sends the results + # global result queue, which in turn sends the results # to the client. res_queue = NonBlockingIterator() @@ -172,13 +176,14 @@ async def create_task(self, req): # Put the request in the iterator await curr_task.iterator.put(d) - async def send_datum_to_task(self, req): + async def send_datum_to_task(self, req: AccumulatorRequest): """ Appends the request to the existing window reduce task. If the task does not exist, create it. """ d = req.payload - keys = d.keys + # Use keyed_window.keys for task lookup to match the key used in create_task/close_task. + keys = req.keyed_window.keys unified_key = build_unique_key_name(keys) result = self.tasks.get(unified_key, None) if not result: @@ -215,9 +220,7 @@ async def __invoke_accumulator( # Put the exception in the result queue await self.global_result_queue.put(err) - async def process_input_stream( - self, request_iterator: AsyncIterable[accumulator_pb2.AccumulatorRequest] - ): + async def process_input_stream(self, request_iterator: AsyncIterable[AccumulatorRequest]): # Start iterating through the request iterator and create tasks # based on the operation type received. try: @@ -226,15 +229,15 @@ async def process_input_stream( request_count += 1 # check whether the request is an open, append, or close operation match request.operation: - case int(WindowOperation.OPEN): + case WindowOperation.OPEN: # create a new task for the open operation and # put the request in the task iterator await self.create_task(request) - case int(WindowOperation.APPEND): + case WindowOperation.APPEND: # append the task data to the existing task # if the task does not exist, create a new task await self.send_datum_to_task(request) - case int(WindowOperation.CLOSE): + case WindowOperation.CLOSE: # close the current task for req await self.close_task(request) case _: diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py index e0927f8e..fa8f0d60 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py @@ -30,9 +30,11 @@ def request_generator(count, request, resetkey: bool = False, send_close: bool = False): for i in range(count): if resetkey: - # Clear previous keys and add new ones + # Update keys on both payload and keyedWindow to match real platform behavior del request.payload.keys[:] request.payload.keys.extend([f"key-{i}"]) + del request.operation.keyedWindow.keys[:] + request.operation.keyedWindow.keys.extend([f"key-{i}"]) # Set operation based on index - first is OPEN, rest are APPEND if i == 0: @@ -52,9 +54,11 @@ def request_generator(count, request, resetkey: bool = False, send_close: bool = def request_generator_append_only(count, request, resetkey: bool = False): for i in range(count): if resetkey: - # Clear previous keys and add new ones + # Update keys on both payload and keyedWindow to match real platform behavior del request.payload.keys[:] request.payload.keys.extend([f"key-{i}"]) + del request.operation.keyedWindow.keys[:] + request.operation.keyedWindow.keys.extend([f"key-{i}"]) # Set operation to APPEND for all requests request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND @@ -64,9 +68,11 @@ def request_generator_append_only(count, request, resetkey: bool = False): def request_generator_mixed(count, request, resetkey: bool = False): for i in range(count): if resetkey: - # Clear previous keys and add new ones + # Update keys on both payload and keyedWindow to match real platform behavior del request.payload.keys[:] request.payload.keys.extend([f"key-{i}"]) + del request.operation.keyedWindow.keys[:] + request.operation.keyedWindow.keys.extend([f"key-{i}"]) if i % 2 == 0: # Set operation to APPEND for even requests @@ -107,7 +113,12 @@ def start_request() -> accumulator_pb2.AccumulatorRequest: def start_request_without_open() -> accumulator_pb2.AccumulatorRequest: event_time_timestamp, watermark_timestamp = get_time_args() - + window = accumulator_pb2.KeyedWindow( + start=mock_interval_window_start(), + end=mock_interval_window_end(), + slot="slot-0", + keys=["test_key"], + ) payload = accumulator_pb2.Payload( keys=["test_key"], value=mock_message(), @@ -115,9 +126,13 @@ def start_request_without_open() -> accumulator_pb2.AccumulatorRequest: watermark=watermark_timestamp, id="test_id", ) - + operation = accumulator_pb2.AccumulatorRequest.WindowOperation( + event=accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND, + keyedWindow=window, + ) request = accumulator_pb2.AccumulatorRequest( payload=payload, + operation=operation, ) return request