Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/pynumaflow/pynumaflow/accumulator/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class WindowOperation(IntEnum):
Enumerate the type of Window operation received.
"""

OPEN = (0,)
CLOSE = (1,)
APPEND = (2,)
OPEN = 0
CLOSE = 1
APPEND = 2


@dataclass(init=False, slots=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_AccumulatorBuilderClass,
AccumulatorAsyncCallable,
WindowOperation,
AccumulatorRequest,
)
from pynumaflow.proto.accumulator import accumulator_pb2
from pynumaflow.shared.asynciter import NonBlockingIterator
Expand Down Expand Up @@ -93,7 +94,7 @@ async def stream_send_eof(self):
for unified_key in task_keys:
await self.tasks[unified_key].iterator.put(STREAM_EOF)

async def close_task(self, req):
async def close_task(self, req: AccumulatorRequest):
"""
Closes a running accumulator task for a given key.
Based on the request we compute the unique key, and then
Expand All @@ -104,8 +105,9 @@ async def close_task(self, req):
3. Wait for all the results from the task to be written to the global result queue
4. Remove the task from the tracker
"""
d = req.payload
keys = d.keys
# Use keyed_window.keys for task lookup since payload.keys may be empty
# (e.g., CLOSE operations don't carry data, so payload.keys is not populated).
keys = req.keyed_window.keys
unified_key = build_unique_key_name(keys)
curr_task = self.tasks.get(unified_key, None)

Expand All @@ -120,14 +122,16 @@ async def close_task(self, req):
# Put the exception in the result queue
await self.global_result_queue.put(err)

async def create_task(self, req):
async def create_task(self, req: AccumulatorRequest):
"""
Creates a new accumulator task for the given request.
Based on the request we compute a unique key, and then
it creates a new task or appends the request to the existing task.
"""
d = req.payload
keys = d.keys
# Use keyed_window.keys for task lookup — the authoritative key identity
# for the window, consistent across all operation types (OPEN, APPEND, CLOSE).
keys = req.keyed_window.keys
unified_key = build_unique_key_name(keys)
curr_task = self.tasks.get(unified_key, None)

Expand All @@ -138,7 +142,7 @@ async def create_task(self, req):
# Create a new result queue for the current task
# We create a new result queue for each task, so that
# the results of the accumulator operation can be sent to the
# the global result queue, which in turn sends the results
# global result queue, which in turn sends the results
# to the client.
res_queue = NonBlockingIterator()

Expand Down Expand Up @@ -172,13 +176,14 @@ async def create_task(self, req):
# Put the request in the iterator
await curr_task.iterator.put(d)

async def send_datum_to_task(self, req):
async def send_datum_to_task(self, req: AccumulatorRequest):
"""
Appends the request to the existing window reduce task.
If the task does not exist, create it.
"""
d = req.payload
keys = d.keys
# Use keyed_window.keys for task lookup to match the key used in create_task/close_task.
keys = req.keyed_window.keys
unified_key = build_unique_key_name(keys)
result = self.tasks.get(unified_key, None)
if not result:
Expand Down Expand Up @@ -215,9 +220,7 @@ async def __invoke_accumulator(
# Put the exception in the result queue
await self.global_result_queue.put(err)

async def process_input_stream(
self, request_iterator: AsyncIterable[accumulator_pb2.AccumulatorRequest]
):
async def process_input_stream(self, request_iterator: AsyncIterable[AccumulatorRequest]):
# Start iterating through the request iterator and create tasks
# based on the operation type received.
try:
Expand All @@ -226,15 +229,15 @@ async def process_input_stream(
request_count += 1
# check whether the request is an open, append, or close operation
match request.operation:
case int(WindowOperation.OPEN):
case WindowOperation.OPEN:
# create a new task for the open operation and
# put the request in the task iterator
await self.create_task(request)
case int(WindowOperation.APPEND):
case WindowOperation.APPEND:
# append the task data to the existing task
# if the task does not exist, create a new task
await self.send_datum_to_task(request)
case int(WindowOperation.CLOSE):
case WindowOperation.CLOSE:
# close the current task for req
await self.close_task(request)
case _:
Expand Down
25 changes: 20 additions & 5 deletions packages/pynumaflow/tests/accumulator/test_async_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
def request_generator(count, request, resetkey: bool = False, send_close: bool = False):
for i in range(count):
if resetkey:
# Clear previous keys and add new ones
# Update keys on both payload and keyedWindow to match real platform behavior
del request.payload.keys[:]
request.payload.keys.extend([f"key-{i}"])
del request.operation.keyedWindow.keys[:]
request.operation.keyedWindow.keys.extend([f"key-{i}"])

# Set operation based on index - first is OPEN, rest are APPEND
if i == 0:
Expand All @@ -52,9 +54,11 @@ def request_generator(count, request, resetkey: bool = False, send_close: bool =
def request_generator_append_only(count, request, resetkey: bool = False):
for i in range(count):
if resetkey:
# Clear previous keys and add new ones
# Update keys on both payload and keyedWindow to match real platform behavior
del request.payload.keys[:]
request.payload.keys.extend([f"key-{i}"])
del request.operation.keyedWindow.keys[:]
request.operation.keyedWindow.keys.extend([f"key-{i}"])

# Set operation to APPEND for all requests
request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND
Expand All @@ -64,9 +68,11 @@ def request_generator_append_only(count, request, resetkey: bool = False):
def request_generator_mixed(count, request, resetkey: bool = False):
for i in range(count):
if resetkey:
# Clear previous keys and add new ones
# Update keys on both payload and keyedWindow to match real platform behavior
del request.payload.keys[:]
request.payload.keys.extend([f"key-{i}"])
del request.operation.keyedWindow.keys[:]
request.operation.keyedWindow.keys.extend([f"key-{i}"])

if i % 2 == 0:
# Set operation to APPEND for even requests
Expand Down Expand Up @@ -107,17 +113,26 @@ def start_request() -> accumulator_pb2.AccumulatorRequest:

def start_request_without_open() -> accumulator_pb2.AccumulatorRequest:
event_time_timestamp, watermark_timestamp = get_time_args()

window = accumulator_pb2.KeyedWindow(
start=mock_interval_window_start(),
end=mock_interval_window_end(),
slot="slot-0",
keys=["test_key"],
)
payload = accumulator_pb2.Payload(
keys=["test_key"],
value=mock_message(),
event_time=event_time_timestamp,
watermark=watermark_timestamp,
id="test_id",
)

operation = accumulator_pb2.AccumulatorRequest.WindowOperation(
event=accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND,
keyedWindow=window,
)
request = accumulator_pb2.AccumulatorRequest(
payload=payload,
operation=operation,
)
return request

Expand Down