|
11 | 11 |
|
12 | 12 | from pynumaflow import setup_logging |
13 | 13 | from pynumaflow._constants import MAX_MESSAGE_SIZE |
| 14 | +from pynumaflow.proto.common import metadata_pb2 |
14 | 15 | from pynumaflow.proto.sourcetransformer import transform_pb2_grpc |
15 | 16 | from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer |
16 | 17 | from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer |
@@ -267,6 +268,122 @@ def test_max_threads(self): |
267 | 268 | self.assertEqual(server.max_threads, 4) |
268 | 269 |
|
269 | 270 |
|
| 271 | +class MetadataAsyncSourceTransformer(SourceTransformer): |
| 272 | + """Source transformer that validates and passes through metadata.""" |
| 273 | + |
| 274 | + async def handler(self, keys: list[str], datum: Datum) -> Messages: |
| 275 | + # Validate system metadata |
| 276 | + if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0": |
| 277 | + raise ValueError("System metadata version mismatch") |
| 278 | + |
| 279 | + val = datum.value |
| 280 | + msg = "payload:{} event_time:{} ".format( |
| 281 | + val.decode("utf-8"), |
| 282 | + datum.event_time, |
| 283 | + ) |
| 284 | + val = bytes(msg, encoding="utf-8") |
| 285 | + messages = Messages() |
| 286 | + # Pass user metadata to the output message |
| 287 | + messages.append( |
| 288 | + Message(val, mock_new_event_time(), keys=keys, user_metadata=datum.user_metadata) |
| 289 | + ) |
| 290 | + return messages |
| 291 | + |
| 292 | + |
| 293 | +_metadata_s: Server = None |
| 294 | +_metadata_channel = grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") |
| 295 | +_metadata_loop = None |
| 296 | + |
| 297 | + |
| 298 | +def metadata_startup_callable(loop): |
| 299 | + asyncio.set_event_loop(loop) |
| 300 | + loop.run_forever() |
| 301 | + |
| 302 | + |
| 303 | +def new_metadata_async_st(): |
| 304 | + handle = MetadataAsyncSourceTransformer() |
| 305 | + server = SourceTransformAsyncServer(source_transform_instance=handle) |
| 306 | + return server.servicer |
| 307 | + |
| 308 | + |
| 309 | +async def start_metadata_server(udfs): |
| 310 | + _server_options = [ |
| 311 | + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), |
| 312 | + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), |
| 313 | + ] |
| 314 | + server = grpc.aio.server(options=_server_options) |
| 315 | + transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server) |
| 316 | + listen_addr = "unix:///tmp/async_st_metadata.sock" |
| 317 | + server.add_insecure_port(listen_addr) |
| 318 | + logging.info("Starting metadata server on %s", listen_addr) |
| 319 | + global _metadata_s |
| 320 | + _metadata_s = server |
| 321 | + await server.start() |
| 322 | + await server.wait_for_termination() |
| 323 | + |
| 324 | + |
| 325 | +@patch("psutil.Process.kill", mock_terminate_on_stop) |
| 326 | +class TestAsyncTransformerMetadata(unittest.TestCase): |
| 327 | + @classmethod |
| 328 | + def setUpClass(cls) -> None: |
| 329 | + global _metadata_loop |
| 330 | + loop = asyncio.new_event_loop() |
| 331 | + _metadata_loop = loop |
| 332 | + _thread = threading.Thread(target=metadata_startup_callable, args=(loop,), daemon=True) |
| 333 | + _thread.start() |
| 334 | + udfs = new_metadata_async_st() |
| 335 | + asyncio.run_coroutine_threadsafe(start_metadata_server(udfs), loop=loop) |
| 336 | + while True: |
| 337 | + try: |
| 338 | + with grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") as channel: |
| 339 | + f = grpc.channel_ready_future(channel) |
| 340 | + f.result(timeout=10) |
| 341 | + if f.done(): |
| 342 | + break |
| 343 | + except grpc.FutureTimeoutError as e: |
| 344 | + LOGGER.error("error trying to connect to grpc server") |
| 345 | + LOGGER.error(e) |
| 346 | + |
| 347 | + @classmethod |
| 348 | + def tearDownClass(cls) -> None: |
| 349 | + try: |
| 350 | + _metadata_loop.stop() |
| 351 | + LOGGER.info("stopped the metadata event loop") |
| 352 | + except Exception as e: |
| 353 | + LOGGER.error(e) |
| 354 | + |
| 355 | + def test_source_transformer_with_metadata(self) -> None: |
| 356 | + stub = transform_pb2_grpc.SourceTransformStub(_metadata_channel) |
| 357 | + request = get_test_datums(with_metadata=True) |
| 358 | + generator_response = None |
| 359 | + try: |
| 360 | + generator_response = stub.SourceTransformFn(request_iterator=request_generator(request)) |
| 361 | + except grpc.RpcError as e: |
| 362 | + logging.error(e) |
| 363 | + raise |
| 364 | + |
| 365 | + responses = [] |
| 366 | + for r in generator_response: |
| 367 | + responses.append(r) |
| 368 | + |
| 369 | + # 1 handshake + 3 data responses |
| 370 | + self.assertEqual(4, len(responses)) |
| 371 | + self.assertTrue(responses[0].handshake.sot) |
| 372 | + |
| 373 | + # Verify metadata is passed through correctly |
| 374 | + for idx, resp in enumerate(responses[1:], 1): |
| 375 | + _id = "test-id-" + str(idx) |
| 376 | + self.assertEqual(_id, resp.id) |
| 377 | + self.assertEqual(1, len(resp.results)) |
| 378 | + # Verify user metadata is returned |
| 379 | + self.assertEqual( |
| 380 | + resp.results[0].metadata.user_metadata["custom_info"], |
| 381 | + metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}), |
| 382 | + ) |
| 383 | + # System metadata should be empty in responses (user cannot set it) |
| 384 | + self.assertEqual(resp.results[0].metadata.sys_metadata, {}) |
| 385 | + |
| 386 | + |
270 | 387 | if __name__ == "__main__": |
271 | 388 | logging.basicConfig(level=logging.DEBUG) |
272 | 389 | unittest.main() |
0 commit comments