@@ -898,31 +898,39 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None:
898898
899899
900900# ==================== KV Interface API ====================
901-
902-
903901async def async_kv_put (
904902 key : str , partition_id : str , fields : Optional [TensorDict | dict [str , Any ]], tag : Optional [dict [str , Any ]]
905903) -> None :
906904 """Asynchronously put a single key-value pair to TransferQueue.
907905
908- See kv_put for detailed documentation.
906+ This is a convenience method for putting data using a user-specified key
907+ instead of BatchMeta. Internally, the key is translated to a BatchMeta
908+ and the data is stored using the regular put mechanism.
909909
910910 Args:
911- key: User-specified key for the data
912- partition_id: Partition to store the data in
913- fields: Data fields to store
911+ key: User-specified key for the data sample (in row)
912+ partition_id: Logical partition to store the data in
913+ fields: Data fields to store. Can be a TensorDict or a dict of tensors.
914+ Each key in `fields` will be treated as a column for the data sample.
915+ If dict is provided, tensors will be unsqueezed to add batch dimension.
914916 tag: Optional metadata tag to associate with the key
915917
918+ Raises:
919+ ValueError: If neither fields nor tag is provided
920+ ValueError: If nested tensors are provided (use kv_batch_put instead)
921+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
922+
916923 Example:
917924 >>> import transfer_queue as tq
918925 >>> import torch
919926 >>> tq.init()
927+ >>> # Put with both fields and tag
920928 >>> await tq.async_kv_put(
921929 ... key="sample_1",
922930 ... partition_id="train",
923931 ... fields={"input_ids": torch.tensor([1, 2, 3])},
924932 ... tag={"score": 0.95}
925- ... )
933+ ... ))
926934 """
927935
928936 if fields is None and tag is None :
@@ -933,6 +941,9 @@ async def async_kv_put(
933941 # 1. translate user-specified key to BatchMeta
934942 batch_meta = await tq_client .async_kv_retrieve_keys (keys = [key ], partition_id = partition_id , create = True )
935943
944+ if batch_meta .size != 1 :
945+ raise RuntimeError (f"Retrieved BatchMeta size { batch_meta .size } does not match with input `key` size of 1!" )
946+
936947 # 2. register the user-specified tag to BatchMeta
937948 if tag :
938949 batch_meta .update_custom_meta ([tag ])
@@ -965,32 +976,51 @@ async def async_kv_batch_put(
965976) -> None :
966977 """Asynchronously put multiple key-value pairs to TransferQueue in batch.
967978
968- See kv_batch_put for detailed documentation.
979+ This method stores multiple key-value pairs in a single operation, which is more
980+ efficient than calling kv_put multiple times.
969981
970982 Args:
971983 keys: List of user-specified keys for the data
972- partition_id: Partition to store the data in
973- fields: TensorDict containing data for all keys
984+ partition_id: Logical partition to store the data in
985+ fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
974986 tags: List of metadata tags, one for each key
975987
988+ Raises:
989+ ValueError: If neither `fields` nor `tags` is provided
990+ ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
991+ RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
992+
976993 Example:
977994 >>> import transfer_queue as tq
978995 >>> tq.init()
979996 >>> keys = ["sample_1", "sample_2", "sample_3"]
980997 >>> fields = TensorDict({
981998 ... "input_ids": torch.randn(3, 10),
999+ ... "attention_mask": torch.ones(3, 10),
9821000 ... }, batch_size=3)
9831001 >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
9841002 >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
9851003 """
1004+
9861005 if fields is None and tags is None :
9871006 raise ValueError ("Please provide at least one parameter of fields or tag." )
9881007
1008+ if fields .batch_size [0 ] != len (keys ):
1009+ raise ValueError (
1010+ f"`keys` with length { len (keys )} does not match the `fields` TensorDict with "
1011+ f"batch_size { fields .batch_size [0 ]} "
1012+ )
1013+
9891014 tq_client = _maybe_create_transferqueue_client ()
9901015
9911016 # 1. translate user-specified key to BatchMeta
9921017 batch_meta = await tq_client .async_kv_retrieve_keys (keys = keys , partition_id = partition_id , create = True )
9931018
1019+ if batch_meta .size != len (keys ):
1020+ raise RuntimeError (
1021+ f"Retrieved BatchMeta size { batch_meta .size } does not match with input `keys` size { len (keys )} !"
1022+ )
1023+
9941024 # 2. register the user-specified tags to BatchMeta
9951025 if tags :
9961026 if len (tags ) != len (keys ):
@@ -1010,28 +1040,38 @@ async def async_kv_get(
10101040) -> TensorDict :
10111041 """Asynchronously get data from TransferQueue using user-specified keys.
10121042
1013- See kv_get for detailed documentation .
1043+ This is a convenience method for retrieving data using keys instead of indexes .
10141044
10151045 Args:
10161046 keys: Single key or list of keys to retrieve
10171047 partition_id: Partition containing the keys
1018- fields: Optional field(s) to retrieve
1048+ fields: Optional field(s) to retrieve. If None, retrieves all fields
10191049
10201050 Returns:
10211051 TensorDict with the requested data
10221052
1053+ Raises:
1054+ RuntimeError: If keys or partition are not found
1055+
10231056 Example:
10241057 >>> import transfer_queue as tq
10251058 >>> tq.init()
1059+ >>> # Get single key with all fields
1060+ >>> data = await tq.async_kv_get(key="sample_1", partition_id="train")
1061+ >>> # Get multiple keys with specific fields
10261062 >>> data = await tq.async_kv_get(
10271063 ... keys=["sample_1", "sample_2"],
1028- ... partition_id="train"
1064+ ... partition_id="train",
1065+ ... fields="input_ids"
10291066 ... )
10301067 """
10311068 tq_client = _maybe_create_transferqueue_client ()
10321069
10331070 batch_meta = await tq_client .async_kv_retrieve_keys (keys = keys , partition_id = partition_id , create = False )
10341071
1072+ if batch_meta .size == 0 :
1073+ raise RuntimeError ("keys or partition were not found!" )
1074+
10351075 if fields is not None :
10361076 if isinstance (fields , str ):
10371077 fields = [fields ]
@@ -1045,18 +1085,20 @@ async def async_kv_get(
10451085async def async_kv_list (partition_id : str ) -> tuple [list [str ], list [dict [str , Any ]]]:
10461086 """Asynchronously list all keys and their metadata in a partition.
10471087
1048- See kv_list for detailed documentation.
1049-
10501088 Args:
10511089 partition_id: Partition to list keys from
10521090
10531091 Returns:
1054- Tuple of (keys list, tags list)
1092+ Tuple of:
1093+ - List of keys in the partition
1094+ - List of custom metadata (tags) associated with each key
10551095
10561096 Example:
10571097 >>> import transfer_queue as tq
10581098 >>> tq.init()
10591099 >>> keys, tags = await tq.async_kv_list(partition_id="train")
1100+ >>> print(f"Keys: {keys}")
1101+ >>> print(f"Tags: {tags}")
10601102 """
10611103 tq_client = _maybe_create_transferqueue_client ()
10621104
@@ -1068,7 +1110,8 @@ async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, An
10681110async def async_kv_clear (keys : list [str ] | str , partition_id : str ) -> None :
10691111 """Asynchronously clear key-value pairs from TransferQueue.
10701112
1071- See kv_clear for detailed documentation.
1113+ This removes the specified keys and their associated data from both
1114+ the controller and storage units.
10721115
10731116 Args:
10741117 keys: Single key or list of keys to clear
@@ -1077,6 +1120,9 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
10771120 Example:
10781121 >>> import transfer_queue as tq
10791122 >>> tq.init()
1123+ >>> # Clear single key
1124+ >>> await tq.async_kv_clear(key="sample_1", partition_id="train")
1125+ >>> # Clear multiple keys
10801126 >>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
10811127 """
10821128
0 commit comments