Skip to content

Commit 14e369a

Browse files
cosmic-chichusrao12
andauthored
feat: add metadata support for source transformer (#312)
Signed-off-by: srao12 <Shrivardhan_Rao@intuit.com> Co-authored-by: srao12 <Shrivardhan_Rao@intuit.com>
1 parent 27ad810 commit 14e369a

8 files changed

Lines changed: 434 additions & 35 deletions

File tree

packages/pynumaflow/pynumaflow/sourcetransformer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer
99
from pynumaflow.sourcetransformer.server import SourceTransformServer
1010
from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer
11+
from pynumaflow._metadata import UserMetadata, SystemMetadata
1112

1213
__all__ = [
1314
"Message",
@@ -18,4 +19,6 @@
1819
"SourceTransformer",
1920
"SourceTransformMultiProcServer",
2021
"SourceTransformAsyncServer",
22+
"UserMetadata",
23+
"SystemMetadata",
2124
]

packages/pynumaflow/pynumaflow/sourcetransformer/_dtypes.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from warnings import warn
88

99
from pynumaflow._constants import DROP
10+
from pynumaflow._metadata import UserMetadata, SystemMetadata
1011

1112
M = TypeVar("M", bound="Message")
1213
Ms = TypeVar("Ms", bound="Messages")
@@ -22,17 +23,24 @@ class Message:
2223
event_time: event time of the message, usually extracted from the payload.
2324
keys: []string keys for vertex (optional)
2425
tags: []string tags for conditional forwarding (optional)
26+
user_metadata: metadata for the message (optional)
2527
"""
2628

27-
__slots__ = ("_value", "_keys", "_tags", "_event_time")
29+
__slots__ = ("_value", "_keys", "_tags", "_event_time", "_user_metadata")
2830

2931
_keys: list[str]
3032
_tags: list[str]
3133
_value: bytes
3234
_event_time: datetime
35+
_user_metadata: UserMetadata
3336

3437
def __init__(
35-
self, value: bytes, event_time: datetime, keys: list[str] = None, tags: list[str] = None
38+
self,
39+
value: bytes,
40+
event_time: datetime,
41+
keys: list[str] = None,
42+
tags: list[str] = None,
43+
user_metadata: Optional[UserMetadata] = None,
3644
):
3745
"""
3846
Creates a Message object to send value to a vertex.
@@ -43,6 +51,7 @@ def __init__(
4351
# There is no year 0, so setting following as default event time.
4452
self._event_time = event_time or datetime(1, 1, 1, 0, 0)
4553
self._value = value or b""
54+
self._user_metadata = user_metadata or UserMetadata()
4655

4756
@classmethod
4857
def to_drop(cls: type[M], event_time: datetime) -> M:
@@ -64,6 +73,10 @@ def value(self) -> bytes:
6473
def tags(self) -> list[str]:
6574
return self._tags
6675

76+
@property
77+
def user_metadata(self) -> UserMetadata:
78+
return self._user_metadata
79+
6780

6881
class Messages(Sequence[M]):
6982
"""
@@ -119,6 +132,8 @@ class Datum:
119132
event_time: the event time of the event.
120133
watermark: the watermark of the event.
121134
headers: the headers of the event.
135+
user_metadata: the user metadata of the event.
136+
system_metadata: the system metadata of the event.
122137
123138
Example:
124139
```py
@@ -135,13 +150,23 @@ class Datum:
135150
```
136151
"""
137152

138-
__slots__ = ("_keys", "_value", "_event_time", "_watermark", "_headers")
153+
__slots__ = (
154+
"_keys",
155+
"_value",
156+
"_event_time",
157+
"_watermark",
158+
"_headers",
159+
"_user_metadata",
160+
"_system_metadata",
161+
)
139162

140163
_keys: list[str]
141164
_value: bytes
142165
_event_time: datetime
143166
_watermark: datetime
144167
_headers: dict[str, str]
168+
_user_metadata: UserMetadata
169+
_system_metadata: SystemMetadata
145170

146171
def __init__(
147172
self,
@@ -150,6 +175,8 @@ def __init__(
150175
event_time: datetime,
151176
watermark: datetime,
152177
headers: Optional[dict[str, str]] = None,
178+
user_metadata: Optional[UserMetadata] = None,
179+
system_metadata: Optional[SystemMetadata] = None,
153180
):
154181
self._keys = keys or list()
155182
self._value = value or b""
@@ -160,6 +187,8 @@ def __init__(
160187
raise TypeError(f"Wrong data type: {type(watermark)} for Datum.watermark")
161188
self._watermark = watermark
162189
self._headers = headers or {}
190+
self._user_metadata = user_metadata or UserMetadata()
191+
self._system_metadata = system_metadata or SystemMetadata()
163192

164193
@property
165194
def keys(self) -> list[str]:
@@ -186,6 +215,16 @@ def headers(self) -> dict[str, str]:
186215
"""Returns the headers of the event."""
187216
return self._headers.copy()
188217

218+
@property
219+
def user_metadata(self) -> UserMetadata:
220+
"""Returns the user metadata of the event."""
221+
return self._user_metadata
222+
223+
@property
224+
def system_metadata(self) -> SystemMetadata:
225+
"""Returns the system metadata of the event."""
226+
return self._system_metadata
227+
189228

190229
class SourceTransformer(metaclass=ABCMeta):
191230
"""

packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from google.protobuf import timestamp_pb2 as _timestamp_pb2
66

77
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
8+
from pynumaflow._metadata import _user_and_system_metadata_from_proto
89
from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc
910
from pynumaflow.shared.asynciter import NonBlockingIterator
1011
from pynumaflow.shared.server import handle_async_error
@@ -105,12 +106,17 @@ async def _invoke_transform(
105106
Invokes the user defined function.
106107
"""
107108
try:
109+
user_metadata, system_metadata = _user_and_system_metadata_from_proto(
110+
request.request.metadata
111+
)
108112
datum = Datum(
109113
keys=list(request.request.keys),
110114
value=request.request.value,
111115
event_time=request.request.event_time.ToDatetime(),
112116
watermark=request.request.watermark.ToDatetime(),
113117
headers=dict(request.request.headers),
118+
user_metadata=user_metadata,
119+
system_metadata=system_metadata,
114120
)
115121
msgs = await self.__transform_handler(list(request.request.keys), datum)
116122
results = []
@@ -123,6 +129,7 @@ async def _invoke_transform(
123129
value=msg.value,
124130
tags=msg.tags,
125131
event_time=event_time_timestamp,
132+
metadata=msg.user_metadata._to_proto(),
126133
)
127134
)
128135
await result_queue.put(

packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pynumaflow.shared.synciter import SyncIterator
1010
from pynumaflow.sourcetransformer import Datum
1111
from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable
12+
from pynumaflow._metadata import _user_and_system_metadata_from_proto
1213
from pynumaflow.proto.sourcetransformer import transform_pb2
1314
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
1415
from pynumaflow.types import NumaflowServicerContext
@@ -119,12 +120,17 @@ def _invoke_transformer(
119120
self, context, request: transform_pb2.SourceTransformRequest, result_queue: SyncIterator
120121
):
121122
try:
123+
user_metadata, system_metadata = _user_and_system_metadata_from_proto(
124+
request.request.metadata
125+
)
122126
d = Datum(
123127
keys=list(request.request.keys),
124128
value=request.request.value,
125129
event_time=request.request.event_time.ToDatetime(),
126130
watermark=request.request.watermark.ToDatetime(),
127131
headers=dict(request.request.headers),
132+
user_metadata=user_metadata,
133+
system_metadata=system_metadata,
128134
)
129135
responses = self.__transform_handler(list(request.request.keys), d)
130136

@@ -138,6 +144,7 @@ def _invoke_transformer(
138144
value=resp.value,
139145
tags=resp.tags,
140146
event_time=event_time_timestamp,
147+
metadata=resp.user_metadata._to_proto(),
141148
)
142149
)
143150
result_queue.put(

packages/pynumaflow/tests/sourcetransform/test_async.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from pynumaflow import setup_logging
1313
from pynumaflow._constants import MAX_MESSAGE_SIZE
14+
from pynumaflow.proto.common import metadata_pb2
1415
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
1516
from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer
1617
from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer
@@ -267,6 +268,122 @@ def test_max_threads(self):
267268
self.assertEqual(server.max_threads, 4)
268269

269270

271+
class MetadataAsyncSourceTransformer(SourceTransformer):
272+
"""Source transformer that validates and passes through metadata."""
273+
274+
async def handler(self, keys: list[str], datum: Datum) -> Messages:
275+
# Validate system metadata
276+
if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0":
277+
raise ValueError("System metadata version mismatch")
278+
279+
val = datum.value
280+
msg = "payload:{} event_time:{} ".format(
281+
val.decode("utf-8"),
282+
datum.event_time,
283+
)
284+
val = bytes(msg, encoding="utf-8")
285+
messages = Messages()
286+
# Pass user metadata to the output message
287+
messages.append(
288+
Message(val, mock_new_event_time(), keys=keys, user_metadata=datum.user_metadata)
289+
)
290+
return messages
291+
292+
293+
_metadata_s: Server = None
294+
_metadata_channel = grpc.insecure_channel("unix:///tmp/async_st_metadata.sock")
295+
_metadata_loop = None
296+
297+
298+
def metadata_startup_callable(loop):
299+
asyncio.set_event_loop(loop)
300+
loop.run_forever()
301+
302+
303+
def new_metadata_async_st():
304+
handle = MetadataAsyncSourceTransformer()
305+
server = SourceTransformAsyncServer(source_transform_instance=handle)
306+
return server.servicer
307+
308+
309+
async def start_metadata_server(udfs):
310+
_server_options = [
311+
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
312+
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
313+
]
314+
server = grpc.aio.server(options=_server_options)
315+
transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server)
316+
listen_addr = "unix:///tmp/async_st_metadata.sock"
317+
server.add_insecure_port(listen_addr)
318+
logging.info("Starting metadata server on %s", listen_addr)
319+
global _metadata_s
320+
_metadata_s = server
321+
await server.start()
322+
await server.wait_for_termination()
323+
324+
325+
@patch("psutil.Process.kill", mock_terminate_on_stop)
326+
class TestAsyncTransformerMetadata(unittest.TestCase):
327+
@classmethod
328+
def setUpClass(cls) -> None:
329+
global _metadata_loop
330+
loop = asyncio.new_event_loop()
331+
_metadata_loop = loop
332+
_thread = threading.Thread(target=metadata_startup_callable, args=(loop,), daemon=True)
333+
_thread.start()
334+
udfs = new_metadata_async_st()
335+
asyncio.run_coroutine_threadsafe(start_metadata_server(udfs), loop=loop)
336+
while True:
337+
try:
338+
with grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") as channel:
339+
f = grpc.channel_ready_future(channel)
340+
f.result(timeout=10)
341+
if f.done():
342+
break
343+
except grpc.FutureTimeoutError as e:
344+
LOGGER.error("error trying to connect to grpc server")
345+
LOGGER.error(e)
346+
347+
@classmethod
348+
def tearDownClass(cls) -> None:
349+
try:
350+
_metadata_loop.stop()
351+
LOGGER.info("stopped the metadata event loop")
352+
except Exception as e:
353+
LOGGER.error(e)
354+
355+
def test_source_transformer_with_metadata(self) -> None:
356+
stub = transform_pb2_grpc.SourceTransformStub(_metadata_channel)
357+
request = get_test_datums(with_metadata=True)
358+
generator_response = None
359+
try:
360+
generator_response = stub.SourceTransformFn(request_iterator=request_generator(request))
361+
except grpc.RpcError as e:
362+
logging.error(e)
363+
raise
364+
365+
responses = []
366+
for r in generator_response:
367+
responses.append(r)
368+
369+
# 1 handshake + 3 data responses
370+
self.assertEqual(4, len(responses))
371+
self.assertTrue(responses[0].handshake.sot)
372+
373+
# Verify metadata is passed through correctly
374+
for idx, resp in enumerate(responses[1:], 1):
375+
_id = "test-id-" + str(idx)
376+
self.assertEqual(_id, resp.id)
377+
self.assertEqual(1, len(resp.results))
378+
# Verify user metadata is returned
379+
self.assertEqual(
380+
resp.results[0].metadata.user_metadata["custom_info"],
381+
metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}),
382+
)
383+
# System metadata should be empty in responses (user cannot set it)
384+
self.assertEqual(resp.results[0].metadata.sys_metadata, {})
385+
386+
270387
if __name__ == "__main__":
271388
logging.basicConfig(level=logging.DEBUG)
272389
unittest.main()

0 commit comments

Comments
 (0)