Skip to content

Commit 5813c41

Browse files
committed
update
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent d4da4d3 commit 5813c41

1 file changed

Lines changed: 63 additions & 17 deletions

File tree

transfer_queue/interface.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -898,31 +898,39 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None:
898898

899899

900900
# ==================== KV Interface API ====================
901-
902-
903901
async def async_kv_put(
904902
key: str, partition_id: str, fields: Optional[TensorDict | dict[str, Any]], tag: Optional[dict[str, Any]]
905903
) -> None:
906904
"""Asynchronously put a single key-value pair to TransferQueue.
907905
908-
See kv_put for detailed documentation.
906+
This is a convenience method for putting data using a user-specified key
907+
instead of BatchMeta. Internally, the key is translated to a BatchMeta
908+
and the data is stored using the regular put mechanism.
909909
910910
Args:
911-
key: User-specified key for the data
912-
partition_id: Partition to store the data in
913-
fields: Data fields to store
911+
key: User-specified key for the data sample (in row)
912+
partition_id: Logical partition to store the data in
913+
fields: Data fields to store. Can be a TensorDict or a dict of tensors.
914+
Each key in `fields` will be treated as a column for the data sample.
915+
If dict is provided, tensors will be unsqueezed to add batch dimension.
914916
tag: Optional metadata tag to associate with the key
915917
918+
Raises:
919+
ValueError: If neither fields nor tag is provided
920+
ValueError: If nested tensors are provided (use kv_batch_put instead)
921+
RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
922+
916923
Example:
917924
>>> import transfer_queue as tq
918925
>>> import torch
919926
>>> tq.init()
927+
>>> # Put with both fields and tag
920928
>>> await tq.async_kv_put(
921929
... key="sample_1",
922930
... partition_id="train",
923931
... fields={"input_ids": torch.tensor([1, 2, 3])},
924932
... tag={"score": 0.95}
925-
... )
933+
... ))
926934
"""
927935

928936
if fields is None and tag is None:
@@ -933,6 +941,9 @@ async def async_kv_put(
933941
# 1. translate user-specified key to BatchMeta
934942
batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True)
935943

944+
if batch_meta.size != 1:
945+
raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")
946+
936947
# 2. register the user-specified tag to BatchMeta
937948
if tag:
938949
batch_meta.update_custom_meta([tag])
@@ -965,32 +976,51 @@ async def async_kv_batch_put(
965976
) -> None:
966977
"""Asynchronously put multiple key-value pairs to TransferQueue in batch.
967978
968-
See kv_batch_put for detailed documentation.
979+
This method stores multiple key-value pairs in a single operation, which is more
980+
efficient than calling kv_put multiple times.
969981
970982
Args:
971983
keys: List of user-specified keys for the data
972-
partition_id: Partition to store the data in
973-
fields: TensorDict containing data for all keys
984+
partition_id: Logical partition to store the data in
985+
fields: TensorDict containing data for all keys. Must have batch_size == len(keys)
974986
tags: List of metadata tags, one for each key
975987
988+
Raises:
989+
ValueError: If neither `fields` nor `tags` is provided
990+
ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
991+
RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
992+
976993
Example:
977994
>>> import transfer_queue as tq
978995
>>> tq.init()
979996
>>> keys = ["sample_1", "sample_2", "sample_3"]
980997
>>> fields = TensorDict({
981998
... "input_ids": torch.randn(3, 10),
999+
... "attention_mask": torch.ones(3, 10),
9821000
... }, batch_size=3)
9831001
>>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
9841002
>>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
9851003
"""
1004+
9861005
if fields is None and tags is None:
9871006
raise ValueError("Please provide at least one parameter of fields or tag.")
9881007

1008+
if fields.batch_size[0] != len(keys):
1009+
raise ValueError(
1010+
f"`keys` with length {len(keys)} does not match the `fields` TensorDict with "
1011+
f"batch_size {fields.batch_size[0]}"
1012+
)
1013+
9891014
tq_client = _maybe_create_transferqueue_client()
9901015

9911016
# 1. translate user-specified key to BatchMeta
9921017
batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True)
9931018

1019+
if batch_meta.size != len(keys):
1020+
raise RuntimeError(
1021+
f"Retrieved BatchMeta size {batch_meta.size} does not match with input `keys` size {len(keys)}!"
1022+
)
1023+
9941024
# 2. register the user-specified tags to BatchMeta
9951025
if tags:
9961026
if len(tags) != len(keys):
@@ -1010,28 +1040,38 @@ async def async_kv_get(
10101040
) -> TensorDict:
10111041
"""Asynchronously get data from TransferQueue using user-specified keys.
10121042
1013-
See kv_get for detailed documentation.
1043+
This is a convenience method for retrieving data using keys instead of indexes.
10141044
10151045
Args:
10161046
keys: Single key or list of keys to retrieve
10171047
partition_id: Partition containing the keys
1018-
fields: Optional field(s) to retrieve
1048+
fields: Optional field(s) to retrieve. If None, retrieves all fields
10191049
10201050
Returns:
10211051
TensorDict with the requested data
10221052
1053+
Raises:
1054+
RuntimeError: If keys or partition are not found
1055+
10231056
Example:
10241057
>>> import transfer_queue as tq
10251058
>>> tq.init()
1059+
>>> # Get single key with all fields
1060+
>>> data = await tq.async_kv_get(key="sample_1", partition_id="train")
1061+
>>> # Get multiple keys with specific fields
10261062
>>> data = await tq.async_kv_get(
10271063
... keys=["sample_1", "sample_2"],
1028-
... partition_id="train"
1064+
... partition_id="train",
1065+
... fields="input_ids"
10291066
... )
10301067
"""
10311068
tq_client = _maybe_create_transferqueue_client()
10321069

10331070
batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False)
10341071

1072+
if batch_meta.size == 0:
1073+
raise RuntimeError("keys or partition were not found!")
1074+
10351075
if fields is not None:
10361076
if isinstance(fields, str):
10371077
fields = [fields]
@@ -1045,18 +1085,20 @@ async def async_kv_get(
10451085
async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, Any]]]:
10461086
"""Asynchronously list all keys and their metadata in a partition.
10471087
1048-
See kv_list for detailed documentation.
1049-
10501088
Args:
10511089
partition_id: Partition to list keys from
10521090
10531091
Returns:
1054-
Tuple of (keys list, tags list)
1092+
Tuple of:
1093+
- List of keys in the partition
1094+
- List of custom metadata (tags) associated with each key
10551095
10561096
Example:
10571097
>>> import transfer_queue as tq
10581098
>>> tq.init()
10591099
>>> keys, tags = await tq.async_kv_list(partition_id="train")
1100+
>>> print(f"Keys: {keys}")
1101+
>>> print(f"Tags: {tags}")
10601102
"""
10611103
tq_client = _maybe_create_transferqueue_client()
10621104

@@ -1068,7 +1110,8 @@ async def async_kv_list(partition_id: str) -> tuple[list[str], list[dict[str, An
10681110
async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
10691111
"""Asynchronously clear key-value pairs from TransferQueue.
10701112
1071-
See kv_clear for detailed documentation.
1113+
This removes the specified keys and their associated data from both
1114+
the controller and storage units.
10721115
10731116
Args:
10741117
keys: Single key or list of keys to clear
@@ -1077,6 +1120,9 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None:
10771120
Example:
10781121
>>> import transfer_queue as tq
10791122
>>> tq.init()
1123+
>>> # Clear single key
1124+
>>> await tq.async_kv_clear(key="sample_1", partition_id="train")
1125+
>>> # Clear multiple keys
10801126
>>> await tq.async_kv_clear(keys=["sample_1", "sample_2"], partition_id="train")
10811127
"""
10821128

0 commit comments

Comments
 (0)