From 0fd79a41eab5eddf8f6e7a156c11877178b4f95e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 09:27:50 +0800 Subject: [PATCH 1/8] add retry mech for kv_batch_get to make sure all fields are fetched Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 51 +++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 8e4f8c3..9787508 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import importlib.resources as pkg_resources import logging import math @@ -40,6 +41,9 @@ _TRANSFER_QUEUE_CLIENT: Any = None _TRANSFER_QUEUE_STORAGE: Any = None +TQ_KV_POLLING_METADATA_TIMEOUT = int(os.environ.get("TQ_KV_POLLING_METADATA_TIMEOUT", 10)) +TQ_KV_POLLING_METADATA_CHECK_INTERVAL = float(os.environ.get("TQ_KV_POLLING_METADATA_CHECK_INTERVAL", 0.5)) + def _maybe_create_transferqueue_client( conf: Optional[DictConfig] = None, @@ -386,6 +390,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list Raises: RuntimeError: If keys or partition are not found RuntimeError: If empty fields exist in any key (sample) + RuntimeError: If any user-specified fields are not retrived after TQ_KV_POLLING_METADATA_TIMEOUT Example: >>> import transfer_queue as tq @@ -409,6 +414,29 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list if fields is not None: if isinstance(fields, str): fields = [fields] + + target_fields = set(fields) + current_fields = set(batch_meta.field_names) + + not_ready_fields = target_fields - current_fields + begin_polling_time = time.time() + while not_ready_fields: + if time.time() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: + raise RuntimeError( + f"Timeout for kv_batch_get. Missing fields: {not_ready_fields}" + f" after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds." + ) + + logger.warning( + f"Fields {list(not_ready_fields)} are not ready yet! " + f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." + ) + + time.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + current_fields = set(batch_meta.field_names) + not_ready_fields = target_fields - current_fields + batch_meta = batch_meta.select_fields(fields) if not batch_meta.is_ready: @@ -647,6 +675,7 @@ async def async_kv_batch_get( Raises: RuntimeError: If keys or partition are not found RuntimeError: If empty fields exist in any key (sample) + RuntimeError: If any user-specified fields are not retrived after TQ_KV_POLLING_METADATA_TIMEOUT Example: >>> import transfer_queue as tq @@ -670,6 +699,28 @@ async def async_kv_batch_get( if fields is not None: if isinstance(fields, str): fields = [fields] + target_fields = set(fields) + current_fields = set(batch_meta.field_names) + + not_ready_fields = target_fields - current_fields + begin_polling_time = time.time() + while not_ready_fields: + if time.time() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: + raise RuntimeError( + f"Timeout for async_kv_batch_get. Missing fields: {not_ready_fields}" + f" after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds." + ) + + logger.warning( + f"Fields {list(not_ready_fields)} are not ready yet! " + f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." + ) + + await asyncio.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + current_fields = set(batch_meta.field_names) + not_ready_fields = target_fields - current_fields + batch_meta = batch_meta.select_fields(fields) if not batch_meta.is_ready: From 7ffde7d1a4d2838591fab75c80b2d38c8a7e452f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 10:18:03 +0800 Subject: [PATCH 2/8] fix typo Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 9787508..73558bb 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -390,7 +390,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list Raises: RuntimeError: If keys or partition are not found RuntimeError: If empty fields exist in any key (sample) - RuntimeError: If any user-specified fields are not retrived after TQ_KV_POLLING_METADATA_TIMEOUT + RuntimeError: If any user-specified fields are not retrieved after TQ_KV_POLLING_METADATA_TIMEOUT Example: >>> import transfer_queue as tq @@ -675,7 +675,7 @@ async def async_kv_batch_get( Raises: RuntimeError: If keys or partition are not found RuntimeError: If empty fields exist in any key (sample) - RuntimeError: If any user-specified fields are not retrived after TQ_KV_POLLING_METADATA_TIMEOUT + RuntimeError: If any user-specified fields are not retrieved after TQ_KV_POLLING_METADATA_TIMEOUT Example: >>> import transfer_queue as tq From 6761b862439ce27d3fa031ca886e6110762af2ce Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 10:50:07 +0800 Subject: [PATCH 3/8] add UT Signed-off-by: 0oshowero0 --- tests/test_kv_interface.py | 509 +++++++++++++++++++++++++++++++++++++ 1 file changed, 509 insertions(+) create mode 100644 tests/test_kv_interface.py diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py new file mode 100644 index 0000000..5e12e9e --- /dev/null +++ b/tests/test_kv_interface.py @@ -0,0 +1,509 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the high-level KV interface in transfer_queue.interface. + +This module tests the kv_batch_get and async_kv_batch_get functions, specifically +the polling and timeout behavior when fields are not immediately available. +""" + +import sys +import threading +import time +from pathlib import Path +from threading import Thread +from unittest.mock import patch + +import pytest +import torch +import zmq +from tensordict import TensorDict + +# Add parent directory to path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue import TransferQueueClient # noqa: E402 +from transfer_queue.metadata import ( # noqa: E402 + BatchMeta, + FieldMeta, + SampleMeta, +) +from transfer_queue.utils.enum_utils import ProductionStatus, TransferQueueRole # noqa: E402 +from transfer_queue.utils.zmq_utils import ( # noqa: E402 + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, +) + +# ============================================================================= +# Mock Controllers for Testing Polling/Timeout Behavior +# ============================================================================= + + +class MockControllerWithFieldDelay: + """Mock controller that simulates fields becoming available over time. + + This mock is used to test the polling behavior of kv_batch_get when + fields are not immediately available (simulating async writes). + """ + + def __init__(self, controller_id="controller_delay"): + self.controller_id = controller_id + self.context = zmq.Context() + + # Socket for data requests + self.request_socket = self.context.socket(zmq.ROUTER) + self.request_port = self._bind_to_random_port(self.request_socket) + + self.zmq_server_info = ZMQServerInfo( + role=TransferQueueRole.CONTROLLER, + id=controller_id, + ip="127.0.0.1", + ports={ + "request_handle_socket": self.request_port, + }, + ) + + self.running = True + self.request_thread = Thread(target=self._handle_requests, daemon=True) + self.request_thread.start() + + # Track call counts to simulate delayed field availability + self.kv_retrieve_call_count = {} + self._lock = threading.Lock() + + def _bind_to_random_port(self, socket): + port = socket.bind_to_random_port("tcp://127.0.0.1") + return port + + def _handle_requests(self): + poller = zmq.Poller() + poller.register(self.request_socket, zmq.POLLIN) + + while self.running: + try: + socks = dict(poller.poll(100)) + if self.request_socket in socks: + messages = self.request_socket.recv_multipart() + identity = messages.pop(0) + serialized_msg = messages + request_msg = ZMQMessage.deserialize(serialized_msg) + + if request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + response_body = self._mock_kv_retrieve_keys_delayed(request_msg.body) + response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE + else: + response_body = {"error": f"Unknown request type: {request_msg.request_type}"} + response_type = ZMQRequestType.CLEAR_META_RESPONSE + + response_msg = ZMQMessage.create( + request_type=response_type, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body=response_body, + ) + self.request_socket.send_multipart([identity, *response_msg.serialize()]) + except zmq.Again: + continue + except Exception as e: + if self.running: + print(f"MockControllerWithFieldDelay exception: {e}") + else: + print(f"MockControllerWithFieldDelay ERROR: {e}") + raise + + def _mock_kv_retrieve_keys_delayed(self, request_body): + """Mock KV retrieve that simulates fields becoming available over time.""" + keys = request_body.get("keys", []) + partition_id = request_body.get("partition_id", "") + + # Create a unique key for tracking call count + call_key = f"{partition_id}:{','.join(keys) if isinstance(keys, list) else keys}" + + with self._lock: + if call_key not in self.kv_retrieve_call_count: + self.kv_retrieve_call_count[call_key] = 0 + self.kv_retrieve_call_count[call_key] += 1 + call_number = self.kv_retrieve_call_count[call_key] + + # Simulate: first 2 calls return only "input_ids", after that return all fields + if call_number <= 2: + # Only input_ids available initially + field_names = ["input_ids"] + else: + # All fields available + field_names = ["input_ids", "attention_mask", "response"] + + # Generate global indexes + if not hasattr(self, "_kv_index_map"): + self._kv_index_map = {} + if partition_id not in self._kv_index_map: + self._kv_index_map[partition_id] = 0 + start_index = self._kv_index_map.get(partition_id, 0) + global_indexes = list(range(start_index, start_index + len(keys))) + self._kv_index_map[partition_id] = global_indexes[-1] + 1 + + # Create metadata for each key + samples = [] + for i, key in enumerate(keys): + fields = {} + for field_name in field_names: + field_meta = FieldMeta( + name=field_name, + dtype=torch.int64 if field_name == "input_ids" else torch.float32, + shape=torch.Size([1, 10]) if field_name == "input_ids" else torch.Size([1, 5]), + production_status=ProductionStatus.READY_FOR_CONSUME, + ) + fields[field_name] = field_meta + sample = SampleMeta( + partition_id=partition_id, + global_index=global_indexes[i], + fields=fields, + ) + samples.append(sample) + + metadata = BatchMeta(samples=samples) + return {"metadata": metadata} + + def reset_call_counts(self): + """Reset the call count tracking for testing.""" + with self._lock: + self.kv_retrieve_call_count = {} + + def stop(self): + self.running = False + time.sleep(0.2) + self.request_socket.close() + self.context.term() + + +class MockControllerForTimeout: + """Mock controller that never provides certain fields (for timeout testing).""" + + def __init__(self, controller_id="controller_timeout"): + self.controller_id = controller_id + self.context = zmq.Context() + + self.request_socket = self.context.socket(zmq.ROUTER) + self.request_port = self._bind_to_random_port(self.request_socket) + + self.zmq_server_info = ZMQServerInfo( + role=TransferQueueRole.CONTROLLER, + id=controller_id, + ip="127.0.0.1", + ports={ + "request_handle_socket": self.request_port, + }, + ) + + self.running = True + self.request_thread = Thread(target=self._handle_requests, daemon=True) + self.request_thread.start() + + def _bind_to_random_port(self, socket): + return socket.bind_to_random_port("tcp://127.0.0.1") + + def _handle_requests(self): + poller = zmq.Poller() + poller.register(self.request_socket, zmq.POLLIN) + + while self.running: + try: + socks = dict(poller.poll(100)) + if self.request_socket in socks: + messages = self.request_socket.recv_multipart() + identity = messages.pop(0) + serialized_msg = messages + request_msg = ZMQMessage.deserialize(serialized_msg) + + if request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + response_body = self._mock_kv_retrieve_keys_never_available(request_msg.body) + response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE + else: + response_body = {"error": f"Unknown request type: {request_msg.request_type}"} + response_type = ZMQRequestType.CLEAR_META_RESPONSE + + response_msg = ZMQMessage.create( + request_type=response_type, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body=response_body, + ) + self.request_socket.send_multipart([identity, *response_msg.serialize()]) + except zmq.Again: + continue + except Exception as e: + if self.running: + print(f"MockControllerForTimeout exception: {e}") + else: + print(f"MockControllerForTimeout ERROR: {e}") + raise + + def _mock_kv_retrieve_keys_never_available(self, request_body): + """Mock KV retrieve that never provides certain fields.""" + keys = request_body.get("keys", []) + partition_id = request_body.get("partition_id", "") + + # Only provide "input_ids" - "attention_mask" and "response" will never be available + field_names = ["input_ids"] + + if not hasattr(self, "_kv_index_map"): + self._kv_index_map = {} + if partition_id not in self._kv_index_map: + self._kv_index_map[partition_id] = 0 + start_index = self._kv_index_map.get(partition_id, 0) + global_indexes = list(range(start_index, start_index + len(keys))) + self._kv_index_map[partition_id] = global_indexes[-1] + 1 + + samples = [] + for i, key in enumerate(keys): + fields = {} + for field_name in field_names: + field_meta = FieldMeta( + name=field_name, + dtype=torch.int64, + shape=torch.Size([1, 10]), + production_status=ProductionStatus.READY_FOR_CONSUME, + ) + fields[field_name] = field_meta + sample = SampleMeta( + partition_id=partition_id, + global_index=global_indexes[i], + fields=fields, + ) + samples.append(sample) + + metadata = BatchMeta(samples=samples) + return {"metadata": metadata} + + def stop(self): + self.running = False + time.sleep(0.2) + self.request_socket.close() + self.context.term() + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +def mock_controller_delay(): + """Fixture providing a mock controller with delayed field availability.""" + controller = MockControllerWithFieldDelay() + yield controller + controller.stop() + + +@pytest.fixture +def mock_controller_timeout(): + """Fixture providing a mock controller that never provides certain fields.""" + controller = MockControllerForTimeout() + yield controller + controller.stop() + + +def create_mock_client(mock_controller): + """Create a TransferQueueClient connected to the given mock controller. + + Note: Storage methods are mocked at high level, so no actual storage is needed. + """ + client = TransferQueueClient( + client_id="client_test", + controller_info=mock_controller.zmq_server_info, + ) + + with patch( + "transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller" + ): + # Create a dummy zmq_server_info for storage (not actually used since we mock storage methods) + storage_info = ZMQServerInfo( + role=TransferQueueRole.STORAGE, + id="dummy_storage", + ip="127.0.0.1", + ports={"put_get_socket": 9999}, + ) + + config = { + "controller_info": mock_controller.zmq_server_info, + "zmq_info": {"dummy_storage": storage_info}, + } + client.initialize_storage_manager(manager_type="SimpleStorage", config=config) + + # Mock storage methods at high level + async def mock_put_data(data, metadata): + pass + + async def mock_get_data(metadata): + # Return test data matching the expected fields + return TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]), + "response": torch.tensor([[100, 101, 102, 103, 104]]), + }, + batch_size=1, + ) + + async def mock_clear_data(metadata): + pass + + client.storage_manager.put_data = mock_put_data + client.storage_manager.get_data = mock_get_data + client.storage_manager.clear_data = mock_clear_data + + return client + + +# ============================================================================= +# Sync KV Interface Tests +# ============================================================================= + + +class TestKVMixedFieldPolling: + """Tests for kv_batch_get polling behavior when fields become available.""" + + def test_kv_batch_get_polls_until_fields_available(self, mock_controller_delay): + """Test that kv_batch_get polls and waits for fields to become available. + + This test simulates the scenario where: + 1. Initial kv_retrieve_keys call returns only "input_ids" + 2. Subsequent calls (after polling) return all fields including "response" + 3. kv_batch_get should eventually succeed after polling + """ + import transfer_queue.interface as interface + + client = create_mock_client(mock_controller_delay) + + # Patch the client creation to use our mock + original_client = interface._TRANSFER_QUEUE_CLIENT + try: + interface._TRANSFER_QUEUE_CLIENT = client + + # This should poll until all fields are available and succeed + result = interface.kv_batch_get( + keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + ) + + # Verify we got all requested fields + assert "input_ids" in result.keys() + assert "attention_mask" in result.keys() + assert "response" in result.keys() + + finally: + interface._TRANSFER_QUEUE_CLIENT = original_client + + def test_kv_batch_get_timeout_on_missing_fields(self, mock_controller_timeout): + """Test that kv_batch_get raises timeout error when fields never become available. + + This test simulates the scenario where: + 1. kv_retrieve_keys only returns "input_ids" + 2. We request "attention_mask" and "response" which never become available + 3. kv_batch_get should raise RuntimeError after timeout + """ + import transfer_queue.interface as interface + from transfer_queue.interface import TQ_KV_POLLING_METADATA_TIMEOUT + + # Temporarily reduce timeout for faster test + original_timeout = TQ_KV_POLLING_METADATA_TIMEOUT + interface.TQ_KV_POLLING_METADATA_TIMEOUT = 1 # 1 second for testing + + client = create_mock_client(mock_controller_timeout) + + original_client = interface._TRANSFER_QUEUE_CLIENT + try: + interface._TRANSFER_QUEUE_CLIENT = client + + with pytest.raises(RuntimeError, match="Timeout for kv_batch_get"): + interface.kv_batch_get( + keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + ) + + finally: + interface._TRANSFER_QUEUE_CLIENT = original_client + interface.TQ_KV_POLLING_METADATA_TIMEOUT = original_timeout + + +# ============================================================================= +# Async KV Interface Tests +# ============================================================================= + + +@pytest.mark.asyncio +class TestAsyncKVMixedFieldPolling: + """Tests for async_kv_batch_get polling behavior.""" + + async def test_async_kv_batch_get_polls_until_fields_available(self, mock_controller_delay): + """Test that async_kv_batch_get polls and waits for fields to become available.""" + import transfer_queue.interface as interface + + client = create_mock_client(mock_controller_delay) + + original_client = interface._TRANSFER_QUEUE_CLIENT + try: + interface._TRANSFER_QUEUE_CLIENT = client + + # This should poll until all fields are available and succeed + result = await interface.async_kv_batch_get( + keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + ) + + # Verify we got all requested fields + assert "input_ids" in result.keys() + assert "attention_mask" in result.keys() + assert "response" in result.keys() + + finally: + interface._TRANSFER_QUEUE_CLIENT = original_client + + async def test_async_kv_batch_get_timeout_on_missing_fields(self, mock_controller_timeout): + """Test that async_kv_batch_get raises timeout error when fields never become available.""" + import transfer_queue.interface as interface + from transfer_queue.interface import TQ_KV_POLLING_METADATA_TIMEOUT + + # Temporarily reduce timeout for faster test + original_timeout = TQ_KV_POLLING_METADATA_TIMEOUT + interface.TQ_KV_POLLING_METADATA_TIMEOUT = 1 # 1 second for testing + + client = create_mock_client(mock_controller_timeout) + + original_client = interface._TRANSFER_QUEUE_CLIENT + try: + interface._TRANSFER_QUEUE_CLIENT = client + + with pytest.raises(RuntimeError, match="Timeout for async_kv_batch_get"): + await interface.async_kv_batch_get( + keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + ) + + finally: + interface._TRANSFER_QUEUE_CLIENT = original_client + interface.TQ_KV_POLLING_METADATA_TIMEOUT = original_timeout + + +# ============================================================================= +# Run Tests +# ============================================================================= + + +def run_tests(): + """Run all tests manually for debugging.""" + pytest.main([__file__, "-v", "-s"]) + + +if __name__ == "__main__": + run_tests() From 73ae9bf7b8001b594762a3b538b65afa118f8bbf Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 11:08:14 +0800 Subject: [PATCH 4/8] fix comments Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 4 ++-- transfer_queue/interface.py | 26 ++++++++++++++------------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 3c7c3f9..eb98853 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1193,7 +1193,7 @@ def get_metadata( if batch_size is None: raise ValueError("must provide batch_size in fetch mode") - start_time = time.time() + start_time = time.monotonic() while True: # ready_for_consume_indexes: samples where all required fields are produced # (production status is ready) and not yet consumed @@ -1207,7 +1207,7 @@ def get_metadata( f" Returning None due to polling mode." ) return BatchMeta.empty() - if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: + if time.monotonic() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: raise TimeoutError( f"Timeout while waiting for sufficient data for task {task_name}. " f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 73558bb..b8cf2e5 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -419,16 +419,17 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list current_fields = set(batch_meta.field_names) not_ready_fields = target_fields - current_fields - begin_polling_time = time.time() + begin_polling_time = time.monotonic() while not_ready_fields: - if time.time() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: + if time.monotonic() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: raise RuntimeError( - f"Timeout for kv_batch_get. Missing fields: {not_ready_fields}" - f" after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds." + f"Timeout for kv_batch_get. Missing fields: {list(sorted(not_ready_fields))} " + f"after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds. " + f"Extra info for debug: partition: {partition_id}, keys: {keys}" ) - logger.warning( - f"Fields {list(not_ready_fields)} are not ready yet! " + logger.info( + f"Fields {list(sorted(not_ready_fields))} are not ready yet! " f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." ) @@ -703,16 +704,17 @@ async def async_kv_batch_get( current_fields = set(batch_meta.field_names) not_ready_fields = target_fields - current_fields - begin_polling_time = time.time() + begin_polling_time = time.monotonic() while not_ready_fields: - if time.time() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: + if time.monotonic() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: raise RuntimeError( - f"Timeout for async_kv_batch_get. Missing fields: {not_ready_fields}" - f" after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds." + f"Timeout for async_kv_batch_get. Missing fields: {list(sorted(not_ready_fields))} " + f"after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds. " + f"Extra info for debug: partition: {partition_id}, keys: {keys}" ) - logger.warning( - f"Fields {list(not_ready_fields)} are not ready yet! " + logger.info( + f"Fields {list(sorted(not_ready_fields))} are not ready yet! " f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." ) From cf1f3b8eb67ca79d03d4229564954a8d17294f18 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 11:50:48 +0800 Subject: [PATCH 5/8] change to warnning Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index b8cf2e5..3a610c4 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -428,8 +428,8 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list f"Extra info for debug: partition: {partition_id}, keys: {keys}" ) - logger.info( - f"Fields {list(sorted(not_ready_fields))} are not ready yet! " + logger.warning( + f"Try kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." ) @@ -713,8 +713,8 @@ async def async_kv_batch_get( f"Extra info for debug: partition: {partition_id}, keys: {keys}" ) - logger.info( - f"Fields {list(sorted(not_ready_fields))} are not ready yet! " + logger.warning( + f"Try async_kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." ) From 502492d48b9cead181b768689df00de52979dbbc Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 19:59:33 +0800 Subject: [PATCH 6/8] improved stability Signed-off-by: 0oshowero0 try Signed-off-by: 0oshowero0 minor improve Signed-off-by: 0oshowero0 --- transfer_queue/controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index eb98853..05c6af8 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -208,7 +208,7 @@ class DataPartitionStatus: # Production status tensor - dynamically expandable # Values: 0 = not produced, 1 = ready for consumption - production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) + production_status: Tensor = field(default_factory=lambda: torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task From 060bc4a1e0a0f9fa5ed99cc2e3bc61b84ec2e190 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 21:13:31 +0800 Subject: [PATCH 7/8] add switch for kv_batch_get fields check Signed-off-by: 0oshowero0 --- tests/test_kv_interface.py | 20 +++++++-- transfer_queue/interface.py | 88 ++++++++++++++++++++----------------- 2 files changed, 63 insertions(+), 45 deletions(-) diff --git a/tests/test_kv_interface.py b/tests/test_kv_interface.py index 5e12e9e..9101db3 100644 --- a/tests/test_kv_interface.py +++ b/tests/test_kv_interface.py @@ -396,7 +396,10 @@ def test_kv_batch_get_polls_until_fields_available(self, mock_controller_delay): # This should poll until all fields are available and succeed result = interface.kv_batch_get( - keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + keys="test_key", + partition_id="test_partition", + fields=["input_ids", "attention_mask", "response"], + strict=True, ) # Verify we got all requested fields @@ -430,7 +433,10 @@ def test_kv_batch_get_timeout_on_missing_fields(self, mock_controller_timeout): with pytest.raises(RuntimeError, match="Timeout for kv_batch_get"): interface.kv_batch_get( - keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + keys="test_key", + partition_id="test_partition", + fields=["input_ids", "attention_mask", "response"], + strict=True, ) finally: @@ -459,7 +465,10 @@ async def test_async_kv_batch_get_polls_until_fields_available(self, mock_contro # This should poll until all fields are available and succeed result = await interface.async_kv_batch_get( - keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + keys="test_key", + partition_id="test_partition", + fields=["input_ids", "attention_mask", "response"], + strict=True, ) # Verify we got all requested fields @@ -487,7 +496,10 @@ async def test_async_kv_batch_get_timeout_on_missing_fields(self, mock_controlle with pytest.raises(RuntimeError, match="Timeout for async_kv_batch_get"): await interface.async_kv_batch_get( - keys="test_key", partition_id="test_partition", fields=["input_ids", "attention_mask", "response"] + keys="test_key", + partition_id="test_partition", + fields=["input_ids", "attention_mask", "response"], + strict=True, ) finally: diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 3a610c4..41eecb6 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -374,7 +374,9 @@ def kv_batch_put( tq_client.set_custom_meta(batch_meta) -def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: +def kv_batch_get( + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None, strict: bool = False +) -> TensorDict: """Get data from TransferQueue using user-specified keys. This is a convenience method for retrieving data using keys instead of indexes. @@ -383,6 +385,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list keys: Single key or list of keys to retrieve partition_id: Partition containing the keys fields: Optional field(s) to retrieve. If None, retrieves all fields + strict: If True, raises an error if specified fields do not exist Returns: TensorDict with the requested data @@ -415,28 +418,28 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list if isinstance(fields, str): fields = [fields] - target_fields = set(fields) - current_fields = set(batch_meta.field_names) - - not_ready_fields = target_fields - current_fields - begin_polling_time = time.monotonic() - while not_ready_fields: - if time.monotonic() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: - raise RuntimeError( - f"Timeout for kv_batch_get. Missing fields: {list(sorted(not_ready_fields))} " - f"after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds. " - f"Extra info for debug: partition: {partition_id}, keys: {keys}" - ) - - logger.warning( - f"Try kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " - f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." - ) - - time.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + if strict: + target_fields = set(fields) current_fields = set(batch_meta.field_names) not_ready_fields = target_fields - current_fields + begin_polling_time = time.monotonic() + while not_ready_fields: + if time.monotonic() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: + raise RuntimeError( + f"Timeout for kv_batch_get. Missing fields: {list(sorted(not_ready_fields))} " + f"after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds. " + f"Extra info for debug: partition: {partition_id}, keys: {keys}" + ) + + logger.warning( + f"Try kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " + f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." + ) + + time.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) + batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + current_fields = set(batch_meta.field_names) + not_ready_fields = target_fields - current_fields batch_meta = batch_meta.select_fields(fields) @@ -659,7 +662,7 @@ async def async_kv_batch_put( async def async_kv_batch_get( - keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None, strict: bool = False ) -> TensorDict: """Asynchronously get data from TransferQueue using user-specified keys. @@ -669,6 +672,7 @@ async def async_kv_batch_get( keys: Single key or list of keys to retrieve partition_id: Partition containing the keys fields: Optional field(s) to retrieve. If None, retrieves all fields + strict: If True, raises an error if specified fields do not exist Returns: TensorDict with the requested data @@ -700,28 +704,30 @@ async def async_kv_batch_get( if fields is not None: if isinstance(fields, str): fields = [fields] - target_fields = set(fields) - current_fields = set(batch_meta.field_names) - - not_ready_fields = target_fields - current_fields - begin_polling_time = time.monotonic() - while not_ready_fields: - if time.monotonic() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: - raise RuntimeError( - f"Timeout for async_kv_batch_get. Missing fields: {list(sorted(not_ready_fields))} " - f"after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds. " - f"Extra info for debug: partition: {partition_id}, keys: {keys}" - ) - - logger.warning( - f"Try async_kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " - f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." - ) - await asyncio.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + if strict: + target_fields = set(fields) current_fields = set(batch_meta.field_names) + not_ready_fields = target_fields - current_fields + begin_polling_time = time.monotonic() + while not_ready_fields: + if time.monotonic() - begin_polling_time > TQ_KV_POLLING_METADATA_TIMEOUT: + raise RuntimeError( + f"Timeout for async_kv_batch_get. Missing fields: {list(sorted(not_ready_fields))} " + f"after {TQ_KV_POLLING_METADATA_TIMEOUT} seconds. " + f"Extra info for debug: partition: {partition_id}, keys: {keys}" + ) + + logger.warning( + f"Try async_kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " + f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." + ) + + await asyncio.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) + batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + current_fields = set(batch_meta.field_names) + not_ready_fields = target_fields - current_fields batch_meta = batch_meta.select_fields(fields) From 8a4191ff9dda79d2839f1bc581731ba987085d13 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Feb 2026 21:33:54 +0800 Subject: [PATCH 8/8] improve readability Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 41eecb6..53bf2ed 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -375,7 +375,7 @@ def kv_batch_put( def kv_batch_get( - keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None, strict: bool = False + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None, wait_for_fields: bool = False ) -> TensorDict: """Get data from TransferQueue using user-specified keys. @@ -385,7 +385,8 @@ def kv_batch_get( keys: Single key or list of keys to retrieve partition_id: Partition containing the keys fields: Optional field(s) to retrieve. If None, retrieves all fields - strict: If True, raises an error if specified fields do not exist + wait_for_fields: If True, enters a polling loop waiting for the specified fields + to become ready (up to a timeout). If False, directly return currently available data Returns: TensorDict with the requested data @@ -418,7 +419,7 @@ def kv_batch_get( if isinstance(fields, str): fields = [fields] - if strict: + if wait_for_fields: target_fields = set(fields) current_fields = set(batch_meta.field_names) not_ready_fields = target_fields - current_fields @@ -432,8 +433,8 @@ def kv_batch_get( ) logger.warning( - f"Try kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " - f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." + f"Requested metadata fields {list(sorted(not_ready_fields))} are not yet available; " + f"retrying kv_batch_get in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." ) time.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) @@ -444,6 +445,7 @@ def kv_batch_get( batch_meta = batch_meta.select_fields(fields) if not batch_meta.is_ready: + # this is a double check that should not happen raise RuntimeError("Some fields are not ready in all the requested keys!") data = tq_client.get_data(batch_meta) @@ -662,7 +664,7 @@ async def async_kv_batch_put( async def async_kv_batch_get( - keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None, strict: bool = False + keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None, wait_for_fields: bool = False ) -> TensorDict: """Asynchronously get data from TransferQueue using user-specified keys. @@ -672,7 +674,8 @@ async def async_kv_batch_get( keys: Single key or list of keys to retrieve partition_id: Partition containing the keys fields: Optional field(s) to retrieve. If None, retrieves all fields - strict: If True, raises an error if specified fields do not exist + wait_for_fields: If True, enters a polling loop waiting for the specified fields + to become ready (up to a timeout). If False, directly return currently available data Returns: TensorDict with the requested data @@ -705,7 +708,7 @@ async def async_kv_batch_get( if isinstance(fields, str): fields = [fields] - if strict: + if wait_for_fields: target_fields = set(fields) current_fields = set(batch_meta.field_names) @@ -720,8 +723,8 @@ async def async_kv_batch_get( ) logger.warning( - f"Try async_kv_batch_get with fields {list(sorted(not_ready_fields))} are not ready! " - f"Retry in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." + f"Requested metadata fields {list(sorted(not_ready_fields))} are not ready! " + f"retrying async_kv_batch_get in {TQ_KV_POLLING_METADATA_CHECK_INTERVAL} seconds." ) await asyncio.sleep(TQ_KV_POLLING_METADATA_CHECK_INTERVAL) @@ -732,6 +735,7 @@ async def async_kv_batch_get( batch_meta = batch_meta.select_fields(fields) if not batch_meta.is_ready: + # this is a double check that should not happen raise RuntimeError("Some fields are not ready in all the requested keys!") data = await tq_client.async_get_data(batch_meta)