From d7759bda756449464971fa84e25fa54ebfe6e79b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 24 Mar 2026 22:57:19 +0800 Subject: [PATCH 01/11] better API for usage Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 4 +- transfer_queue/interface.py | 197 +++++++++++++++++++++++++---- tutorial/02_kv_interface.py | 2 +- 3 files changed, 172 insertions(+), 31 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 8760c321..7a04ba71 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -473,14 +473,14 @@ def test_kv_batch_get_partial_fields(self, controller): tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) # Get only input_ids - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields="input_ids") + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, select_fields="input_ids") assert "input_ids" in retrieved.keys() assert "attention_mask" not in retrieved.keys() assert "response" not in retrieved.keys() assert_tensor_equal(retrieved["input_ids"], input_ids) # Get multiple specific fields - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, fields=["input_ids", "response"]) + retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, select_fields=["input_ids", "response"]) assert "input_ids" in retrieved.keys() assert "response" in retrieved.keys() assert "attention_mask" not in retrieved.keys() diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index d0fd2f7f..ce0fdb92 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -28,6 +28,7 @@ from tensordict import TensorDict from tensordict.tensorclass import NonTensorStack +from transfer_queue import KVBatchMeta from transfer_queue.client import TransferQueueClient from transfer_queue.controller import TransferQueueController from transfer_queue.sampler import * # noqa: F401 @@ -359,7 +360,7 @@ def kv_put( partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, -) -> None: +) -> KVBatchMeta: """Put a single key-value pair to TransferQueue. This is a convenience method for putting data using a user-specified key @@ -372,8 +373,14 @@ def kv_put( fields: Data fields to store. Can be a TensorDict or a dict of tensors. Each key in `fields` will be treated as a column for the data sample. If dict is provided, tensors will be unsqueezed to add batch dimension. + If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key + Returns: + KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. + The `fields` attribute includes all fields stored for this sample, + including any new fields written by this put operation. + Raises: ValueError: If neither fields nor tag is provided ValueError: If nested tensors are provided (use kv_batch_put instead) @@ -384,12 +391,13 @@ def kv_put( >>> import torch >>> tq.init() >>> # Put with both fields and tag - >>> tq.kv_put( + >>> meta = tq.kv_put( ... key="sample_1", ... partition_id="train", ... fields={"input_ids": torch.tensor([1, 2, 3])}, ... tag={"score": 0.95} ... ) + >>> print(meta.fields) # ['input_ids'] """ if fields is None and tag is None: raise ValueError("Please provide at least one parameter of `fields` or `tag`.") @@ -423,15 +431,26 @@ def kv_put( raise ValueError("field can only be dict or TensorDict") # custom_meta (tag) will be put to controller through the internal put process - tq_client.put(fields, batch_meta) + # After put, batch_meta.field_names() will include the new fields written by user + batch_meta = tq_client.put(fields, batch_meta) + fields_to_return = batch_meta.field_names() else: # directly update custom_meta (tag) to controller tq_client.set_custom_meta(batch_meta) + fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + + return KVBatchMeta( + keys=[key], + tags=batch_meta.custom_meta, + partition_id=partition_id, + fields=fields_to_return, + extra_info=batch_meta.extra_info, + ) def kv_batch_put( keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None -) -> None: +) -> KVBatchMeta: """Put multiple key-value pairs to TransferQueue in batch. This method stores multiple key-value pairs in a single operation, which is more @@ -440,9 +459,15 @@ def kv_batch_put( Args: keys: List of user-specified keys for the data partition_id: Logical partition to store the data in - fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + fields: TensorDict containing data for all keys. Must have batch_size == len(keys). + If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key + Returns: + KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. + The `fields` attribute includes all fields stored for these samples, + including any new fields written by this put operation. + Raises: ValueError: If neither `fields` nor `tags` is provided ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict @@ -458,7 +483,8 @@ def kv_batch_put( ... "attention_mask": torch.ones(3, 10), ... }, batch_size=3) >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] - >>> tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + >>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + >>> print(meta.fields) # ['input_ids', 'attention_mask'] """ if fields is None and tags is None: @@ -488,13 +514,59 @@ def kv_batch_put( # 3. put data if fields is not None: - tq_client.put(fields, batch_meta) + # After put, batch_meta.field_names() will include the new fields written by user + batch_meta = tq_client.put(fields, batch_meta) + fields_to_return = batch_meta.field_names() else: # directly update custom_meta (tags) to controller tq_client.set_custom_meta(batch_meta) + fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + + return KVBatchMeta( + keys=keys, + tags=batch_meta.custom_meta, + partition_id=partition_id, + fields=fields_to_return, + extra_info=batch_meta.extra_info, + ) + + +def kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: + """Get data from TransferQueue using KVBatchMeta. + This is a convenience method for retrieving data using KVBatchMeta returned + from a previous put operation. It extracts the keys, partition_id, and fields + from the metadata to fetch the corresponding data. -def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None) -> TensorDict: + Args: + meta: KVBatchMeta object returned from a previous put operation (e.g., kv_put, + kv_batch_put). It contains keys, partition_id, and fields information. + + Returns: + TensorDict with the requested data + + Raises: + RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # First put some data + >>> meta = tq.kv_batch_put( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields={"input_ids": torch.randn(2, 10)}, + ... ) + >>> # Then retrieve it using the returned metadata + >>> data = tq.kv_batch_get_by_meta(meta) + """ + return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=meta.fields) + + +def kv_batch_get( + keys: list[str] | str, partition_id: str, select_fields: Optional[list[str] | str] = None +) -> TensorDict: """Get data from TransferQueue using user-specified keys. This is a convenience method for retrieving data using keys instead of indexes. @@ -502,7 +574,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list Args: keys: Single key or list of keys to retrieve partition_id: Partition containing the keys - fields: Optional field(s) to retrieve. If None, retrieves all fields + select_fields: Optional field(s) to retrieve. If None, retrieves all fields Returns: TensorDict with the requested data @@ -520,7 +592,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list >>> data = tq.kv_batch_get( ... keys=["sample_1", "sample_2"], ... partition_id="train", - ... fields="input_ids" + ... select_fields="input_ids" ... ) """ tq_client = _maybe_create_transferqueue_client() @@ -530,10 +602,10 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list if batch_meta.size == 0: raise RuntimeError("keys or partition were not found!") - if fields is not None: - if isinstance(fields, str): - fields = [fields] - batch_meta = batch_meta.select_fields(fields) + if select_fields is not None: + if isinstance(select_fields, str): + fields_to_fetch = [select_fields] + batch_meta = batch_meta.select_fields(fields_to_fetch) if not batch_meta.is_ready: raise RuntimeError("Some fields are not ready in all the requested keys!") @@ -618,7 +690,7 @@ async def async_kv_put( partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, -) -> None: +) -> KVBatchMeta: """Asynchronously put a single key-value pair to TransferQueue. This is a convenience method for putting data using a user-specified key @@ -631,8 +703,14 @@ async def async_kv_put( fields: Data fields to store. Can be a TensorDict or a dict of tensors. Each key in `fields` will be treated as a column for the data sample. If dict is provided, tensors will be unsqueezed to add batch dimension. + If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key + Returns: + KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. + The `fields` attribute includes all fields stored for this sample, + including any new fields written by this put operation. + Raises: ValueError: If neither fields nor tag is provided ValueError: If nested tensors are provided (use kv_batch_put instead) @@ -643,12 +721,13 @@ async def async_kv_put( >>> import torch >>> tq.init() >>> # Put with both fields and tag - >>> await tq.async_kv_put( + >>> meta = await tq.async_kv_put( ... key="sample_1", ... partition_id="train", ... fields={"input_ids": torch.tensor([1, 2, 3])}, ... tag={"score": 0.95} - ... )) + ... ) + >>> print(meta.fields) # ['input_ids'] """ if fields is None and tag is None: @@ -683,15 +762,26 @@ async def async_kv_put( raise ValueError("field can only be dict or TensorDict") # custom_meta (tag) will be put to controller through the put process + # After put, batch_meta.field_names() will include the new fields written by user await tq_client.async_put(fields, batch_meta) + fields_to_return = batch_meta.field_names() else: # directly update custom_meta (tag) to controller await tq_client.async_set_custom_meta(batch_meta) + fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + + return KVBatchMeta( + keys=[key], + tags=batch_meta.custom_meta, + partition_id=partition_id, + fields=fields_to_return, + extra_info=batch_meta.extra_info, + ) async def async_kv_batch_put( keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None -) -> None: +) -> KVBatchMeta: """Asynchronously put multiple key-value pairs to TransferQueue in batch. This method stores multiple key-value pairs in a single operation, which is more @@ -700,9 +790,15 @@ async def async_kv_batch_put( Args: keys: List of user-specified keys for the data partition_id: Logical partition to store the data in - fields: TensorDict containing data for all keys. Must have batch_size == len(keys) + fields: TensorDict containing data for all keys. Must have batch_size == len(keys). + If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key + Returns: + KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. + The `fields` attribute includes all fields stored for these samples, + including any new fields written by this put operation. + Raises: ValueError: If neither `fields` nor `tags` is provided ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict @@ -717,7 +813,8 @@ async def async_kv_batch_put( ... "attention_mask": torch.ones(3, 10), ... }, batch_size=3) >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}] - >>> await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + >>> meta = await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags) + >>> print(meta.fields) # ['input_ids', 'attention_mask'] """ if fields is None and tags is None: @@ -747,14 +844,58 @@ async def async_kv_batch_put( # 3. put data if fields is not None: - await tq_client.async_put(fields, batch_meta) + # After put, batch_meta.field_names() will include the new fields written by user + batch_meta = await tq_client.async_put(fields, batch_meta) + fields_to_return = batch_meta.field_names() else: # directly update custom_meta (tags) to controller await tq_client.async_set_custom_meta(batch_meta) + fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + + return KVBatchMeta( + keys=keys, + tags=batch_meta.custom_meta, + partition_id=partition_id, + fields=fields_to_return, + extra_info=batch_meta.extra_info, + ) + + +async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: + """Asynchronously get data from TransferQueue using KVBatchMeta. + + This is a convenience method for retrieving data using KVBatchMeta returned + from a previous put operation. It extracts the keys, partition_id, and fields + from the metadata to fetch the corresponding data. + + Args: + meta: KVBatchMeta object returned from a previous put operation (e.g., async_kv_put, + async_kv_batch_put). It contains keys, partition_id, and fields information. + + Returns: + TensorDict with the requested data + + Raises: + RuntimeError: If keys or partition are not found + RuntimeError: If empty fields exist in any key (sample) + + Example: + >>> import transfer_queue as tq + >>> tq.init() + >>> # First put some data + >>> meta = await tq.async_kv_batch_put( + ... keys=["sample_1", "sample_2"], + ... partition_id="train", + ... fields={"input_ids": torch.randn(2, 10)}, + ... ) + >>> # Then retrieve it using the returned metadata + >>> data = await tq.async_kv_batch_get_by_meta(meta) + """ + return await async_kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=meta.fields) async def async_kv_batch_get( - keys: list[str] | str, partition_id: str, fields: Optional[list[str] | str] = None + keys: list[str] | str, partition_id: str, select_fields: Optional[list[str] | str] = None ) -> TensorDict: """Asynchronously get data from TransferQueue using user-specified keys. @@ -763,7 +904,7 @@ async def async_kv_batch_get( Args: keys: Single key or list of keys to retrieve partition_id: Partition containing the keys - fields: Optional field(s) to retrieve. If None, retrieves all fields + select_fields: Optional field(s) to retrieve. If None, retrieves all fields Returns: TensorDict with the requested data @@ -781,7 +922,7 @@ async def async_kv_batch_get( >>> data = await tq.async_kv_batch_get( ... keys=["sample_1", "sample_2"], ... partition_id="train", - ... fields="input_ids" + ... select_fields="input_ids" ... ) """ tq_client = _maybe_create_transferqueue_client() @@ -791,10 +932,10 @@ async def async_kv_batch_get( if batch_meta.size == 0: raise RuntimeError("keys or partition were not found!") - if fields is not None: - if isinstance(fields, str): - fields = [fields] - batch_meta = batch_meta.select_fields(fields) + if select_fields is not None: + if isinstance(select_fields, str): + fields_to_fetch = [select_fields] + batch_meta = batch_meta.select_fields(fields_to_fetch) if not batch_meta.is_ready: raise RuntimeError("Some fields are not ready in all the requested keys!") diff --git a/tutorial/02_kv_interface.py b/tutorial/02_kv_interface.py index 376ebfb5..778739c2 100644 --- a/tutorial/02_kv_interface.py +++ b/tutorial/02_kv_interface.py @@ -167,7 +167,7 @@ def demonstrate_kv_api(): print(" Fetching only 'input_ids' to save bandwidth (ignoring 'attention_mask' and 'response').") all_keys = list(partition_info[partition_id].keys()) - retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, fields="input_ids") + retrieved_input_ids = tq.kv_batch_get(keys=all_keys, partition_id=partition_id, select_fields="input_ids") print(f" ✓ Successfully retrieved only {list(retrieved_input_ids.keys())} field for all samples.") # # Step 7: Retrieve all fields using kv_batch_get From c00e878b81cdc02b9b361836e2c5ccdb73150a67 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 09:34:28 +0800 Subject: [PATCH 02/11] add UT Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 92 ++++++++++++++++++++++++++++++ transfer_queue/interface.py | 22 ++++--- 2 files changed, 105 insertions(+), 9 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 7a04ba71..386dc56b 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -301,6 +301,58 @@ def test_kv_put_partial_update(self, controller): tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_returns_cumulative_fields(self, controller): + """Test that kv_put returns KVBatchMeta with cumulative fields (previous + new).""" + partition_id = "test_partition" + key = "sample_cumulative" + + # First put: only input_ids + first_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3]]), + }, + batch_size=1, + ) + first_meta = tq.kv_put(key=key, partition_id=partition_id, fields=first_data, tag={"step": 1}) + + # Verify first meta contains only input_ids + assert first_meta.fields is not None + assert "input_ids" in first_meta.fields + assert len(first_meta.fields) == 1 + + # Second put: add attention_mask + second_data = TensorDict( + { + "attention_mask": torch.tensor([[1, 1, 1]]), + }, + batch_size=1, + ) + second_meta = tq.kv_put(key=key, partition_id=partition_id, fields=second_data, tag={"step": 2}) + + # Verify second meta contains BOTH previous (input_ids) and new (attention_mask) fields + assert second_meta.fields is not None + assert "input_ids" in second_meta.fields, "Previous field 'input_ids' should be in returned fields" + assert "attention_mask" in second_meta.fields, "New field 'attention_mask' should be in returned fields" + assert len(second_meta.fields) == 2, f"Expected 2 fields, got {second_meta.fields}" + + # Third put: add response field + third_data = TensorDict( + { + "response": torch.tensor([[10, 20]]), + }, + batch_size=1, + ) + third_meta = tq.kv_put(key=key, partition_id=partition_id, fields=third_data, tag={"step": 3}) + + # Verify third meta contains ALL three fields + assert third_meta.fields is not None + assert "input_ids" in third_meta.fields, "Previous field 'input_ids' should still be present" + assert "attention_mask" in third_meta.fields, "Previous field 'attention_mask' should still be present" + assert "response" in third_meta.fields, "New field 'response' should be present" + assert len(third_meta.fields) == 3, f"Expected 3 fields, got {third_meta.fields}" + + tq.kv_clear(keys=key, partition_id=partition_id) + class TestKVBatchPutE2E: """End-to-end tests for kv_batch_put functionality.""" @@ -393,6 +445,46 @@ def test_kv_batch_put_partial_update(self, controller): tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_put_returns_cumulative_fields(self, controller): + """Test that kv_batch_put returns KVBatchMeta with cumulative fields (previous + new).""" + partition_id = "test_partition" + keys = ["batch_cumulative_0", "batch_cumulative_1"] + + # First batch put: only input_ids + first_data = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), + }, + batch_size=2, + ) + first_meta = tq.kv_batch_put( + keys=keys, partition_id=partition_id, fields=first_data, tags=[{"step": 1}, {"step": 1}] + ) + + # Verify first meta contains only input_ids + assert first_meta.fields is not None + assert "input_ids" in first_meta.fields + assert len(first_meta.fields) == 1 + + # Second batch put: add attention_mask for both keys + second_data = TensorDict( + { + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]), + }, + batch_size=2, + ) + second_meta = tq.kv_batch_put( + keys=keys, partition_id=partition_id, fields=second_data, tags=[{"step": 2}, {"step": 2}] + ) + + # Verify second meta contains BOTH previous (input_ids) and new (attention_mask) fields + assert second_meta.fields is not None + assert "input_ids" in second_meta.fields, "Previous field 'input_ids' should be in returned fields" + assert "attention_mask" in second_meta.fields, "New field 'attention_mask' should be in returned fields" + assert len(second_meta.fields) == 2, f"Expected 2 fields, got {second_meta.fields}" + + tq.kv_clear(keys=keys, partition_id=partition_id) + class TestKVGetE2E: """End-to-end tests for kv_batch_get functionality.""" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ce0fdb92..ce76305c 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -28,9 +28,9 @@ from tensordict import TensorDict from tensordict.tensorclass import NonTensorStack -from transfer_queue import KVBatchMeta from transfer_queue.client import TransferQueueClient from transfer_queue.controller import TransferQueueController +from transfer_queue.metadata import KVBatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler from transfer_queue.storage.simple_backend import SimpleStorageUnit @@ -433,11 +433,11 @@ def kv_put( # custom_meta (tag) will be put to controller through the internal put process # After put, batch_meta.field_names() will include the new fields written by user batch_meta = tq_client.put(fields, batch_meta) - fields_to_return = batch_meta.field_names() + fields_to_return = batch_meta.field_names else: # directly update custom_meta (tag) to controller tq_client.set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + fields_to_return = batch_meta.field_names if batch_meta.field_names else None return KVBatchMeta( keys=[key], @@ -516,11 +516,11 @@ def kv_batch_put( if fields is not None: # After put, batch_meta.field_names() will include the new fields written by user batch_meta = tq_client.put(fields, batch_meta) - fields_to_return = batch_meta.field_names() + fields_to_return = batch_meta.field_names else: # directly update custom_meta (tags) to controller tq_client.set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + fields_to_return = batch_meta.field_names if batch_meta.field_names else None return KVBatchMeta( keys=keys, @@ -605,6 +605,8 @@ def kv_batch_get( if select_fields is not None: if isinstance(select_fields, str): fields_to_fetch = [select_fields] + else: + fields_to_fetch = select_fields batch_meta = batch_meta.select_fields(fields_to_fetch) if not batch_meta.is_ready: @@ -764,11 +766,11 @@ async def async_kv_put( # custom_meta (tag) will be put to controller through the put process # After put, batch_meta.field_names() will include the new fields written by user await tq_client.async_put(fields, batch_meta) - fields_to_return = batch_meta.field_names() + fields_to_return = batch_meta.field_names else: # directly update custom_meta (tag) to controller await tq_client.async_set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + fields_to_return = batch_meta.field_names if batch_meta.field_names else None return KVBatchMeta( keys=[key], @@ -846,11 +848,11 @@ async def async_kv_batch_put( if fields is not None: # After put, batch_meta.field_names() will include the new fields written by user batch_meta = await tq_client.async_put(fields, batch_meta) - fields_to_return = batch_meta.field_names() + fields_to_return = batch_meta.field_names else: # directly update custom_meta (tags) to controller await tq_client.async_set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names() if batch_meta.field_names() else None + fields_to_return = batch_meta.field_names if batch_meta.field_names else None return KVBatchMeta( keys=keys, @@ -935,6 +937,8 @@ async def async_kv_batch_get( if select_fields is not None: if isinstance(select_fields, str): fields_to_fetch = [select_fields] + else: + fields_to_fetch = select_fields batch_meta = batch_meta.select_fields(fields_to_fetch) if not batch_meta.is_ready: From 3f725c342694297eb4e167bf0f8bfac0a26b225b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 09:51:02 +0800 Subject: [PATCH 03/11] provide async kv API test cover Signed-off-by: 0oshowero0 # Conflicts: # tests/e2e/test_kv_interface_e2e.py --- tests/e2e/test_kv_interface_e2e.py | 207 +++++++++++++++++------------ 1 file changed, 119 insertions(+), 88 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 386dc56b..c6ebfa9c 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -20,6 +20,7 @@ 2. Verifying correctness by calling TransferQueueController's internal methods directly """ +import asyncio import os import sys from pathlib import Path @@ -36,6 +37,34 @@ import transfer_queue as tq # noqa: E402 + +class TQAPIWrapper: + """Wrapper that routes kv_* calls to sync or async interface based on use_async flag.""" + + def __init__(self, use_async: bool): + self.use_async = use_async + + def __getattr__(self, name): + if name.startswith("kv_"): + if self.use_async: + async_func = getattr(tq, f"async_{name}") + return lambda *args, **kwargs: asyncio.run(async_func(*args, **kwargs)) + else: + return getattr(tq, name) + # For non-kv_ attributes (init, close), pass through directly + return getattr(tq, name) + + +@pytest.fixture(params=[False, True], ids=["sync", "async"]) +def tq_api(request): + """Returns a unified TQ API handle that routes to sync or async interface. + + When use_async=False (sync mode), calls tq.kv_* directly. + When use_async=True (async mode), calls tq.async_kv_* via asyncio.run(). + """ + return TQAPIWrapper(use_async=request.param) + + # Configure Ray for tests os.environ["RAY_DEDUP_LOGS"] = "0" @@ -152,25 +181,25 @@ def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""): class TestKVPutE2E: """End-to-end tests for kv_put functionality.""" - def test_kv_put_with_dict_fields(self, controller): + def test_kv_put_with_dict_fields(self, controller, tq_api): """Test kv_put with dict fields (auto-converted to TensorDict).""" partition_id = "test_partition" key = "sample_0" # Put with dict fields - will be auto-unsqueezed - tq.kv_put( + tq_api.kv_put( key=key, partition_id=partition_id, fields={"data": torch.tensor([1, 2, 3, 4])}, tag={"type": "dict_test"} ) # Verify - retrieved data will have batch dimension - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id) expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["data"], expected) # delete the key (MooncakeStore does not support updating existing key, so we need to clear it before next test) - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_put_with_tensordict_fields(self, controller): + def test_kv_put_with_tensordict_fields(self, controller, tq_api): """Test kv_put with tensordict fields.""" partition_id = "test_partition" key = "sample_1" @@ -182,16 +211,16 @@ def test_kv_put_with_tensordict_fields(self, controller): batch_size=1, ) # Put with dict fields - will be auto-unsqueezed - tq.kv_put(key=key, partition_id=partition_id, fields=tensordict_data, tag={"type": "tensordict_test"}) + tq_api.kv_put(key=key, partition_id=partition_id, fields=tensordict_data, tag={"type": "tensordict_test"}) # Verify - retrieved data will have batch dimension - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id) expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["input_ids"], expected) - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_put_single_sample_with_fields_and_tag(self, controller): + def test_kv_put_single_sample_with_fields_and_tag(self, controller, tq_api): """Test putting a single sample with fields and tag.""" partition_id = "test_partition" key = "sample_2" @@ -201,7 +230,7 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): tag = {"global_steps": 0, "status": "running"} # Put data using interface - tq.kv_put( + tq_api.kv_put( key=key, partition_id=partition_id, fields={"input_ids": input_ids, "attention_mask": attention_mask}, @@ -229,7 +258,7 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): assert partition.production_status[global_idx, input_ids_col_idx] == 1, "input_ids should be marked as produced" # Retrieve and verify data via kv_batch_get - tensors will have batch dimension - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id) assert "input_ids" in retrieved.keys() assert "attention_mask" in retrieved.keys() # After unsqueeze, tensors become 2D [batch_size=1, original_size] @@ -238,20 +267,20 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): assert_tensor_equal(retrieved["input_ids"], expected_input_ids) assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_put_update_tag_only(self, controller): + def test_kv_put_update_tag_only(self, controller, tq_api): """Test updating only tag without providing fields.""" partition_id = "test_partition" key = "sample_3" # First put with fields - use TensorDict as another example single_data = TensorDict({"value": torch.tensor([[10]])}, batch_size=1) - tq.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1}) + tq_api.kv_put(key=key, partition_id=partition_id, fields=single_data, tag={"version": 1}) # Update only tag new_tag = {"version": 2, "status": "updated"} - tq.kv_put(key=key, partition_id=partition_id, fields=None, tag=new_tag) + tq_api.kv_put(key=key, partition_id=partition_id, fields=None, tag=new_tag) # Verify via controller partition = get_controller_partition(controller, partition_id) @@ -260,12 +289,12 @@ def test_kv_put_update_tag_only(self, controller): assert partition.custom_meta[global_idx]["status"] == "updated" # Data should still be accessible - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["value"], torch.tensor([[10]])) - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_put_partial_update(self, controller): + def test_kv_put_partial_update(self, controller, tq_api): """Test adding new fields to existing sample.""" partition_id = "test_partition" key = "sample_4" @@ -277,7 +306,7 @@ def test_kv_put_partial_update(self, controller): }, batch_size=1, ) - tq.kv_put(key=key, partition_id=partition_id, fields=initial_data, tag={"v": 1}) + tq_api.kv_put(key=key, partition_id=partition_id, fields=initial_data, tag={"v": 1}) # Add new fields to subset of keys new_fields = TensorDict( @@ -286,7 +315,7 @@ def test_kv_put_partial_update(self, controller): }, batch_size=1, ) - tq.kv_put(key=key, partition_id=partition_id, fields=new_fields, tag={"v": 2}) + tq_api.kv_put(key=key, partition_id=partition_id, fields=new_fields, tag={"v": 2}) # Verify via controller - only keys[1] should have response field partition = get_controller_partition(controller, partition_id) @@ -299,9 +328,9 @@ def test_kv_put_partial_update(self, controller): # key should have response marked as produced assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response" - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_put_returns_cumulative_fields(self, controller): + def test_kv_put_returns_cumulative_fields(self, controller, tq_api): """Test that kv_put returns KVBatchMeta with cumulative fields (previous + new).""" partition_id = "test_partition" key = "sample_cumulative" @@ -313,7 +342,7 @@ def test_kv_put_returns_cumulative_fields(self, controller): }, batch_size=1, ) - first_meta = tq.kv_put(key=key, partition_id=partition_id, fields=first_data, tag={"step": 1}) + first_meta = tq_api.kv_put(key=key, partition_id=partition_id, fields=first_data, tag={"step": 1}) # Verify first meta contains only input_ids assert first_meta.fields is not None @@ -327,7 +356,7 @@ def test_kv_put_returns_cumulative_fields(self, controller): }, batch_size=1, ) - second_meta = tq.kv_put(key=key, partition_id=partition_id, fields=second_data, tag={"step": 2}) + second_meta = tq_api.kv_put(key=key, partition_id=partition_id, fields=second_data, tag={"step": 2}) # Verify second meta contains BOTH previous (input_ids) and new (attention_mask) fields assert second_meta.fields is not None @@ -342,7 +371,7 @@ def test_kv_put_returns_cumulative_fields(self, controller): }, batch_size=1, ) - third_meta = tq.kv_put(key=key, partition_id=partition_id, fields=third_data, tag={"step": 3}) + third_meta = tq_api.kv_put(key=key, partition_id=partition_id, fields=third_data, tag={"step": 3}) # Verify third meta contains ALL three fields assert third_meta.fields is not None @@ -351,13 +380,13 @@ def test_kv_put_returns_cumulative_fields(self, controller): assert "response" in third_meta.fields, "New field 'response' should be present" assert len(third_meta.fields) == 3, f"Expected 3 fields, got {third_meta.fields}" - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) class TestKVBatchPutE2E: """End-to-end tests for kv_batch_put functionality.""" - def test_kv_batch_put_multiple_samples(self, controller): + def test_kv_batch_put_multiple_samples(self, controller, tq_api): """Test batch putting multiple samples.""" partition_id = "test_partition" keys = ["batch_0", "batch_1", "batch_2", "batch_3"] @@ -382,7 +411,7 @@ def test_kv_batch_put_multiple_samples(self, controller): tags = [{"idx": i, "batch": True} for i in range(4)] # Batch put using interface - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) # Verify via controller partition = get_controller_partition(controller, partition_id) @@ -399,13 +428,13 @@ def test_kv_batch_put_multiple_samples(self, controller): assert partition.custom_meta[global_idx]["batch"] is True # Verify all data via kv_batch_get - retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=keys, partition_id=partition_id) assert_tensor_equal(retrieved["input_ids"], batch_input_ids) assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask) - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) - def test_kv_batch_put_partial_update(self, controller): + def test_kv_batch_put_partial_update(self, controller, tq_api): """Test adding new fields to existing samples.""" partition_id = "test_partition" keys = ["partial_0", "partial_1"] @@ -417,7 +446,7 @@ def test_kv_batch_put_partial_update(self, controller): }, batch_size=2, ) - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=initial_data, tags=[{"v": 1}, {"v": 1}]) + tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=initial_data, tags=[{"v": 1}, {"v": 1}]) # Add new fields to subset of keys new_fields = TensorDict( @@ -426,7 +455,7 @@ def test_kv_batch_put_partial_update(self, controller): }, batch_size=1, ) - tq.kv_batch_put(keys=[keys[1]], partition_id=partition_id, fields=new_fields, tags=[{"v": 2}]) + tq_api.kv_batch_put(keys=[keys[1]], partition_id=partition_id, fields=new_fields, tags=[{"v": 2}]) # Verify via controller - only keys[1] should have response field partition = get_controller_partition(controller, partition_id) @@ -443,9 +472,9 @@ def test_kv_batch_put_partial_update(self, controller): # keys[1] should have response marked as produced assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response" - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) - def test_kv_batch_put_returns_cumulative_fields(self, controller): + def test_kv_batch_put_returns_cumulative_fields(self, controller, tq_api): """Test that kv_batch_put returns KVBatchMeta with cumulative fields (previous + new).""" partition_id = "test_partition" keys = ["batch_cumulative_0", "batch_cumulative_1"] @@ -457,7 +486,7 @@ def test_kv_batch_put_returns_cumulative_fields(self, controller): }, batch_size=2, ) - first_meta = tq.kv_batch_put( + first_meta = tq_api.kv_batch_put( keys=keys, partition_id=partition_id, fields=first_data, tags=[{"step": 1}, {"step": 1}] ) @@ -473,7 +502,7 @@ def test_kv_batch_put_returns_cumulative_fields(self, controller): }, batch_size=2, ) - second_meta = tq.kv_batch_put( + second_meta = tq_api.kv_batch_put( keys=keys, partition_id=partition_id, fields=second_data, tags=[{"step": 2}, {"step": 2}] ) @@ -483,13 +512,13 @@ def test_kv_batch_put_returns_cumulative_fields(self, controller): assert "attention_mask" in second_meta.fields, "New field 'attention_mask' should be in returned fields" assert len(second_meta.fields) == 2, f"Expected 2 fields, got {second_meta.fields}" - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) class TestKVGetE2E: """End-to-end tests for kv_batch_get functionality.""" - def test_kv_batch_get_single_key(self, controller): + def test_kv_batch_get_single_key(self, controller, tq_api): """Test getting data for a single key.""" partition_id = "test_partition" key = "get_single" @@ -497,28 +526,28 @@ def test_kv_batch_get_single_key(self, controller): expected_data = torch.tensor([[100, 200, 300]]) fields = TensorDict({"data": expected_data}, batch_size=1) - tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + tq_api.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_batch_get_multiple_keys(self, controller): + def test_kv_batch_get_multiple_keys(self, controller, tq_api): """Test getting data for multiple keys.""" partition_id = "test_partition" keys = ["get_multi_0", "get_multi_1", "get_multi_2"] expected_data = torch.tensor([[1, 2], [3, 4], [5, 6]]) fields = TensorDict({"data": expected_data}, batch_size=3) - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) + tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) - retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) - def test_kv_batch_get_partial_keys(self, controller): + def test_kv_batch_get_partial_keys(self, controller, tq_api): """Test getting data for partial keys.""" partition_id = "test_partition" keys = ["get_multi_3", "get_multi_4", "get_multi_5"] @@ -535,9 +564,9 @@ def test_kv_batch_get_partial_keys(self, controller): fields = TensorDict( {"data": input_data, "nested_data": nested_data, "three_d_nested_data": three_d_nested_data}, batch_size=3 ) - tq.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) + tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}]) - retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) for actual, expected in zip(retrieved["nested_data"], expected_nested_data, strict=True): @@ -546,9 +575,9 @@ def test_kv_batch_get_partial_keys(self, controller): for actual, expected in zip(retrieved["three_d_nested_data"], expected_three_d_nested_data, strict=True): assert_tensor_equal(actual, expected) - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) - def test_kv_batch_get_partial_fields(self, controller): + def test_kv_batch_get_partial_fields(self, controller, tq_api): """Test getting only partial fields.""" partition_id = "test_partition" key = "get_fields" @@ -562,32 +591,32 @@ def test_kv_batch_get_partial_fields(self, controller): ) # Put all fields - tq.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) + tq_api.kv_put(key=key, partition_id=partition_id, fields=fields, tag=None) # Get only input_ids - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, select_fields="input_ids") + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id, select_fields="input_ids") assert "input_ids" in retrieved.keys() assert "attention_mask" not in retrieved.keys() assert "response" not in retrieved.keys() assert_tensor_equal(retrieved["input_ids"], input_ids) # Get multiple specific fields - retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id, select_fields=["input_ids", "response"]) + retrieved = tq_api.kv_batch_get(keys=key, partition_id=partition_id, select_fields=["input_ids", "response"]) assert "input_ids" in retrieved.keys() assert "response" in retrieved.keys() assert "attention_mask" not in retrieved.keys() assert_tensor_equal(retrieved["input_ids"], input_ids) assert_tensor_equal(retrieved["response"], response) - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) - def test_kv_batch_get_nonexistent_key(self, controller): + def test_kv_batch_get_nonexistent_key(self, controller, tq_api): """Test that getting data for non-existent key returns empty result.""" partition_id = "test_partition" # Try to get data for a key that doesn't exist - should return empty or raise error try: - retrieved = tq.kv_batch_get(keys="nonexistent_key", partition_id=partition_id) + retrieved = tq_api.kv_batch_get(keys="nonexistent_key", partition_id=partition_id) # If it returns, it should be empty assert retrieved.batch_size[0] == 0 except RuntimeError as e: @@ -598,16 +627,16 @@ def test_kv_batch_get_nonexistent_key(self, controller): class TestKVListE2E: """End-to-end tests for kv_list functionality.""" - def test_kv_list_single_partition(self, controller): + def test_kv_list_single_partition(self, controller, tq_api): """Test listing all keys and tags in single partition.""" partition_id = "test_partition" keys = ["list_0", "list_1", "list_2"] for i, key in enumerate(keys): - tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i}) + tq_api.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag={"id": i}) # List all keys - partition_info = tq.kv_list(partition_id=partition_id) + partition_info = tq_api.kv_list(partition_id=partition_id) assert len(partition_info.keys()) == 1 assert "test_partition" in partition_info.keys() @@ -619,9 +648,9 @@ def test_kv_list_single_partition(self, controller): for i, (key, tag) in enumerate(partition_info["test_partition"].items()): assert tag["id"] == i - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) - def test_kv_list_all_partitions(self, controller): + def test_kv_list_all_partitions(self, controller, tq_api): """Test listing keys and tags in all partitions.""" partition_id = ["test_partition0", "test_partition1", "test_partition2"] @@ -638,18 +667,18 @@ def test_kv_list_all_partitions(self, controller): tags_partition2 = [{"id": i + 6} for i in range(4)] # Put to TQ - tq.kv_batch_put( + tq_api.kv_batch_put( keys=keys_partition0, partition_id=partition_id[0], fields=fields_partition0, tags=tags_partition0 ) - tq.kv_batch_put( + tq_api.kv_batch_put( keys=keys_partition1, partition_id=partition_id[1], fields=fields_partition1, tags=tags_partition1 ) - tq.kv_batch_put( + tq_api.kv_batch_put( keys=keys_partition2, partition_id=partition_id[2], fields=fields_partition2, tags=tags_partition2 ) # List all keys - partition_info = tq.kv_list() + partition_info = tq_api.kv_list() # Verify all partitions are exist assert len(partition_info.keys()) == 3 @@ -677,15 +706,15 @@ def test_kv_list_all_partitions(self, controller): for i, (key, tag) in enumerate(partition_info["test_partition2"].items()): assert tag["id"] == i + 6 - tq.kv_clear(keys=keys_partition0, partition_id=partition_id[0]) - tq.kv_clear(keys=keys_partition1, partition_id=partition_id[1]) - tq.kv_clear(keys=keys_partition2, partition_id=partition_id[2]) + tq_api.kv_clear(keys=keys_partition0, partition_id=partition_id[0]) + tq_api.kv_clear(keys=keys_partition1, partition_id=partition_id[1]) + tq_api.kv_clear(keys=keys_partition2, partition_id=partition_id[2]) - def test_kv_list_empty_partition(self): + def test_kv_list_empty_partition(self, tq_api): """Test listing empty partition.""" partition_id = "test_partition_empty" - partition_info = tq.kv_list(partition_id=partition_id) + partition_info = tq_api.kv_list(partition_id=partition_id) assert len(partition_info) == 0 @@ -693,20 +722,22 @@ def test_kv_list_empty_partition(self): class TestKVClearE2E: """End-to-end tests for kv_clear functionality.""" - def test_kv_clear_single_key(self, controller): + def test_kv_clear_single_key(self, controller, tq_api): """Test clearing a single key.""" partition_id = "test_partition" key = "clear_single" other_key = "clear_other" - tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[1]])}, tag={"id": "single"}) - tq.kv_put(key=other_key, partition_id=partition_id, fields={"data": torch.tensor([[2]])}, tag={"id": "other"}) + tq_api.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[1]])}, tag={"id": "single"}) + tq_api.kv_put( + key=other_key, partition_id=partition_id, fields={"data": torch.tensor([[2]])}, tag={"id": "other"} + ) # Clear single key - tq.kv_clear(keys=key, partition_id=partition_id) + tq_api.kv_clear(keys=key, partition_id=partition_id) # Verify via kv_list - partition_info = tq.kv_list(partition_id=partition_id) + partition_info = tq_api.kv_list(partition_id=partition_id) assert key not in partition_info[partition_id] assert other_key in partition_info[partition_id] @@ -715,46 +746,46 @@ def test_kv_clear_single_key(self, controller): assert key not in partition.keys_mapping assert other_key in partition.keys_mapping - tq.kv_clear(keys=other_key, partition_id=partition_id) + tq_api.kv_clear(keys=other_key, partition_id=partition_id) - def test_kv_clear_multiple_keys(self, controller): + def test_kv_clear_multiple_keys(self, controller, tq_api): """Test clearing multiple keys.""" partition_id = "test_partition" keys = ["clear_multi_0", "clear_multi_1", "clear_multi_2", "clear_multi_3"] for i, key in enumerate(keys): - tq.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag=None) + tq_api.kv_put(key=key, partition_id=partition_id, fields={"data": torch.tensor([[i]])}, tag=None) # Clear first 2 keys - tq.kv_clear(keys=keys[:2], partition_id=partition_id) + tq_api.kv_clear(keys=keys[:2], partition_id=partition_id) # Verify - partition_info = tq.kv_list(partition_id=partition_id) + partition_info = tq_api.kv_list(partition_id=partition_id) assert len(partition_info[partition_id]) == 2 assert keys[0] not in partition_info[partition_id] assert keys[1] not in partition_info[partition_id] assert keys[2] in partition_info[partition_id] assert keys[3] in partition_info[partition_id] - tq.kv_clear(keys=keys[2:], partition_id=partition_id) + tq_api.kv_clear(keys=keys[2:], partition_id=partition_id) class TestKVE2ECornerCases: """End-to-end tests for corner cases.""" - def test_field_expansion_across_samples(self, controller): + def test_field_expansion_across_samples(self, controller, tq_api): """Test that new fields can be added across samples.""" partition_id = "test_partition" keys = ["expand_0", "expand_1"] # Put initial fields - tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_a": torch.tensor([[1]])}, tag=None) + tq_api.kv_put(key=keys[0], partition_id=partition_id, fields={"field_a": torch.tensor([[1]])}, tag=None) # Add new field to first key - tq.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None) + tq_api.kv_put(key=keys[0], partition_id=partition_id, fields={"field_b": torch.tensor([[2]])}, tag=None) # Add different field to second key - tq.kv_put( + tq_api.kv_put( key=keys[1], partition_id=partition_id, fields={"field_a": torch.tensor([[3]]), "field_c": torch.tensor([[4]])}, @@ -770,12 +801,12 @@ def test_field_expansion_across_samples(self, controller): assert "field_c" in partition.field_name_mapping # We can only fetch "field_a" because not all requested keys has other fields - data = tq.kv_batch_get(keys=keys, partition_id=partition_id) + data = tq_api.kv_batch_get(keys=keys, partition_id=partition_id) assert "field_a" in data assert "field_b" not in data assert "field_c" not in data - tq.kv_clear(keys=keys, partition_id=partition_id) + tq_api.kv_clear(keys=keys, partition_id=partition_id) def run_tests(): From aab02f48d1e305ca0bb7c7ce9168959c23e14b6c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 10:50:38 +0800 Subject: [PATCH 04/11] fix comments Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 52 ++++++++++++++++++++++++++++++ transfer_queue/interface.py | 40 ++++++++++++----------- 2 files changed, 73 insertions(+), 19 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index c6ebfa9c..aa96e8bd 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -624,6 +624,58 @@ def test_kv_batch_get_nonexistent_key(self, controller, tq_api): assert "not found" in str(e).lower() or "empty" in str(e).lower() +class TestKVBatchGetByMetaE2E: + """End-to-end tests for kv_batch_get_by_meta functionality.""" + + def test_kv_batch_get_by_meta_from_kv_batch_put(self, controller, tq_api): + """Test kv_batch_get_by_meta using KVBatchMeta returned from kv_batch_put.""" + partition_id = "test_partition" + keys = ["meta_batch_0", "meta_batch_1", "meta_batch_2"] + expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + expected_attention_mask = torch.ones_like(expected_input_ids) + + fields = TensorDict( + { + "input_ids": expected_input_ids, + "attention_mask": expected_attention_mask, + }, + batch_size=3, + ) + tags = [{"idx": i} for i in range(3)] + + # Batch put and get KVBatchMeta + meta = tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + # Retrieve using kv_batch_get_by_meta + retrieved = tq_api.kv_batch_get_by_meta(meta) + assert_tensor_equal(retrieved["input_ids"], expected_input_ids) + assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + + tq_api.kv_clear(keys=keys, partition_id=partition_id) + + def test_kv_batch_get_by_meta_multiple_puts(self, controller, tq_api): + """Test kv_batch_get_by_meta with data from multiple sequential puts.""" + partition_id = "test_partition" + keys = ["meta_multi_0", "meta_multi_1"] + + # First put + first_data = TensorDict({"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]])}, batch_size=2) + tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=first_data, tags=[{}, {}]) + + # Second put adds more fields + second_data = TensorDict({"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]])}, batch_size=2) + second_meta = tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=second_data, tags=[{}, {}]) + + # Use second meta (contains both fields) + retrieved = tq_api.kv_batch_get_by_meta(second_meta) + assert "input_ids" in retrieved.keys() + assert "attention_mask" in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 6]])) + assert_tensor_equal(retrieved["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) + + tq_api.kv_clear(keys=keys, partition_id=partition_id) + + class TestKVListE2E: """End-to-end tests for kv_list functionality.""" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ce76305c..d31beb37 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -430,14 +430,13 @@ def kv_put( elif not isinstance(fields, TensorDict): raise ValueError("field can only be dict or TensorDict") - # custom_meta (tag) will be put to controller through the internal put process - # After put, batch_meta.field_names() will include the new fields written by user + # After put, batch_meta.field_names will include the new fields written by user batch_meta = tq_client.put(fields, batch_meta) - fields_to_return = batch_meta.field_names else: - # directly update custom_meta (tag) to controller + # Directly update custom_meta (tag) to controller tq_client.set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names if batch_meta.field_names else None + + fields_to_return = batch_meta.field_names return KVBatchMeta( keys=[key], @@ -514,13 +513,13 @@ def kv_batch_put( # 3. put data if fields is not None: - # After put, batch_meta.field_names() will include the new fields written by user + # After put, batch_meta.field_names will include the new fields written by user batch_meta = tq_client.put(fields, batch_meta) - fields_to_return = batch_meta.field_names else: - # directly update custom_meta (tags) to controller + # Directly update custom_meta (tags) to controller tq_client.set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names if batch_meta.field_names else None + + fields_to_return = batch_meta.field_names return KVBatchMeta( keys=keys, @@ -561,6 +560,8 @@ def kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: >>> # Then retrieve it using the returned metadata >>> data = tq.kv_batch_get_by_meta(meta) """ + if meta.partition_id is None: + raise ValueError("Must provide partition_id in the input KVBatchMeta.") return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=meta.fields) @@ -763,14 +764,13 @@ async def async_kv_put( elif not isinstance(fields, TensorDict): raise ValueError("field can only be dict or TensorDict") - # custom_meta (tag) will be put to controller through the put process - # After put, batch_meta.field_names() will include the new fields written by user - await tq_client.async_put(fields, batch_meta) - fields_to_return = batch_meta.field_names + # After put, batch_meta.field_names will include the new fields written by user + batch_meta = await tq_client.async_put(fields, batch_meta) else: - # directly update custom_meta (tag) to controller + # Directly update custom_meta (tag) to controller await tq_client.async_set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names if batch_meta.field_names else None + + fields_to_return = batch_meta.field_names return KVBatchMeta( keys=[key], @@ -846,13 +846,13 @@ async def async_kv_batch_put( # 3. put data if fields is not None: - # After put, batch_meta.field_names() will include the new fields written by user + # After put, batch_meta.field_names will include the new fields written by user batch_meta = await tq_client.async_put(fields, batch_meta) - fields_to_return = batch_meta.field_names else: - # directly update custom_meta (tags) to controller + # Directly update custom_meta (tags) to controller await tq_client.async_set_custom_meta(batch_meta) - fields_to_return = batch_meta.field_names if batch_meta.field_names else None + + fields_to_return = batch_meta.field_names return KVBatchMeta( keys=keys, @@ -893,6 +893,8 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: >>> # Then retrieve it using the returned metadata >>> data = await tq.async_kv_batch_get_by_meta(meta) """ + if meta.partition_id is None: + raise ValueError("Must provide partition_id in the input KVBatchMeta.") return await async_kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=meta.fields) From 44f765bd837694c384ea7ecfc0125ed50da7303a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 11:11:26 +0800 Subject: [PATCH 05/11] fix Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 2 ++ transfer_queue/interface.py | 46 +++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 4732e84a..4ec041f4 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -46,11 +46,13 @@ "kv_put", "kv_batch_put", "kv_batch_get", + "kv_batch_get_by_meta", "kv_list", "kv_clear", "async_kv_put", "async_kv_batch_put", "async_kv_batch_get", + "async_kv_batch_get_by_meta", "async_kv_list", "async_kv_clear", "KVBatchMeta", diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index d31beb37..8bddae9c 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -506,7 +506,7 @@ def kv_batch_put( ) # 2. register the user-specified tags to BatchMeta - if tags: + if tags is not None: if len(tags) != len(keys): raise ValueError(f"keys with length {len(keys)} does not match length of tags {len(tags)}") batch_meta.update_custom_meta(tags) @@ -545,18 +545,19 @@ def kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: TensorDict with the requested data Raises: - RuntimeError: If keys or partition are not found - RuntimeError: If empty fields exist in any key (sample) + ValueError: If keys or partition are not found + ValueError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq >>> tq.init() >>> # First put some data - >>> meta = tq.kv_batch_put( - ... keys=["sample_1", "sample_2"], - ... partition_id="train", - ... fields={"input_ids": torch.randn(2, 10)}, - ... ) + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields) >>> # Then retrieve it using the returned metadata >>> data = tq.kv_batch_get_by_meta(meta) """ @@ -581,8 +582,8 @@ def kv_batch_get( TensorDict with the requested data Raises: - RuntimeError: If keys or partition are not found - RuntimeError: If empty fields exist in any key (sample) + ValueError: If keys or partition are not found + ValueError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq @@ -601,7 +602,7 @@ def kv_batch_get( batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size == 0: - raise RuntimeError("keys or partition were not found!") + raise ValueError("keys or partition were not found!") if select_fields is not None: if isinstance(select_fields, str): @@ -611,7 +612,7 @@ def kv_batch_get( batch_meta = batch_meta.select_fields(fields_to_fetch) if not batch_meta.is_ready: - raise RuntimeError("Some fields are not ready in all the requested keys!") + raise ValueError("Some fields are not ready in all the requested keys!") data = tq_client.get_data(batch_meta) return data @@ -745,7 +746,7 @@ async def async_kv_put( raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") # 2. register the user-specified tag to BatchMeta - if tag: + if tag is not None: batch_meta.update_custom_meta([tag]) # 3. put data @@ -885,11 +886,12 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: >>> import transfer_queue as tq >>> tq.init() >>> # First put some data - >>> meta = await tq.async_kv_batch_put( - ... keys=["sample_1", "sample_2"], - ... partition_id="train", - ... fields={"input_ids": torch.randn(2, 10)}, - ... ) + >>> keys = ["sample_1", "sample_2", "sample_3"] + >>> fields = TensorDict({ + ... "input_ids": torch.randn(3, 10), + ... "attention_mask": torch.ones(3, 10), + ... }, batch_size=3) + >>> meta = await tq.async_kv_batch_put(keys=keys, partition_id="train", fields=fields) >>> # Then retrieve it using the returned metadata >>> data = await tq.async_kv_batch_get_by_meta(meta) """ @@ -914,8 +916,8 @@ async def async_kv_batch_get( TensorDict with the requested data Raises: - RuntimeError: If keys or partition are not found - RuntimeError: If empty fields exist in any key (sample) + ValueError: If keys or partition are not found + ValueError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq @@ -934,7 +936,7 @@ async def async_kv_batch_get( batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size == 0: - raise RuntimeError("keys or partition were not found!") + raise ValueError("keys or partition were not found!") if select_fields is not None: if isinstance(select_fields, str): @@ -944,7 +946,7 @@ async def async_kv_batch_get( batch_meta = batch_meta.select_fields(fields_to_fetch) if not batch_meta.is_ready: - raise RuntimeError("Some fields are not ready in all the requested keys!") + raise ValueError("Some fields are not ready in all the requested keys!") data = await tq_client.async_get_data(batch_meta) return data From 888dfa7a5666147042fbb50197149b5145426d30 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 11:20:58 +0800 Subject: [PATCH 06/11] fix import Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 4ec041f4..2d242741 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -19,6 +19,7 @@ from .dataloader import StreamingDataLoader, StreamingDataset from .interface import ( async_kv_batch_get, + async_kv_batch_get_by_meta, async_kv_batch_put, async_kv_clear, async_kv_list, @@ -27,6 +28,7 @@ get_client, init, kv_batch_get, + kv_batch_get_by_meta, kv_batch_put, kv_clear, kv_list, From 177f1561e607d756e0652b15e42052174e15e271 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 11:32:48 +0800 Subject: [PATCH 07/11] fix Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 2 +- transfer_queue/interface.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index aa96e8bd..335c34d2 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -619,7 +619,7 @@ def test_kv_batch_get_nonexistent_key(self, controller, tq_api): retrieved = tq_api.kv_batch_get(keys="nonexistent_key", partition_id=partition_id) # If it returns, it should be empty assert retrieved.batch_size[0] == 0 - except RuntimeError as e: + except ValueError as e: # Or it might raise an error about keys not found assert "not found" in str(e).lower() or "empty" in str(e).lower() diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 8bddae9c..ecffb69c 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -879,8 +879,8 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: TensorDict with the requested data Raises: - RuntimeError: If keys or partition are not found - RuntimeError: If empty fields exist in any key (sample) + ValueError: If keys or partition are not found + ValueError: If empty fields exist in any key (sample) Example: >>> import transfer_queue as tq From 32b94f40a496073eabd80c4b8695a20cca755b07 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 16:06:05 +0800 Subject: [PATCH 08/11] enhance get by meta function Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 76 ++++++++++++++++++++++++++++++ transfer_queue/interface.py | 52 ++++++++++++++++---- 2 files changed, 120 insertions(+), 8 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 335c34d2..4fac9d23 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -627,6 +627,82 @@ def test_kv_batch_get_nonexistent_key(self, controller, tq_api): class TestKVBatchGetByMetaE2E: """End-to-end tests for kv_batch_get_by_meta functionality.""" + def test_kv_batch_get_by_meta_select_fields_override(self, controller, tq_api): + """Test kv_batch_get_by_meta with select_fields to override meta.fields.""" + partition_id = "test_partition" + keys = ["meta_override_0", "meta_override_1", "meta_override_2"] + expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + expected_attention_mask = torch.ones_like(expected_input_ids) + expected_response = torch.tensor([[10, 20], [30, 40], [50, 60]]) + + fields = TensorDict( + { + "input_ids": expected_input_ids, + "attention_mask": expected_attention_mask, + "response": expected_response, + }, + batch_size=3, + ) + tags = [{"idx": i} for i in range(3)] + + # Batch put all fields + meta = tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=tags) + + # Verify meta.fields contains all fields + assert "input_ids" in meta.fields + assert "attention_mask" in meta.fields + assert "response" in meta.fields + assert len(meta.fields) == 3 + + # Retrieve using kv_batch_get_by_meta with select_fields override - only input_ids + retrieved = tq_api.kv_batch_get_by_meta(meta, select_fields="input_ids") + assert "input_ids" in retrieved.keys() + assert "attention_mask" not in retrieved.keys() + assert "response" not in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], expected_input_ids) + + # Retrieve using kv_batch_get_by_meta with select_fields override - subset of fields + retrieved = tq_api.kv_batch_get_by_meta(meta, select_fields=["attention_mask", "response"]) + assert "input_ids" not in retrieved.keys() + assert "attention_mask" in retrieved.keys() + assert "response" in retrieved.keys() + assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + assert_tensor_equal(retrieved["response"], expected_response) + + # Retrieve without select_fields - should get all fields from meta + retrieved = tq_api.kv_batch_get_by_meta(meta) + assert "input_ids" in retrieved.keys() + assert "attention_mask" in retrieved.keys() + assert "response" in retrieved.keys() + assert_tensor_equal(retrieved["input_ids"], expected_input_ids) + assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + assert_tensor_equal(retrieved["response"], expected_response) + + tq_api.kv_clear(keys=keys, partition_id=partition_id) + + def test_kv_batch_get_by_meta_select_fields_invalid(self, controller, tq_api): + """Test kv_batch_get_by_meta raises error when select_fields contains invalid field.""" + partition_id = "test_partition" + keys = ["meta_invalid_0", "meta_invalid_1"] + fields = TensorDict( + { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), + }, + batch_size=2, + ) + + meta = tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}]) + + # Try to retrieve with a field that doesn't exist in meta + with pytest.raises(ValueError, match=r"select_fields.*not found"): + tq_api.kv_batch_get_by_meta(meta, select_fields="nonexistent_field") + + # Try to retrieve with mix of valid and invalid fields + with pytest.raises(ValueError, match=r"select_fields.*not found"): + tq_api.kv_batch_get_by_meta(meta, select_fields=["input_ids", "invalid_field"]) + + tq_api.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_get_by_meta_from_kv_batch_put(self, controller, tq_api): """Test kv_batch_get_by_meta using KVBatchMeta returned from kv_batch_put.""" partition_id = "test_partition" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ecffb69c..3c3f0178 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -530,16 +530,20 @@ def kv_batch_put( ) -def kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: +def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | str] = None) -> TensorDict: """Get data from TransferQueue using KVBatchMeta. This is a convenience method for retrieving data using KVBatchMeta returned - from a previous put operation. It extracts the keys, partition_id, and fields - from the metadata to fetch the corresponding data. + from a previous put operation. It extracts the keys and partition_id from + the metadata to fetch the corresponding data. Args: meta: KVBatchMeta object returned from a previous put operation (e.g., kv_put, kv_batch_put). It contains keys, partition_id, and fields information. + select_fields: Optional field(s) to retrieve, which overrides the fields + recorded in the given KVBatchMeta. If None, uses all fields + from meta.fields. Can be a single field name (str) or a list + of field names. Returns: TensorDict with the requested data @@ -547,6 +551,7 @@ def kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: Raises: ValueError: If keys or partition are not found ValueError: If empty fields exist in any key (sample) + ValueError: If any field in select_fields doesn't exist in KVBatchMeta.fields Example: >>> import transfer_queue as tq @@ -563,7 +568,20 @@ def kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: """ if meta.partition_id is None: raise ValueError("Must provide partition_id in the input KVBatchMeta.") - return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=meta.fields) + if select_fields is not None: + if isinstance(select_fields, str): + fields_to_fetch = [select_fields] + else: + fields_to_fetch = select_fields + + if any(f not in meta.fields for f in fields_to_fetch): + raise ValueError( + f"Some fields assigned in select_fields not found in the metadata. " + f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}." + ) + else: + fields_to_fetch = meta.fields + return kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch) def kv_batch_get( @@ -864,16 +882,20 @@ async def async_kv_batch_put( ) -async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: +async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | str] = None) -> TensorDict: """Asynchronously get data from TransferQueue using KVBatchMeta. This is a convenience method for retrieving data using KVBatchMeta returned - from a previous put operation. It extracts the keys, partition_id, and fields - from the metadata to fetch the corresponding data. + from a previous put operation. It extracts the keys and partition_id from + the metadata to fetch the corresponding data. Args: meta: KVBatchMeta object returned from a previous put operation (e.g., async_kv_put, async_kv_batch_put). It contains keys, partition_id, and fields information. + select_fields: Optional field(s) to retrieve, which overrides the fields + recorded in the given KVBatchMeta. If None, uses all fields + from meta.fields. Can be a single field name (str) or a list + of field names. Returns: TensorDict with the requested data @@ -881,6 +903,7 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: Raises: ValueError: If keys or partition are not found ValueError: If empty fields exist in any key (sample) + ValueError: If any field in select_fields doesn't exist in KVBatchMeta.fields Example: >>> import transfer_queue as tq @@ -897,7 +920,20 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta) -> TensorDict: """ if meta.partition_id is None: raise ValueError("Must provide partition_id in the input KVBatchMeta.") - return await async_kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=meta.fields) + if select_fields is not None: + if isinstance(select_fields, str): + fields_to_fetch = [select_fields] + else: + fields_to_fetch = select_fields + + if any(f not in meta.fields for f in fields_to_fetch): + raise ValueError( + f"Some fields assigned in select_fields not found in the metadata. " + f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}." + ) + else: + fields_to_fetch = meta.fields + return await async_kv_batch_get(keys=meta.keys, partition_id=meta.partition_id, select_fields=fields_to_fetch) async def async_kv_batch_get( From 0725f8912276b67520f4a5bab3973076880f6627 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 16:23:50 +0800 Subject: [PATCH 09/11] fix pre commit Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 3c3f0178..8a951ed3 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -570,11 +570,12 @@ def kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[list[str] | raise ValueError("Must provide partition_id in the input KVBatchMeta.") if select_fields is not None: if isinstance(select_fields, str): - fields_to_fetch = [select_fields] + fields_to_fetch: Optional[list[str]] = [select_fields] else: fields_to_fetch = select_fields - if any(f not in meta.fields for f in fields_to_fetch): + assert fields_to_fetch is not None + if meta.fields is None or any(f not in meta.fields for f in fields_to_fetch): raise ValueError( f"Some fields assigned in select_fields not found in the metadata. " f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}." @@ -622,6 +623,7 @@ def kv_batch_get( if batch_meta.size == 0: raise ValueError("keys or partition were not found!") + fields_to_fetch: list[str] | None if select_fields is not None: if isinstance(select_fields, str): fields_to_fetch = [select_fields] @@ -920,13 +922,16 @@ async def async_kv_batch_get_by_meta(meta: KVBatchMeta, select_fields: Optional[ """ if meta.partition_id is None: raise ValueError("Must provide partition_id in the input KVBatchMeta.") + + fields_to_fetch: list[str] | None if select_fields is not None: if isinstance(select_fields, str): fields_to_fetch = [select_fields] else: fields_to_fetch = select_fields - if any(f not in meta.fields for f in fields_to_fetch): + assert fields_to_fetch is not None + if meta.fields is None or any(f not in meta.fields for f in fields_to_fetch): raise ValueError( f"Some fields assigned in select_fields not found in the metadata. " f"Assigned: {fields_to_fetch}; Fields in KVBatchMeta: {meta.fields}." From 0e510043c8cbd4286cf3188865a7c2f52f4cbb5d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 16:52:21 +0800 Subject: [PATCH 10/11] fix Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 8a951ed3..14fcd22c 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -428,7 +428,7 @@ def kv_put( batch[field_name] = NonTensorStack(value) fields = TensorDict(batch, batch_size=[1]) elif not isinstance(fields, TensorDict): - raise ValueError("field can only be dict or TensorDict") + raise ValueError("`fields` can only be dict or TensorDict") # After put, batch_meta.field_names will include the new fields written by user batch_meta = tq_client.put(fields, batch_meta) @@ -783,7 +783,7 @@ async def async_kv_put( batch[field_name] = NonTensorStack(value) fields = TensorDict(batch, batch_size=[1]) elif not isinstance(fields, TensorDict): - raise ValueError("field can only be dict or TensorDict") + raise ValueError("`fields` can only be dict or TensorDict") # After put, batch_meta.field_names will include the new fields written by user batch_meta = await tq_client.async_put(fields, batch_meta) From 34fb3004c73db8224ecae5a079af58cc2b196be6 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 25 Mar 2026 17:08:06 +0800 Subject: [PATCH 11/11] update demo Signed-off-by: 0oshowero0 --- .../simple_use_case/single_controller_demo.py | 47 +++++++++---------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/recipe/simple_use_case/single_controller_demo.py b/recipe/simple_use_case/single_controller_demo.py index f23f1ebe..5f5399fb 100644 --- a/recipe/simple_use_case/single_controller_demo.py +++ b/recipe/simple_use_case/single_controller_demo.py @@ -67,16 +67,15 @@ def train_mini_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta: assert self.role == "actor" # 1. Pull data from storage - data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields) + data = tq.kv_batch_get_by_meta(meta=kv_meta) logger.info(f"train_mini_batch: got data {data}") # 2. Compute loss output = compute_loss(data["old_log_prob"], data["ref_log_prob"]) output = TensorDict({"loss": output}, batch_size=output.size(0)) - kv_meta.fields.append("loss") # 3. Write back - tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) + kv_meta = tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) logger.info("train_mini_batch: put data done") return kv_meta @@ -84,22 +83,20 @@ def train_mini_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta: def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta: """Simulate forward-only inference""" # 1. Pull data from storage - data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields) + data = tq.kv_batch_get_by_meta(meta=kv_meta) logger.info(f"compute_log_prob: got data {data}") # 2. Model forward output = compute_log_prob(data["input_ids"], data["generate_sequences_ids"]) if self.role == "actor": output = TensorDict({"old_log_prob": output}, batch_size=output.size(0)) - kv_meta.fields.append("old_log_prob") elif self.role == "ref": output = TensorDict({"ref_log_prob": output}, batch_size=output.size(0)) - kv_meta.fields.append("ref_log_prob") else: raise ValueError(f"Role {self.role} not supported.") # 3. Write back - tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) + kv_meta = tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) logger.info("infer_batch: put data done") return kv_meta @@ -134,7 +131,7 @@ def __init__(self, config): tq.init(config) async def generate(self, kv_meta: KVBatchMeta) -> KVBatchMeta: - data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields) + data = tq.kv_batch_get_by_meta(meta=kv_meta) logger.info(f"demo get data -> generate_sequences {data}") data = data["input_ids"] @@ -151,9 +148,8 @@ async def generate(self, kv_meta: KVBatchMeta) -> KVBatchMeta: }, batch_size=data.size(0), ) - kv_meta.fields.extend(["generate_sequences_ids", "non_tensor_data", "nested_tensor"]) - tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) + kv_meta = tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) logger.info("demo Async Server put data to storages done") return kv_meta @@ -240,47 +236,46 @@ def fit(self): time.sleep(5) # ========================= Sample generate KVBatchMeta ========================= - # TODO: Can be optimized by letting kv_batch_put returns KVBatchMeta directly sampled_keys = random.sample(batch_keys, self.config.global_batch_size) - gen_meta = KVBatchMeta( + meta = KVBatchMeta( keys=sampled_keys, tags=[{} for _ in sampled_keys], partition_id=f"train_{step}", fields=["input_ids", "attention_mask"], ) - logger.info(f"demo get gen KVBatchMeta {gen_meta}") + logger.info(f"demo get KVBatchMeta {meta}") # ========================= Rollout: generate sequences ========================= - gen_meta = self.async_rollout_manager.generate_sequences(gen_meta) - logger.info(f"demo get after gen KVBatchMeta {gen_meta}") + meta = self.async_rollout_manager.generate_sequences(meta) + logger.info(f"demo get after gen KVBatchMeta {meta}") # ========================= Compute ref log prob ========================= - gen_meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"] - ref_log_prob_meta = self.actor_rollout_wg.compute_ref_log_prob(gen_meta) - logger.info(f"demo get ref log prob KVBatchMeta: {ref_log_prob_meta}") + meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"] + meta = self.actor_rollout_wg.compute_ref_log_prob(meta) + logger.info(f"demo get ref log prob KVBatchMeta: {meta}") # ========================= Compute old log prob ========================= - gen_meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"] - old_log_prob_meta = self.actor_rollout_wg.compute_log_prob(gen_meta) - logger.info(f"demo get old log prob KVBatchMeta: {old_log_prob_meta}") + meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"] + meta = self.actor_rollout_wg.compute_log_prob(meta) + logger.info(f"demo get old log prob KVBatchMeta: {meta}") # ========================= Compute reward ========================= # Simulated inline; in real training this calls a reward model worker - gen_meta.fields = ["generate_sequences_ids", "ref_log_prob", "old_log_prob"] + meta.fields = ["generate_sequences_ids", "ref_log_prob", "old_log_prob"] logger.info("demo computing reward (simulated)") time.sleep(1) - logger.info(f"demo reward KVBatchMeta: {gen_meta}") + logger.info(f"demo reward KVBatchMeta: {meta}") # ========================= Update actor ========================= - gen_meta.fields = [ + meta.fields = [ "input_ids", "attention_mask", "generate_sequences_ids", "old_log_prob", "ref_log_prob", ] - train_meta = self.actor_rollout_wg.update_actor(gen_meta) - logger.info(f"demo get after update actor KVBatchMeta: {train_meta}") + meta = self.actor_rollout_wg.update_actor(meta) + logger.info(f"demo get after update actor KVBatchMeta: {meta}") # ========================= Sync weights to rollout ========================= asyncio.run(self.actor_rollout_wg.update_weights(global_steps=step))