diff --git a/tests/test_client.py b/tests/test_client.py index 5d308d83..6729196f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -144,6 +144,9 @@ def _handle_requests(self): "message": "Consumption reset successfully", } response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_META: + response_body = self._mock_kv_retrieve_meta(request_msg.body) + response_type = ZMQRequestType.KV_RETRIEVE_META_RESPONSE elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: response_body = self._mock_kv_retrieve_keys(request_msg.body) response_type = ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE @@ -193,7 +196,7 @@ def _mock_batch_meta(self, request_body): return {"metadata": metadata} - def _mock_kv_retrieve_keys(self, request_body): + def _mock_kv_retrieve_meta(self, request_body): """Mock KV retrieve keys response.""" keys = request_body.get("keys", []) create = request_body.get("create", False) @@ -250,6 +253,42 @@ def _mock_kv_list(self, request_body): return {"partition_info": {partition_id: {k: {} for k in keys}}, "message": "success"} + def _mock_kv_retrieve_keys(self, request_body): + """Mock KV retrieve indexes response.""" + global_indexes = request_body.get("global_indexes", []) + partition_id = request_body.get("partition_id", "") + + # Initialize key tracking if not exists + if not hasattr(self, "_kv_partition_keys"): + self._kv_partition_keys = {} + + # Initialize index to key mapping if not exists + if not hasattr(self, "_kv_index_to_key"): + self._kv_index_to_key = {} + + # Get keys for this partition + partition_keys = self._kv_partition_keys.get(partition_id, []) + + # Build reverse mapping from index to key if needed + if not hasattr(self, "_kv_partition_index_map"): + self._kv_partition_index_map = {} + + if partition_id not in self._kv_partition_index_map: + # Build the mapping from stored keys + start_idx = self._get_next_kv_index(partition_id) - len(partition_keys) + self._kv_partition_index_map[partition_id] = {} + for i, key in enumerate(partition_keys): + self._kv_partition_index_map[partition_id][start_idx + i] = key + + index_map = self._kv_partition_index_map.get(partition_id, {}) + + # Retrieve keys for the given global_indexes + keys = [] + for idx in global_indexes: + keys.append(index_map.get(idx, None)) + + return {"keys": keys} + def _get_next_kv_index(self, partition_id): """Get next available index for KV keys in partition.""" if not hasattr(self, "_kv_index_map"): @@ -970,12 +1009,12 @@ class TestClientKVInterface: """Tests for client KV interface methods.""" @pytest.mark.asyncio - async def test_async_kv_retrieve_keys_single(self, client_setup): - """Test async_kv_retrieve_keys with single key.""" + async def test_async_kv_retrieve_meta_single(self, client_setup): + """Test async_kv_retrieve_meta with single key.""" client, _, _ = client_setup - # Test async_kv_retrieve_keys with single key - metadata = await client.async_kv_retrieve_keys( + # Test async_kv_retrieve_meta with single key + metadata = await client.async_kv_retrieve_meta( keys="test_key_1", partition_id="test_partition", create=True, @@ -988,13 +1027,13 @@ async def test_async_kv_retrieve_keys_single(self, client_setup): assert metadata.size == 1 @pytest.mark.asyncio - async def test_async_kv_retrieve_keys_multiple(self, client_setup): - """Test async_kv_retrieve_keys with multiple keys.""" + async def test_async_kv_retrieve_meta_multiple(self, client_setup): + """Test async_kv_retrieve_meta with multiple keys.""" client, _, _ = client_setup - # Test async_kv_retrieve_keys with multiple keys + # Test async_kv_retrieve_meta with multiple keys keys = ["key_a", "key_b", "key_c"] - metadata = await client.async_kv_retrieve_keys( + metadata = await client.async_kv_retrieve_meta( keys=keys, partition_id="test_partition", create=True, @@ -1007,19 +1046,19 @@ async def test_async_kv_retrieve_keys_multiple(self, client_setup): assert metadata.size == 3 @pytest.mark.asyncio - async def test_async_kv_retrieve_keys_create_false(self, client_setup): - """Test async_kv_retrieve_keys with create=False (retrieve existing keys).""" + async def test_async_kv_retrieve_meta_create_false(self, client_setup): + """Test async_kv_retrieve_meta with create=False (retrieve existing keys).""" client, _, _ = client_setup # create some keys - await client.async_kv_retrieve_keys( + await client.async_kv_retrieve_meta( keys="existing_key", partition_id="existing_partition", create=True, ) # Then retrieve them with create=False - metadata = await client.async_kv_retrieve_keys( + metadata = await client.async_kv_retrieve_meta( keys="existing_key", partition_id="existing_partition", create=False, @@ -1030,13 +1069,13 @@ async def test_async_kv_retrieve_keys_create_false(self, client_setup): assert metadata.size == 1 @pytest.mark.asyncio - async def test_async_kv_retrieve_keys_invalid_keys_type(self, client_setup): - """Test async_kv_retrieve_keys raises error with invalid keys type.""" + async def test_async_kv_retrieve_meta_invalid_keys_type(self, client_setup): + """Test async_kv_retrieve_meta raises error with invalid keys type.""" client, _, _ = client_setup # Test with invalid keys type (not string or list) with pytest.raises(TypeError): - await client.async_kv_retrieve_keys( + await client.async_kv_retrieve_meta( keys=123, # Invalid type partition_id="test_partition", create=True, @@ -1048,7 +1087,7 @@ async def test_async_kv_list_with_keys(self, client_setup): client, mock_controller, _ = client_setup # First register some keys - await client.async_kv_retrieve_keys( + await client.async_kv_retrieve_meta( keys=["key_1", "key_2"], partition_id="kv_partition", create=True, @@ -1069,12 +1108,12 @@ async def test_async_kv_list_multiple_partitions(self, client_setup): client, _, _ = client_setup # Create keys in different partitions - await client.async_kv_retrieve_keys( + await client.async_kv_retrieve_meta( keys="partition_a_key", partition_id="partition_a", create=True, ) - await client.async_kv_retrieve_keys( + await client.async_kv_retrieve_meta( keys="partition_b_key", partition_id="partition_b", create=True, @@ -1096,8 +1135,8 @@ async def test_async_kv_list_multiple_partitions(self, client_setup): assert list(partition_a["partition_a"].values()) == [{}] assert list(partition_b["partition_b"].values()) == [{}] - def test_kv_retrieve_keys_type_validation(self, client_setup): - """Test synchronous kv_retrieve_keys type validation.""" + def test_kv_retrieve_meta_type_validation(self, client_setup): + """Test synchronous kv_retrieve_meta type validation.""" import asyncio client, _, _ = client_setup @@ -1105,10 +1144,203 @@ def test_kv_retrieve_keys_type_validation(self, client_setup): # Test with non-string element in list async def test_invalid_list(): with pytest.raises(TypeError): - await client.async_kv_retrieve_keys( + await client.async_kv_retrieve_meta( keys=["valid_key", 123], # Invalid: 123 is not a string partition_id="test_partition", create=True, ) asyncio.run(test_invalid_list()) + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_single(self, client_setup): + """Test async_kv_retrieve_keys with single global_index.""" + client, _, _ = client_setup + partition_id = "test_partition_idx" + + # First create a key using kv_retrieve_meta + await client.async_kv_retrieve_meta( + keys=["test_key"], + partition_id=partition_id, + create=True, + ) + + # Now retrieve the key using global_index 0 + keys = await client.async_kv_retrieve_keys( + global_indexes=[0], + partition_id=partition_id, + ) + + assert keys == ["test_key"] + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_multiple(self, client_setup): + """Test async_kv_retrieve_keys with multiple global_indexes.""" + client, _, _ = client_setup + partition_id = "test_partition_idx" + + # First create keys using kv_retrieve_meta + keys_to_create = ["key_a", "key_b", "key_c"] + await client.async_kv_retrieve_meta( + keys=keys_to_create, + partition_id=partition_id, + create=True, + ) + + # Retrieve keys using global_indexes [0, 1, 2] + keys = await client.async_kv_retrieve_keys( + global_indexes=[0, 1, 2], + partition_id=partition_id, + ) + + assert keys == ["key_a", "key_b", "key_c"] + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_partial(self, client_setup): + """Test async_kv_retrieve_keys with subset of global_indexes.""" + client, _, _ = client_setup + partition_id = "test_partition_idx" + + # First create keys using kv_retrieve_meta + await client.async_kv_retrieve_meta( + keys=["first_key", "second_key", "third_key"], + partition_id=partition_id, + create=True, + ) + + # Retrieve only first and third keys + keys = await client.async_kv_retrieve_keys( + global_indexes=[0, 2], + partition_id=partition_id, + ) + + assert keys == ["first_key", "third_key"] + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_single_int(self, client_setup): + """Test async_kv_retrieve_keys accepts a single int.""" + client, _, _ = client_setup + partition_id = "test_partition_idx" + + # First create a key using kv_retrieve_meta + await client.async_kv_retrieve_meta( + keys=["single_key"], + partition_id=partition_id, + create=True, + ) + + # Now retrieve the key using a single int (not a list) + keys = await client.async_kv_retrieve_keys( + global_indexes=0, + partition_id=partition_id, + ) + + assert keys == ["single_key"] + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_invalid_type(self, client_setup): + """Test async_kv_retrieve_keys raises error with invalid global_indexes type.""" + client, _, _ = client_setup + + # Test with invalid type (string instead of int) + with pytest.raises(TypeError): + await client.async_kv_retrieve_keys( + global_indexes=["not_an_int"], + partition_id="test_partition", + ) + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_empty_list(self, client_setup): + """Test async_kv_retrieve_keys raises error with empty list.""" + client, _, _ = client_setup + + with pytest.raises(ValueError): + await client.async_kv_retrieve_keys( + global_indexes=[], + partition_id="test_partition", + ) + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_non_existent(self, client_setup): + """Test async_kv_retrieve_keys returns None for non-existent global_indexes.""" + client, _, _ = client_setup + partition_id = "test_partition_idx" + + # First create a key using kv_retrieve_meta + await client.async_kv_retrieve_meta( + keys=["existing_key"], + partition_id=partition_id, + create=True, + ) + + # Try to retrieve a non-existent global_index + keys = await client.async_kv_retrieve_keys( + global_indexes=[99], + partition_id=partition_id, + ) + assert keys == [None] + + @pytest.mark.asyncio + async def test_async_kv_retrieve_keys_multiple_partitions(self, client_setup): + """Test async_kv_retrieve_keys returns keys from the correct partition.""" + client, _, _ = client_setup + partition_1 = "partition_1" + partition_2 = "partition_2" + + # Create keys in both partitions + await client.async_kv_retrieve_meta( + keys=["key_1"], + partition_id=partition_1, + create=True, + ) + await client.async_kv_retrieve_meta( + keys=["key_2"], + partition_id=partition_2, + create=True, + ) + + # Retrieve key from partition_1 (global_index 0) + keys_1 = await client.async_kv_retrieve_keys( + global_indexes=[0], + partition_id=partition_1, + ) + + # Retrieve key from partition_2 (global_index 0) + keys_2 = await client.async_kv_retrieve_keys( + global_indexes=[0], + partition_id=partition_2, + ) + + assert keys_1 == ["key_1"] + assert keys_2 == ["key_2"] + + def test_kv_retrieve_keys_sync(self, client_setup): + """Test synchronous kv_retrieve_keys.""" + client, _, _ = client_setup + partition_id = "test_partition_sync" + + # First create a key using kv_retrieve_meta + client.kv_retrieve_meta( + keys=["sync_key"], + partition_id=partition_id, + create=True, + ) + + # Now retrieve the key using global_index + keys = client.kv_retrieve_keys( + global_indexes=[0], + partition_id=partition_id, + ) + + assert keys == ["sync_key"] + + def test_kv_retrieve_keys_type_validation(self, client_setup): + """Test synchronous kv_retrieve_keys type validation.""" + client, _, _ = client_setup + + # Test with non-int element in list + with pytest.raises(TypeError): + client.kv_retrieve_keys( + global_indexes=[0, "invalid"], + partition_id="test_partition", + ) diff --git a/tests/test_controller.py b/tests/test_controller.py index 3528d1aa..77565bce 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -765,18 +765,18 @@ def test_controller_with_custom_meta(self, ray_setup): class TestTransferQueueControllerKvInterface: """End-to-end tests for TransferQueueController KV interface functionality. - Tests for kv_retrieve_keys method that supports key-value interface operations + Tests for kv_retrieve_meta method that supports key-value interface operations across the controller and partition layers. """ - def test_controller_kv_retrieve_keys_create_mode(self, ray_setup): - """Test kv_retrieve_keys with create=True creates new keys in partition.""" + def test_controller_kv_retrieve_meta_create_mode(self, ray_setup): + """Test kv_retrieve_meta with create=True creates new keys in partition.""" tq_controller = TransferQueueController.remote() partition_id = "kv_test_partition" # Retrieve keys with create=True - should create partition and keys keys = ["key_a", "key_b", "key_c"] - metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + metadata = ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True)) # Verify partition was created partitions = ray.get(tq_controller.list_partitions.remote()) @@ -797,72 +797,72 @@ def test_controller_kv_retrieve_keys_create_mode(self, ray_setup): assert partition.revert_keys_mapping[metadata.global_indexes[1]] == "key_b" assert partition.revert_keys_mapping[metadata.global_indexes[2]] == "key_c" - print("✓ kv_retrieve_keys with create=True creates keys correctly") + print("✓ kv_retrieve_meta with create=True creates keys correctly") # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) - def test_controller_kv_retrieve_keys_existing_keys(self, ray_setup): - """Test kv_retrieve_keys retrieves existing keys correctly.""" + def test_controller_kv_retrieve_meta_existing_keys(self, ray_setup): + """Test kv_retrieve_meta retrieves existing keys correctly.""" tq_controller = TransferQueueController.remote() partition_id = "kv_existing_test" # First, create some keys keys = ["existing_key_1", "existing_key_2"] - ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True)) # Retrieve the same keys again (should return existing) retrieved_metadata = ray.get( - tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False) ) # Verify the same global_indexes are returned assert len(retrieved_metadata.global_indexes) == len(keys) - print("✓ kv_retrieve_keys retrieves existing keys correctly") + print("✓ kv_retrieve_meta retrieves existing keys correctly") # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) - def test_controller_kv_retrieve_keys_non_existent_without_create(self, ray_setup): - """Test kv_retrieve_keys raises error for non-existent keys without create.""" + def test_controller_kv_retrieve_meta_non_existent_without_create(self, ray_setup): + """Test kv_retrieve_meta raises error for non-existent keys without create.""" tq_controller = TransferQueueController.remote() partition_id = "kv_nonexistent_test" # Create partition first - ray.get(tq_controller.kv_retrieve_keys.remote(keys=["initial_key"], partition_id=partition_id, create=True)) + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["initial_key"], partition_id=partition_id, create=True)) # Try to retrieve non-existent key without create batch_meta = ray.get( - tq_controller.kv_retrieve_keys.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False) + tq_controller.kv_retrieve_meta.remote(keys=["nonexistent_key"], partition_id=partition_id, create=False) ) assert batch_meta.size == 0 - print("✓ kv_retrieve_keys return an empty BatchMeta for non-existent keys without create") + print("✓ kv_retrieve_meta return an empty BatchMeta for non-existent keys without create") # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) - def test_controller_kv_retrieve_keys_empty_partition_without_create(self, ray_setup): - """Test kv_retrieve_keys raises error for non-existent partition without create.""" + def test_controller_kv_retrieve_meta_empty_partition_without_create(self, ray_setup): + """Test kv_retrieve_meta raises error for non-existent partition without create.""" tq_controller = TransferQueueController.remote() partition_id = "nonexistent_partition" batch_meta = ray.get( - tq_controller.kv_retrieve_keys.remote(keys=["key_1"], partition_id=partition_id, create=False) + tq_controller.kv_retrieve_meta.remote(keys=["key_1"], partition_id=partition_id, create=False) ) assert batch_meta.size == 0 - print("✓ kv_retrieve_keys return an empty BatchMeta for non-existent partition_id without create") + print("✓ kv_retrieve_meta return an empty BatchMeta for non-existent partition_id without create") - def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup): - """Test kv_retrieve_keys works with production status update.""" + def test_controller_kv_retrieve_meta_with_production_status(self, ray_setup): + """Test kv_retrieve_meta works with production status update.""" tq_controller = TransferQueueController.remote() partition_id = "kv_production_test" # Create keys keys = ["sample_1", "sample_2", "sample_3"] - metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + metadata = ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True)) global_indexes = metadata.global_indexes # Update production status @@ -881,7 +881,7 @@ def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup): # Retrieve keys again (should include production info) retrieved_metadata = ray.get( - tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False) ) # Verify production status is available @@ -891,19 +891,19 @@ def test_controller_kv_retrieve_keys_with_production_status(self, ray_setup): assert sample.fields["data"].dtype == "torch.float32" assert sample.fields["data"].shape == (64,) - print("✓ kv_retrieve_keys works with production status") + print("✓ kv_retrieve_meta works with production status") # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) - def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup): - """Test kv_retrieve_keys preserves custom_meta through retrieve.""" + def test_controller_kv_retrieve_meta_with_custom_meta(self, ray_setup): + """Test kv_retrieve_meta preserves custom_meta through retrieve.""" tq_controller = TransferQueueController.remote() partition_id = "kv_custom_meta_test" # Create keys keys = ["key_1", "key_2"] - metadata = ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=True)) + metadata = ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True)) # Set custom_meta custom_meta = { @@ -916,7 +916,7 @@ def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup): # Retrieve keys and verify custom_meta retrieved_metadata = ray.get( - tq_controller.kv_retrieve_keys.remote(keys=keys, partition_id=partition_id, create=False) + tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=False) ) # Verify custom_meta is preserved @@ -925,7 +925,7 @@ def test_controller_kv_retrieve_keys_with_custom_meta(self, ray_setup): assert all_custom_meta[0]["score"] == 0.9 assert all_custom_meta[1]["tag"] == "B" - print("✓ kv_retrieve_keys preserves custom_meta") + print("✓ kv_retrieve_meta preserves custom_meta") # Clean up ray.get(tq_controller.clear_partition.remote(partition_id)) @@ -937,12 +937,12 @@ def test_controller_kv_interface_multiple_partitions(self, ray_setup): # Create keys in partition 1 partition_1 = "partition_kv_1" keys_1 = ["p1_key_a", "p1_key_b"] - ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_1, partition_id=partition_1, create=True)) + ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys_1, partition_id=partition_1, create=True)) # Create keys in partition 2 partition_2 = "partition_kv_2" keys_2 = ["p2_key_x", "p2_key_y", "p2_key_z"] - ray.get(tq_controller.kv_retrieve_keys.remote(keys=keys_2, partition_id=partition_2, create=True)) + ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys_2, partition_id=partition_2, create=True)) # Verify partitions are isolated partition_1_snapshot = ray.get(tq_controller.get_partition_snapshot.remote(partition_1)) @@ -962,3 +962,103 @@ def test_controller_kv_interface_multiple_partitions(self, ray_setup): # Clean up ray.get(tq_controller.clear_partition.remote(partition_1)) ray.get(tq_controller.clear_partition.remote(partition_2)) + + def test_controller_kv_retrieve_keys_basic(self, ray_setup): + """Test kv_retrieve_keys retrieves keys from global_indexes.""" + tq_controller = TransferQueueController.remote() + partition_id = "partition_retrieve_idx" + keys = ["test_key_a", "test_key_b", "test_key_c"] + + # First create keys using kv_retrieve_meta + ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True)) + + # Now retrieve keys using global_indexes [0, 1, 2] + retrieved_keys = ray.get( + tq_controller.kv_retrieve_keys.remote(global_indexes=[0, 1, 2], partition_id=partition_id) + ) + + assert retrieved_keys == ["test_key_a", "test_key_b", "test_key_c"] + print("✓ kv_retrieve_keys retrieves keys correctly") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_partial(self, ray_setup): + """Test kv_retrieve_keys retrieves subset of keys.""" + tq_controller = TransferQueueController.remote() + partition_id = "partition_retrieve_partial" + + # Create keys using kv_retrieve_meta + keys = ["key_0", "key_1", "key_2", "key_3", "key_4"] + ray.get(tq_controller.kv_retrieve_meta.remote(keys=keys, partition_id=partition_id, create=True)) + + # Retrieve only first and last keys + retrieved_keys = ray.get( + tq_controller.kv_retrieve_keys.remote(global_indexes=[0, 4], partition_id=partition_id) + ) + + assert retrieved_keys == ["key_0", "key_4"] + print("✓ kv_retrieve_keys retrieves subset correctly") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_single_int(self, ray_setup): + """Test kv_retrieve_keys with list containing single element.""" + tq_controller = TransferQueueController.remote() + partition_id = "partition_single_int" + + # Create key using kv_retrieve_meta + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["single_key"], partition_id=partition_id, create=True)) + + # Retrieve using list with single int + retrieved_keys = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[0], partition_id=partition_id)) + + assert retrieved_keys == ["single_key"] + print("✓ kv_retrieve_keys works with list containing single element") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_nonexistent(self, ray_setup): + """Test kv_retrieve_keys handles non-existent global_indexes.""" + tq_controller = TransferQueueController.remote() + partition_id = "partition_nonexistent" + + # Create keys using kv_retrieve_meta + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["existing_key"], partition_id=partition_id, create=True)) + + # Try to retrieve non-existent global_index + result = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[99], partition_id=partition_id)) + + # Should return list with None when global_index doesn't exist + assert result == [None] + print("✓ kv_retrieve_keys handles non-existent indexes") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_id)) + + def test_controller_kv_retrieve_keys_multiple_partitions(self, ray_setup): + """Test kv_retrieve_keys respects partition isolation.""" + tq_controller = TransferQueueController.remote() + partition_1 = "partition_idx_1" + partition_2 = "partition_idx_2" + + # Create keys in both partitions + # Note: global_index is global across partitions, so p2_key will have global_index=1 + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["p1_key"], partition_id=partition_1, create=True)) + ray.get(tq_controller.kv_retrieve_meta.remote(keys=["p2_key"], partition_id=partition_2, create=True)) + + # Retrieve from partition_1 (global_index=0) + keys_1 = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[0], partition_id=partition_1)) + + # Retrieve from partition_2 (global_index=1) + keys_2 = ray.get(tq_controller.kv_retrieve_keys.remote(global_indexes=[1], partition_id=partition_2)) + + assert keys_1 == ["p1_key"] + assert keys_2 == ["p2_key"] + print("✓ kv_retrieve_keys maintains partition isolation") + + # Clean up + ray.get(tq_controller.clear_partition.remote(partition_1)) + ray.get(tq_controller.clear_partition.remote(partition_2)) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 6eba7a08..1cac3fef 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -946,12 +946,12 @@ def test_custom_meta_cleared_with_data(self): class TestDataPartitionStatusKvInterface: """Unit tests for DataPartitionStatus KV interface functionality. - Tests for the keys_mapping and kv_retrieve_keys methods that support + Tests for the keys_mapping and kv_retrieve_meta methods that support key-value interface operations within a partition. """ - def test_kv_retrieve_keys_with_existing_keys(self): - """Test kv_retrieve_keys returns correct global_indexes for existing keys.""" + def test_kv_retrieve_meta_with_existing_keys(self): + """Test kv_retrieve_meta returns correct global_indexes for existing keys.""" from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="kv_test_partition") @@ -960,12 +960,12 @@ def test_kv_retrieve_keys_with_existing_keys(self): partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2} # Retrieve keys - global_indexes = partition.kv_retrieve_keys(["key_a", "key_b", "key_c"]) + global_indexes = partition.kv_retrieve_indexes(["key_a", "key_b", "key_c"]) assert global_indexes == [0, 1, 2] - def test_kv_retrieve_keys_with_nonexistent_keys(self): - """Test kv_retrieve_keys returns None for keys that don't exist.""" + def test_kv_retrieve_meta_with_nonexistent_keys(self): + """Test kv_retrieve_meta returns None for keys that don't exist.""" from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="kv_test_partition") @@ -974,22 +974,22 @@ def test_kv_retrieve_keys_with_nonexistent_keys(self): partition.keys_mapping = {"existing_key": 5} # Retrieve mixed existing and non-existing keys - global_indexes = partition.kv_retrieve_keys(["existing_key", "nonexistent_key"]) + global_indexes = partition.kv_retrieve_indexes(["existing_key", "nonexistent_key"]) assert global_indexes == [5, None] - def test_kv_retrieve_keys_empty_list(self): - """Test kv_retrieve_keys handles empty key list.""" + def test_kv_retrieve_meta_empty_list(self): + """Test kv_retrieve_meta handles empty key list.""" from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="kv_test_partition") - global_indexes = partition.kv_retrieve_keys([]) + global_indexes = partition.kv_retrieve_indexes([]) assert global_indexes == [] - def test_kv_retrieve_keys_partial_match(self): - """Test kv_retrieve_keys with partial key matches.""" + def test_kv_retrieve_meta_partial_match(self): + """Test kv_retrieve_meta with partial key matches.""" from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="kv_test_partition") @@ -997,6 +997,61 @@ def test_kv_retrieve_keys_partial_match(self): partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30} # Request only some of the keys - global_indexes = partition.kv_retrieve_keys(["key_1", "key_3"]) + global_indexes = partition.kv_retrieve_indexes(["key_1", "key_3"]) assert global_indexes == [10, 30] + + def test_kv_retrieve_keys_with_existing_indexes(self): + """Test kv_retrieve_keys returns correct keys for existing global_indexes.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + # Simulate reverse mapping (key -> global_index) + partition.keys_mapping = {"key_a": 0, "key_b": 1, "key_c": 2} + # Build reverse mapping + partition.revert_keys_mapping = {0: "key_a", 1: "key_b", 2: "key_c"} + + # Retrieve keys using global_indexes + keys = partition.kv_retrieve_keys([0, 1, 2]) + + assert keys == ["key_a", "key_b", "key_c"] + + def test_kv_retrieve_keys_with_nonexistent_indexes(self): + """Test kv_retrieve_keys returns None for global_indexes that don't exist.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + # Simulate some indexes being registered + partition.keys_mapping = {"existing_key": 5} + partition.revert_keys_mapping = {5: "existing_key"} + + # Retrieve mixed existing and non-existing global_indexes + keys = partition.kv_retrieve_keys([5, 99]) + + assert keys == ["existing_key", None] + + def test_kv_retrieve_keys_empty_list(self): + """Test kv_retrieve_keys handles empty global_index list.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + keys = partition.kv_retrieve_keys([]) + + assert keys == [] + + def test_kv_retrieve_keys_partial_match(self): + """Test kv_retrieve_keys with partial global_index matches.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="kv_test_partition") + + partition.keys_mapping = {"key_1": 10, "key_2": 20, "key_3": 30} + partition.revert_keys_mapping = {10: "key_1", 20: "key_2", 30: "key_3"} + + # Request only some of the global_indexes + keys = partition.kv_retrieve_keys([10, 30]) + + assert keys == ["key_1", "key_3"] diff --git a/transfer_queue/client.py b/transfer_queue/client.py index b06f90c7..05fac154 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -914,7 +914,7 @@ async def async_get_partition_list( # ==================== KV Interface API ==================== @dynamic_socket(socket_name="request_handle_socket") - async def async_kv_retrieve_keys( + async def async_kv_retrieve_meta( self, keys: list[str] | str, partition_id: str, @@ -948,7 +948,7 @@ async def async_kv_retrieve_keys( raise TypeError("Only string or list of strings are allowed as `keys`.") request_msg = ZMQMessage.create( - request_type=ZMQRequestType.KV_RETRIEVE_KEYS, # type: ignore[arg-type] + request_type=ZMQRequestType.KV_RETRIEVE_META, # type: ignore[arg-type] sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -968,7 +968,7 @@ async def async_kv_retrieve_keys( f"from controller {self._controller.id}" ) - if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE: + if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_META_RESPONSE: metadata = response_msg.body.get("metadata", BatchMeta.empty()) metadata = BatchMeta.from_dict(metadata) if isinstance(metadata, dict) else metadata return metadata @@ -980,6 +980,69 @@ async def async_kv_retrieve_keys( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e + @dynamic_socket(socket_name="request_handle_socket") + async def async_kv_retrieve_keys( + self, + global_indexes: list[int] | int, + partition_id: str, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> list[str]: + """Asynchronously retrieve keys according to global_indexes from the controller. + + Args: + global_indexes: List of global_indexes to retrieve from the controller + partition_id: The ID of the logical partition to search for global_indexes. + socket: ZMQ socket (injected by decorator) + + Returns: + keys: list of keys of the corresponding global_indexes + + Raises: + TypeError: If `global_indexes` is not a list of int or an int + RuntimeError: If some indexes in `global_indexes` do not have corresponding keys + """ + + if isinstance(global_indexes, int): + global_indexes = [global_indexes] + elif isinstance(global_indexes, list): + if len(global_indexes) < 1: + raise ValueError("Received an empty list as `global_indexes`.") + # validate all the elements are int + if not all(isinstance(idx, int) for idx in global_indexes): + raise TypeError("Not all elements in `global_indexes` are int.") + else: + raise TypeError("Only int or list of int are allowed as `global_indexes`.") + + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_RETRIEVE_KEYS, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={"global_indexes": global_indexes, "partition_id": partition_id}, + ) + + try: + assert socket is not None + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart() + response_msg = ZMQMessage.deserialize(response_serialized) + logger.debug( + f"[{self.client_id}]: Client get kv_retrieve_indexes response: {response_msg} " + f"from controller {self._controller.id}" + ) + + if response_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE: + keys = response_msg.body.get("keys", []) + if len(keys) != len(global_indexes): + raise RuntimeError("Some global_indexes have no corresponding keys!") + return keys + else: + raise RuntimeError( + f"[{self.client_id}]: Failed to retrieve indexes from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e + @dynamic_socket(socket_name="request_handle_socket") async def async_kv_list( self, @@ -1112,6 +1175,7 @@ def wrapper(*args, **kwargs): self._get_partition_list = _make_sync(self.async_get_partition_list) self._set_custom_meta = _make_sync(self.async_set_custom_meta) self._reset_consumption = _make_sync(self.async_reset_consumption) + self._kv_retrieve_meta = _make_sync(self.async_kv_retrieve_meta) self._kv_retrieve_keys = _make_sync(self.async_kv_retrieve_keys) self._kv_list = _make_sync(self.async_kv_list) @@ -1458,7 +1522,7 @@ def get_partition_list( return self._get_partition_list() # ==================== KV Interface API ==================== - def kv_retrieve_keys( + def kv_retrieve_meta( self, keys: list[str] | str, partition_id: str, @@ -1478,7 +1542,28 @@ def kv_retrieve_keys( TypeError: If `keys` is not a list of string or a string """ - return self._kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) + return self._kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create) + + def kv_retrieve_keys( + self, + global_indexes: list[int] | int, + partition_id: str, + ) -> BatchMeta: + """Synchronously retrieve keys according to global_indexes from the controller. + + Args: + global_indexes: List of global_indexes to retrieve from the controller + partition_id: The ID of the logical partition to search for global_indexes. + + Returns: + keys: list of keys of the corresponding global_indexes + + Raises: + TypeError: If `global_indexes` is not a list of int or an int + RuntimeError: If some indexes in `global_indexes` do not have corresponding keys + """ + + return self._kv_retrieve_keys(global_indexes=global_indexes, partition_id=partition_id) def kv_list( self, diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index fdc0840e..15ff873c 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -845,11 +845,16 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr f"Attempted to clear global_indexes: {indexes_to_release}" ) - def kv_retrieve_keys(self, keys: list[str]) -> list[int | None]: + def kv_retrieve_indexes(self, keys: list[str]) -> list[int | None]: """Translate the user-specified keys to global_indexes""" global_indexes = [self.keys_mapping.get(k, None) for k in keys] return global_indexes + def kv_retrieve_keys(self, global_indexes: list[int]) -> list[str | None]: + """Translate the global_indexes to keys""" + keys = [self.revert_keys_mapping.get(idx, None) for idx in global_indexes] + return keys + @ray.remote(num_cpus=1) class TransferQueueController: @@ -1454,7 +1459,7 @@ def clear_meta( # Release the specific indexes from index manager self.index_manager.release_indexes(partition_id, global_indexes_to_clear) - def kv_retrieve_keys( + def kv_retrieve_meta( self, keys: list[str], partition_id: str, @@ -1485,12 +1490,12 @@ def kv_retrieve_keys( partition = self._get_partition(partition_id) assert partition is not None - global_indexes = partition.kv_retrieve_keys(keys) + global_indexes = partition.kv_retrieve_indexes(keys) none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None] if len(none_indexes) > 0: if not create: - logger.warning(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!") + logger.error(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!") return BatchMeta.empty() else: # create non-exist keys @@ -1530,6 +1535,42 @@ def kv_retrieve_keys( return metadata + def kv_retrieve_keys( + self, + global_indexes: list[int], + partition_id: str, + ) -> list[Optional[str]]: + """ + Retrieve keys from the controller using a list of global_indexes. + + Args: + global_indexes: List of global_indexes to retrieve keys from the controller + partition_id: Partition id to retrieve from the controller + + Returns: + metadata: BatchMeta of the requested keys + """ + + logger.debug(f"[{self.controller_id}]: Retrieve global_indexes {global_indexes} in partition {partition_id}") + + partition = self._get_partition(partition_id) + + if partition is None: + logger.warning(f"Partition {partition_id} were not found in controller!") + return [] + + assert partition is not None + keys = partition.kv_retrieve_keys(global_indexes) + + none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None] + if len(none_indexes) > 0: + logger.error( + f"Key for global_index {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!" + ) + return [] + + return keys + def _init_zmq_socket(self): """Initialize ZMQ sockets for communication.""" self.zmq_context = zmq.Context() @@ -1825,21 +1866,35 @@ def _process_request(self): body={"partition_ids": partition_ids}, ) - elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: - with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"): + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_META: + with perf_monitor.measure(op_type="KV_RETRIEVE_META"): params = request_msg.body keys = params["keys"] partition_id = params["partition_id"] create = params["create"] - metadata = self.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=create) + metadata = self.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=create) response_msg = ZMQMessage.create( - request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE, + request_type=ZMQRequestType.KV_RETRIEVE_META_RESPONSE, sender_id=self.controller_id, receiver_id=request_msg.sender_id, body={"metadata": metadata}, ) + elif request_msg.request_type == ZMQRequestType.KV_RETRIEVE_KEYS: + with perf_monitor.measure(op_type="KV_RETRIEVE_KEYS"): + params = request_msg.body + global_indexes = params["global_indexes"] + partition_id = params["partition_id"] + + keys = self.kv_retrieve_keys(global_indexes=global_indexes, partition_id=partition_id) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.KV_RETRIEVE_KEYS_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={"keys": keys}, + ) + elif request_msg.request_type == ZMQRequestType.KV_LIST: with perf_monitor.measure(op_type="KV_LIST"): params = request_msg.body diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 23d0bc9f..0bfc5916 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -288,7 +288,7 @@ def kv_put( tq_client = _maybe_create_transferqueue_client() # 1. translate user-specified key to BatchMeta - batch_meta = tq_client.kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + batch_meta = tq_client.kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True) if batch_meta.size != 1: raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") @@ -364,7 +364,7 @@ def kv_batch_put( tq_client = _maybe_create_transferqueue_client() # 1. translate user-specified key to BatchMeta - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) + batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True) if batch_meta.size != len(keys): raise RuntimeError( @@ -416,7 +416,7 @@ def kv_batch_get(keys: list[str] | str, partition_id: str, fields: Optional[list """ tq_client = _maybe_create_transferqueue_client() - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size == 0: raise RuntimeError("keys or partition were not found!") @@ -497,7 +497,7 @@ def kv_clear(keys: list[str] | str, partition_id: str) -> None: keys = [keys] tq_client = _maybe_create_transferqueue_client() - batch_meta = tq_client.kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + batch_meta = tq_client.kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size > 0: tq_client.clear_samples(batch_meta) @@ -548,7 +548,7 @@ async def async_kv_put( tq_client = _maybe_create_transferqueue_client() # 1. translate user-specified key to BatchMeta - batch_meta = await tq_client.async_kv_retrieve_keys(keys=[key], partition_id=partition_id, create=True) + batch_meta = await tq_client.async_kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True) if batch_meta.size != 1: raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!") @@ -623,7 +623,7 @@ async def async_kv_batch_put( tq_client = _maybe_create_transferqueue_client() # 1. translate user-specified key to BatchMeta - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=True) + batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=True) if batch_meta.size != len(keys): raise RuntimeError( @@ -677,7 +677,7 @@ async def async_kv_batch_get( """ tq_client = _maybe_create_transferqueue_client() - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size == 0: raise RuntimeError("keys or partition were not found!") @@ -759,7 +759,7 @@ async def async_kv_clear(keys: list[str] | str, partition_id: str) -> None: keys = [keys] tq_client = _maybe_create_transferqueue_client() - batch_meta = await tq_client.async_kv_retrieve_keys(keys=keys, partition_id=partition_id, create=False) + batch_meta = await tq_client.async_kv_retrieve_meta(keys=keys, partition_id=partition_id, create=False) if batch_meta.size > 0: await tq_client.async_clear_samples(batch_meta) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index eaaf65e4..7d571a5e 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -102,6 +102,8 @@ class ZMQRequestType(ExplicitEnum): NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR" # KV_INTERFACE + KV_RETRIEVE_META = "KV_RETRIEVE_META" + KV_RETRIEVE_META_RESPONSE = "KV_RETRIEVE_META_RESPONSE" KV_RETRIEVE_KEYS = "KV_RETRIEVE_KEYS" KV_RETRIEVE_KEYS_RESPONSE = "KV_RETRIEVE_KEYS_RESPONSE" KV_LIST = "KV_LIST"