diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 7cd24fc5..31c8e975 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -33,6 +33,7 @@ jobs: python -m pip install --upgrade pip pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -e ".[test,build,yuanrong]" + pip install mooncake-transfer-engine-non-cuda - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -43,11 +44,10 @@ jobs: run: | python -m build --wheel pip install dist/*.whl --force-reinstall - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False) + - name: Test with pytest run: | pytest tests - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True) - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True - pytest tests \ No newline at end of file + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py + pkill -f "mooncake_master" + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py + pkill -f "mooncake_master" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 35d65242..1fba227f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,9 @@ test = [ yuanrong = [ "openyuanrong-datasystem" ] +mooncake = [ + "mooncake-transfer-engine" +] # If you need to mimic `package_dir={'': '.'}`: [tool.setuptools.package-dir] diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 625bef9d..39b0b91e 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""E2E lifecycle consistency tests for TransferQueue.""" - +import os import sys import time from pathlib import Path @@ -23,6 +22,7 @@ import pytest import ray import torch +from omegaconf import OmegaConf from tensordict import TensorDict from tensordict.tensorclass import NonTensorData @@ -48,6 +48,38 @@ "non_tensor_stack", ] +# Backend configurations for E2E tests +BACKEND_CONFIGS = { + "SimpleStorage": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + }, + "MooncakeStore": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "MooncakeStore", + "MooncakeStore": { + "global_segment_size": 134217728, # 128MB + "local_buffer_size": 134217728, # 128MB + "metadata_server": "localhost:50050", + "master_server_address": "localhost:50051", + "protocol": "tcp", + "device_name": "", + }, + }, + }, +} + @pytest.fixture(scope="module") def ray_cluster(): @@ -59,24 +91,33 @@ def ray_cluster(): @pytest.fixture(scope="module") -def e2e_client(ray_cluster): - """Create a client using transfer_queue.init() for lifecycle testing.""" - from omegaconf import OmegaConf +def backend_name(): + """Get the backend name from environment variable. + + Environment variables: + TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore) + + To run tests for a specific backend: + TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py + """ + return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") + + +@pytest.fixture(scope="module") +def e2e_client(ray_cluster, backend_name): + """Create a client using transfer_queue.init() for lifecycle testing. + Args: + ray_cluster: Ray cluster fixture + backend_name: Backend name from TQ_TEST_BACKEND env var + """ import transfer_queue - config = { - "controller": { - "polling_mode": True, - }, - "backend": { - "storage_backend": "SimpleStorage", - "SimpleStorage": { - "total_storage_size": 200, - "num_data_storage_units": 2, - }, - }, - } + if backend_name not in BACKEND_CONFIGS: + raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}") + + config = BACKEND_CONFIGS[backend_name] transfer_queue.init(OmegaConf.create(config)) client = transfer_queue.get_client() yield client @@ -244,7 +285,7 @@ def verify_list_equal(retrieved, expected) -> bool: if isinstance(retrieved, NonTensorStack): retrieved = retrieved.tolist() elif isinstance(retrieved, torch.Tensor): - retrieved = retrieved.tolist() + retrieved = retrieved.reshape(-1).tolist() # may get 2D tensor back using key-value based backend if isinstance(expected, NonTensorStack): expected = expected.tolist() elif isinstance(expected, torch.Tensor): @@ -283,9 +324,21 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict: return TensorDict(reordered, batch_size=td.batch_size) +def recover_local_index(global_index_order, new_global_index_order): + value_to_new_index = {} + for idx, val in enumerate(new_global_index_order): + value_to_new_index[val] = idx + + local_index_order_to_recover = [] + for val in global_index_order: + local_index_order_to_recover.append(value_to_new_index[val]) + + return local_index_order_to_recover + + # Scenario One: Core Read/Write Consistency def test_core_consistency(e2e_client): - """Put full complex data then get — verify all field types are correctly round-tripped.""" + """Put full complex data then get - verify all field types are correctly round-tripped.""" client = e2e_client partition_id = "test_core_consistency" batch_size = 20 @@ -362,6 +415,12 @@ def test_core_consistency(e2e_client): # Scenario Two: Cross-Shard Update def test_cross_shard_complex_update(e2e_client): """Cross-shard update: put A + put B, update overlapping region, verify all regions.""" + + # FIXME: Add data update test to MooncakeStore after Upsert function is ready + # https://github.com/kvcache-ai/Mooncake/issues/1645 + if os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") == "MooncakeStore": + return + client = e2e_client partition_id = "test_cross_shard_update" task_name = "cross_shard_task" @@ -744,12 +803,19 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): indices = list(range(batch_size)) original_data = generate_complex_data(indices) - client.put(data=original_data, partition_id=partition_id) + original_meta = client.put(data=original_data, partition_id=partition_id) + global_index_order = original_meta.global_indexes try: # === Phase 1: Retrieve and verify writability === meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") assert meta is not None and meta.size == batch_size + + # the global_index_order in retrieved meta is different from the original one. + # we need to reorder first. + local_index_order = recover_local_index(global_index_order, meta.global_indexes) + meta = meta.select_samples(local_index_order) + retrieved = client.get_data(meta) # 1. tensor_f32: writable @@ -793,6 +859,12 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): # Re-retrieve the same data — modifications above should NOT have affected storage meta2 = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") assert meta2 is not None and meta2.size == batch_size + + # the global_index_order in retrieved meta is different from the original one. + # we need to reorder first. + local_index_order = recover_local_index(global_index_order, meta2.global_indexes) + meta2 = meta2.select_samples(local_index_order) + retrieved2 = client.get_data(meta2) # tensor_f32[0,0] should be the original value, not 99999.0 diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 71e9833a..0a248ea7 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -27,6 +27,7 @@ import pytest import ray import torch +from omegaconf import OmegaConf from tensordict import TensorDict # Add parent directory to path @@ -38,6 +39,40 @@ # Configure Ray for tests os.environ["RAY_DEDUP_LOGS"] = "0" +# Backend configurations for E2E tests +# Adjust values for GitHub CI environment (smaller memory footprint) +BACKEND_CONFIGS = { + "SimpleStorage": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + }, + "MooncakeStore": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "MooncakeStore", + "MooncakeStore": { + # Reduced memory sizes for CI/testing environment + "global_segment_size": 134217728, # 128MB + "local_buffer_size": 134217728, # 128MB + "metadata_server": os.environ.get("TQ_MOONCAKE_METADATA_SERVER", "localhost:50050"), + "master_server_address": os.environ.get("TQ_MOONCAKE_MASTER_SERVER", "localhost:50051"), + "protocol": "tcp", + "device_name": "", + }, + }, + }, +} + @pytest.fixture(scope="module") def ray_init(): @@ -50,9 +85,32 @@ def ray_init(): @pytest.fixture(scope="module") -def tq_system(ray_init): - """Initialize TransferQueue system for the test module.""" - tq.init() +def backend_name(): + """Get the backend name from environment variable. + + Environment variables: + TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore) + + To run tests for a specific backend: + TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_kv_interface_e2e.py + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py + """ + return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") + + +@pytest.fixture(scope="module") +def tq_system(ray_init, backend_name): + """Initialize TransferQueue system for the test module. + + Args: + ray_init: Ray cluster fixture + backend_name: Backend name from TQ_TEST_BACKEND env var + """ + if backend_name not in BACKEND_CONFIGS: + raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}") + + config = BACKEND_CONFIGS[backend_name] + tq.init(OmegaConf.create(config)) yield tq.close() @@ -109,6 +167,9 @@ def test_kv_put_with_dict_fields(self, controller): expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["data"], expected) + # delete the key (MooncakeStore does not support updating existing key, so we need to clear it before next test) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_with_tensordict_fields(self, controller): """Test kv_put with tensordict fields.""" partition_id = "test_partition" @@ -128,6 +189,8 @@ def test_kv_put_with_tensordict_fields(self, controller): expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["input_ids"], expected) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_single_sample_with_fields_and_tag(self, controller): """Test putting a single sample with fields and tag.""" partition_id = "test_partition" @@ -175,6 +238,8 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): assert_tensor_equal(retrieved["input_ids"], expected_input_ids) assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_update_tag_only(self, controller): """Test updating only tag without providing fields.""" partition_id = "test_partition" @@ -198,6 +263,8 @@ def test_kv_put_update_tag_only(self, controller): retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["value"], torch.tensor([[10]])) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_partial_update(self, controller): """Test adding new fields to existing sample.""" partition_id = "test_partition" @@ -232,6 +299,8 @@ def test_kv_put_partial_update(self, controller): # key should have response marked as produced assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response" + tq.kv_clear(keys=key, partition_id=partition_id) + class TestKVBatchPutE2E: """End-to-end tests for kv_batch_put functionality.""" @@ -282,6 +351,8 @@ def test_kv_batch_put_multiple_samples(self, controller): assert_tensor_equal(retrieved["input_ids"], batch_input_ids) assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask) + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_put_partial_update(self, controller): """Test adding new fields to existing samples.""" partition_id = "test_partition" @@ -320,6 +391,8 @@ def test_kv_batch_put_partial_update(self, controller): # keys[1] should have response marked as produced assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response" + tq.kv_clear(keys=keys, partition_id=partition_id) + class TestKVGetE2E: """End-to-end tests for kv_batch_get functionality.""" @@ -337,6 +410,8 @@ def test_kv_batch_get_single_key(self, controller): retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_batch_get_multiple_keys(self, controller): """Test getting data for multiple keys.""" partition_id = "test_partition" @@ -349,6 +424,8 @@ def test_kv_batch_get_multiple_keys(self, controller): retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_get_partial_keys(self, controller): """Test getting data for partial keys.""" partition_id = "test_partition" @@ -363,6 +440,8 @@ def test_kv_batch_get_partial_keys(self, controller): retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_get_partial_fields(self, controller): """Test getting only partial fields.""" partition_id = "test_partition" @@ -394,6 +473,8 @@ def test_kv_batch_get_partial_fields(self, controller): assert_tensor_equal(retrieved["input_ids"], input_ids) assert_tensor_equal(retrieved["response"], response) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_batch_get_nonexistent_key(self, controller): """Test that getting data for non-existent key returns empty result.""" partition_id = "test_partition" @@ -432,6 +513,8 @@ def test_kv_list_single_partition(self, controller): for i, (key, tag) in enumerate(partition_info["test_partition"].items()): assert tag["id"] == i + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_list_all_partitions(self, controller): """Test listing keys and tags in all partitions.""" partition_id = ["test_partition0", "test_partition1", "test_partition2"] @@ -488,6 +571,10 @@ def test_kv_list_all_partitions(self, controller): for i, (key, tag) in enumerate(partition_info["test_partition2"].items()): assert tag["id"] == i + 6 + tq.kv_clear(keys=keys_partition0, partition_id=partition_id[0]) + tq.kv_clear(keys=keys_partition1, partition_id=partition_id[1]) + tq.kv_clear(keys=keys_partition2, partition_id=partition_id[2]) + def test_kv_list_empty_partition(self): """Test listing empty partition.""" partition_id = "test_partition_empty" @@ -522,6 +609,8 @@ def test_kv_clear_single_key(self, controller): assert key not in partition.keys_mapping assert other_key in partition.keys_mapping + tq.kv_clear(keys=other_key, partition_id=partition_id) + def test_kv_clear_multiple_keys(self, controller): """Test clearing multiple keys.""" partition_id = "test_partition" @@ -541,6 +630,8 @@ def test_kv_clear_multiple_keys(self, controller): assert keys[2] in partition_info[partition_id] assert keys[3] in partition_info[partition_id] + tq.kv_clear(keys=keys[2:], partition_id=partition_id) + class TestKVE2ECornerCases: """End-to-end tests for corner cases.""" @@ -578,6 +669,8 @@ def test_field_expansion_across_samples(self, controller): assert "field_b" not in data assert "field_c" not in data + tq.kv_clear(keys=keys, partition_id=partition_id) + def run_tests(): """Run all e2e tests manually for debugging.""" diff --git a/tests/test_controller.py b/tests/test_controller.py index 74793bd4..2e559600 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -176,14 +176,14 @@ def test_controller_with_single_partition(self, ray_setup): # Test get clear meta clear_meta = ray.get( tq_controller.get_metadata.remote( - data_fields=[], + data_fields=gen_meta.field_names, partition_id=partition_id, - mode="insert", + mode="force_fetch", ) ) assert clear_meta.global_indexes == list(range(gbs * num_n_samples)) # In insert mode with no fields, field_schema should be empty - assert clear_meta.field_schema == {} or clear_meta.field_names == [] + assert clear_meta.field_names == gen_meta.field_names print("✓ Clear metadata correct") # Test clear_partition @@ -431,9 +431,9 @@ def test_controller_with_multi_partitions(self, ray_setup): # Test get clear meta clear_meta = ray.get( tq_controller.get_metadata.remote( - data_fields=[], + data_fields=gen_meta.field_names, partition_id=partition_id_1, - mode="insert", + mode="force_fetch", ) ) assert clear_meta diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index a5776924..904f4320 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -1057,10 +1057,11 @@ def test_remove_samples(self): fm = FieldMeta(is_nested=True) fm.per_sample_shapes = {0: (3,), 1: (5,), 2: (7,)} fm.remove_samples([0, 2]) - assert fm.per_sample_shapes == {1: (5,)} + assert fm.per_sample_shapes == {} + assert fm.shape == (5,) + assert not fm.is_nested # Removing non-existent index should not raise fm.remove_samples([99]) - assert fm.per_sample_shapes == {1: (5,)} def test_to_batch_schema_regular(self): from transfer_queue.controller import FieldMeta diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 199ceeda..235c9b07 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -231,7 +231,7 @@ async def async_get_meta( ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get_meta response: {response_msg} from controller {self._controller.id}" @@ -307,7 +307,7 @@ async def async_set_custom_meta( ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from controller {self._controller.id}" @@ -387,6 +387,13 @@ async def async_put( "Call initialize_storage_manager() before performing storage operations." ) + for field_name, field_data in data.items(): + if isinstance(field_data, torch.Tensor) and field_data.ndim == 1: + logger.warning( + f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. " + f"You may receive 2D tensors in key-value based backend." + ) + if metadata is None: if partition_id is None: raise ValueError("partition_id must be provided if metadata is not given") @@ -480,6 +487,10 @@ async def async_clear_partition(self, partition_id: str): metadata = await self._get_partition_meta(partition_id) + if not metadata: + logger.warning(f"Try to clear an non-exist partition {partition_id}. No action will be taken.") + return + # Clear the controller metadata await self._clear_partition_in_controller(partition_id) @@ -543,7 +554,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: @@ -571,7 +582,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE: @@ -599,7 +610,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE: @@ -650,7 +661,7 @@ async def async_get_consumption_status( try: await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get consumption response: {response_msg} " @@ -712,7 +723,7 @@ async def async_get_production_status( try: await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get production response: {response_msg} " @@ -844,7 +855,7 @@ async def async_reset_consumption( ) try: await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client reset consumption response: {response_msg} " @@ -890,7 +901,7 @@ async def async_get_partition_list( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get partition list response: {response_msg} " @@ -957,7 +968,7 @@ async def async_kv_retrieve_meta( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} " @@ -1018,7 +1029,7 @@ async def async_kv_retrieve_keys( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get kv_retrieve_indexes response: {response_msg} " @@ -1079,7 +1090,7 @@ async def async_kv_list( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}" diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index c0ddfe76..98819edd 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -24,10 +24,27 @@ backend: # ZMQ Server IP & Ports (automatically generated during init) zmq_info: null + # For MooncakeStore: + MooncakeStore: + # Whether to let TQ automatically init metadata_server. + auto_init: true + # Address of the HTTP metadata server + metadata_server: localhost:50050 + # Address of master server + master_server_address: localhost:50051 + # Address of local host. Set to "" to use Ray IP as local host address + local_hostname: "" + # Protocol for transmission. Choose from: tcp, rdma. (default: tcp) + protocol: tcp + # Memory segment size in bytes for mounting (default: 4GB) + global_segment_size: 4294967296 + # Local buffer size in bytes (default: 1GB) + local_buffer_size: 1073741824 + # Network device name. Set to "" to let Mooncake to auto-picks devices + device_name: "" + + # For RayStore: RayStore: # For Yuanrong: - # TODO - - # For MooncakeStore: # TODO \ No newline at end of file diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c74e515a..b0ef7572 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -251,6 +251,23 @@ def remove_samples(self, indexes: list[int]): for idx in indexes: self.per_sample_shapes.pop(idx, None) + # After removing samples, check if we can update is_nested and shape + # If per_sample_shapes is empty or all remaining shapes are the same, + # we should reset is_nested to False and update shape accordingly + if not self.per_sample_shapes: + # All samples removed - reset to non-nested state + self.is_nested = False + self.shape = None + else: + # Check if all remaining shapes are the same + remaining_shapes = set(self.per_sample_shapes.values()) + if len(remaining_shapes) == 1: + # All remaining samples have the same shape - update to non-nested + self.is_nested = False + self.shape = next(iter(remaining_shapes)) + # Clear per-sample shapes since we are no longer nested + self.per_sample_shapes.clear() + def to_batch_schema(self, batch_global_indexes: list[int]) -> dict[str, Any]: """Export as a BatchMeta.field_schema-compatible dict for generate_batch_meta.""" schema = { @@ -529,7 +546,24 @@ def _update_field_metadata( is_non_tensor=meta.get("is_non_tensor", False), ) else: + # Track if is_nested changed from False to True during update + was_not_nested = not self.field_metadata[field_name].is_nested + # Save old shape before update (for filling per_sample_shapes of existing samples) + old_shape = self.field_metadata[field_name].shape self.field_metadata[field_name].update(meta) + # If is_nested became True due to shape mismatch, capture shapes for all samples + if was_not_nested and self.field_metadata[field_name].is_nested: + col_meta = self.field_metadata[field_name] + new_shape = meta.get("shape") + # Fill new samples with new shape + if new_shape is not None: + for gi in global_indexes: + col_meta.per_sample_shapes[gi] = new_shape + # Fill existing samples with old shape + if old_shape is not None: + for gi in self.global_indexes: + if gi not in col_meta.per_sample_shapes: + col_meta.per_sample_shapes[gi] = old_shape # nested per-sample shapes per_sample_shapes = meta.get("per_sample_shapes") @@ -1214,31 +1248,31 @@ def get_metadata( self.create_partition(partition_id) partition = self._get_partition(partition_id) - if data_fields: - # This is called during put_data call without providing metadata. - # try to use pre-allocated global index first + if data_fields is None: + raise RuntimeError("Must provide data_fields for inserting new data") - if batch_size is None: - raise ValueError("must provide batch_size for inserting new data") + # This is called during put_data call without providing metadata. + # try to use pre-allocated global index first - assert partition is not None - batch_global_indexes = partition.activate_pre_allocated_indexes(batch_size) + if batch_size is None: + raise ValueError("must provide batch_size for inserting new data") - if len(batch_global_indexes) < batch_size: - new_global_indexes = self.index_manager.allocate_indexes( - partition_id, count=(batch_size - len(batch_global_indexes)) - ) - batch_global_indexes.extend(new_global_indexes) + assert partition is not None + batch_global_indexes = partition.activate_pre_allocated_indexes(batch_size) - # register global_indexes in partition - partition.global_indexes.update(batch_global_indexes) + if len(batch_global_indexes) < batch_size: + new_global_indexes = self.index_manager.allocate_indexes( + partition_id, count=(batch_size - len(batch_global_indexes)) + ) + batch_global_indexes.extend(new_global_indexes) + + # register global_indexes in partition + partition.global_indexes.update(batch_global_indexes) - else: - batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) - assert task_name is not None if mode == "fetch": + assert task_name is not None # Find ready samples within current data partition and package into BatchMeta when reading if batch_size is None: @@ -1288,6 +1322,11 @@ def get_metadata( f"after sampling: {len(batch_global_indexes)}" ) + # Mark samples as consumed if in fetch mode + if consumed_indexes: + partition = self.partitions[partition_id] + partition.mark_consumed(task_name, consumed_indexes) + elif mode == "force_fetch": batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) consumed_indexes = [] @@ -1295,11 +1334,6 @@ def get_metadata( # Package into metadata metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) - # Mark samples as consumed if in fetch mode - if mode == "fetch" and consumed_indexes: - partition = self.partitions[partition_id] - partition.mark_consumed(task_name, consumed_indexes) - return metadata def scan_data_status( @@ -1779,12 +1813,18 @@ def _process_request(self): with perf_monitor.measure(op_type="GET_PARTITION_META"): params = request_msg.body partition_id = params["partition_id"] + partition = self._get_partition(partition_id) + if partition is not None: + partition_data_fields = list(partition.field_name_mapping.keys()) + + metadata = self.get_metadata( + data_fields=partition_data_fields, + partition_id=partition_id, + mode="force_fetch", + ) + else: + metadata = None - metadata = self.get_metadata( - data_fields=[], - partition_id=partition_id, - mode="insert", - ) response_msg = ZMQMessage.create( request_type=ZMQRequestType.GET_PARTITION_META_RESPONSE, sender_id=self.controller_id, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 97f25758..d0fd2f7f 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib.resources as pkg_resources import logging import math import os +import subprocess import time +from importlib import resources from typing import Any, Optional +from urllib.parse import urlparse import ray import torch @@ -73,6 +75,7 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: _TRANSFER_QUEUE_STORAGE = {} if conf.backend.storage_backend == "SimpleStorage": # initialize SimpleStorageUnit + simple_storage_handles = {} num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units total_storage_size = conf.backend.SimpleStorage.total_storage_size storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) @@ -86,13 +89,84 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: ).remote( storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), ) - _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") - storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) + storage_zmq_info = process_zmq_server_info(simple_storage_handles) backend_name = conf.backend.storage_backend conf.backend[backend_name].zmq_info = storage_zmq_info - + _TRANSFER_QUEUE_STORAGE["SimpleStorage"] = simple_storage_handles + if conf.backend.storage_backend == "MooncakeStore": + if conf.backend.MooncakeStore.auto_init: + # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts + check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) + if check.returncode == 0: + pids = check.stdout.strip().replace("\n", ", ") + logging.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...") + + result = os.system('pkill -f "[m]ooncake_master"') + if result == 0: + logging.info("Successfully killed existing mooncake_master processes.") + else: + raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).") + + raw_address = conf.backend.MooncakeStore.metadata_server + if "://" not in raw_address: + raw_address = "//" + raw_address + + parsed = urlparse(raw_address) + + if not parsed.hostname or parsed.port is None: + raise ValueError( + f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. " + f"Host and port are required (e.g., host:port)." + ) + + metadata_server_host = parsed.hostname + metadata_server_port = str(parsed.port) + + cmd = [ + "mooncake_master", + "-default_kv_lease_ttl=999999", + "-default_kv_soft_pin_ttl=999999", + "--eviction_high_watermark_ratio=1.0", + "--eviction_ratio=0.0", + "--enable_http_metadata_server=true", + "--allow_evict_soft_pinned_objects=false", + f"--http_metadata_server_host={metadata_server_host}", + f"--http_metadata_server_port={metadata_server_port}", + ] + + log_file_path = "/tmp/mooncake_master.log" + with open(log_file_path, "w") as log_file: + process = subprocess.Popen( + cmd, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + start_new_session=True, + ) + time.sleep(3) + + if process.poll() is None: + logger.info( + f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}" + ) + else: + error_msg = "" + try: + with open(log_file_path) as f: + error_msg = f.read() + except Exception as e: + error_msg = f"Failed to read log file: {e}" + + raise RuntimeError( + f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " + f"Output:\n{error_msg}" + ) + _TRANSFER_QUEUE_STORAGE["MooncakeStore"] = process return conf @@ -167,8 +241,7 @@ def init(conf: Optional[DictConfig] = None) -> None: # create config final_conf = OmegaConf.create({}, flags={"allow_objects": True}) - with pkg_resources.path("transfer_queue", "config.yaml") as p: - default_conf = OmegaConf.load(p) + default_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml") final_conf = OmegaConf.merge(final_conf, default_conf) if conf: final_conf = OmegaConf.merge(final_conf, conf) @@ -229,19 +302,43 @@ def close(): global _TRANSFER_QUEUE_CLIENT global _TRANSFER_QUEUE_STORAGE global _TRANSFER_QUEUE_CONTROLLER - if _TRANSFER_QUEUE_CLIENT: - _TRANSFER_QUEUE_CLIENT.close() - _TRANSFER_QUEUE_CLIENT = None try: if _TRANSFER_QUEUE_STORAGE: - # only the process that do first-time init can clean the distributed storage - for storage in _TRANSFER_QUEUE_STORAGE.values(): - ray.kill(storage) + for key, value in _TRANSFER_QUEUE_STORAGE.items(): + if key == "SimpleStorage": + # only the process that do first-time init can clean the distributed storage + for storage in value.values(): + ray.kill(storage) + elif key == "MooncakeStore": + check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) + if check.returncode == 0: + pids = check.stdout.strip().replace("\n", ", ") + logger.warning( + f"TransferQueue will not stop mooncake_master process with PID: {pids}. " + f"Consider manually killing the mooncake_master." + ) + + if _TRANSFER_QUEUE_CLIENT: + try: + ret = _TRANSFER_QUEUE_CLIENT.storage_manager.storage_client._store.remove_all() + if ret < 0: + logger.error("Failed to remove existing keys in mooncake_master.") + else: + logger.info("Successfully removed all existing keys in mooncake_master.") + except Exception: + pass + else: + logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") + _TRANSFER_QUEUE_STORAGE = None except Exception: pass + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None + if _TRANSFER_QUEUE_CONTROLLER: try: ray.kill(_TRANSFER_QUEUE_CONTROLLER) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index dd601362..24248e97 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -141,6 +141,63 @@ def __iter__(self): return (_SampleView(self._batch, i) for i in range(len(self))) +def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: + """Extract field-level schema from TensorDict.""" + field_schema: dict[str, dict[str, Any]] = {} + batch_size = data.batch_size[0] + + for field_name, value in data.items(): + is_tensor = isinstance(value, torch.Tensor) + is_nested = is_tensor and value.is_nested + + first_item = None + if is_nested: + unbound = value.unbind() + first_item = unbound[0] if unbound else None + elif is_tensor: + first_item = value[0] if value.shape[0] > 0 else None + else: + first_item = value[0] if len(value) > 0 else None + + # Determine is_non_tensor: when first_item is None (empty field), cannot determine type + if first_item is None: + is_non_tensor = None + else: + is_non_tensor = not is_tensor + + # Determine the shape of each sample (excluding batch dimension) + # When TensorDict converts a Python list to tensor, the first dimension equals batch_size + # We need to strip this batch dimension to get per-sample shape + if isinstance(value, torch.Tensor) and not is_nested and value.shape[0] > 0: + if value.shape[0] != batch_size: + raise ValueError( + f"Inconsistent batch dimension for field '{field_name}': " + f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}" + ) + if len(value.shape) > 1: + sample_shape = value.shape[1:] + else: + # When input is 1D tensor, manually set to torch.Size([1]). + sample_shape = torch.Size([1]) + else: + sample_shape = getattr(first_item, "shape", None) if first_item is not None else None + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": sample_shape, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + # For nested tensors, record per-sample shapes + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in value.unbind()] + + field_schema[field_name] = field_meta + + return field_schema + + @dataclass class BatchMeta: """Records the metadata of a batch of data samples with optimized field-level schema. @@ -160,9 +217,9 @@ class BatchMeta: global_indexes: list[int] partition_ids: list[str] - # O(F) field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} + # field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict) - # O(B) vectorized production status; always np.ndarray after __post_init__ (never None) + # vectorized production status matrix production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) # user-defined meta for each sample (sample-level), list aligned with global_indexes @@ -387,33 +444,10 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba if batch_size != self.size: raise ValueError(f"add_fields batch size mismatch: self.size={self.size} vs tensor_dict={batch_size}") - for name, value in tensor_dict.items(): - # Determine if this is a nested tensor - is_nested = isinstance(value, torch.Tensor) and value.is_nested - - first_item = None - if is_nested: - unbound = value.unbind() - first_item = unbound[0] if unbound else None - else: - first_item = value[0] if len(value) > 0 else None - - # Determine if this is non-tensor data. - # When first_item is None (empty field), we cannot determine type—leave as None. - is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else None - - field_meta = { - "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), - "shape": getattr(first_item, "shape", None) if not is_nested else None, - "is_nested": is_nested, - "is_non_tensor": is_non_tensor, - } - - # For nested tensors, record per-sample shapes - if is_nested: - field_meta["per_sample_shapes"] = [tuple(t.shape) for t in value.unbind()] + field_schema = extract_field_schema(tensor_dict) - self.field_schema[name] = field_meta + for key, value in field_schema.items(): + self.field_schema[key] = value if set_all_ready: self.production_status[:] = 1 diff --git a/transfer_queue/storage/clients/__init__.py b/transfer_queue/storage/clients/__init__.py index 93e81114..2b861166 100644 --- a/transfer_queue/storage/clients/__init__.py +++ b/transfer_queue/storage/clients/__init__.py @@ -16,7 +16,7 @@ # This module is currently empty but reserved for future client implementations from .base import TransferQueueStorageKVClient from .factory import StorageClientFactory -from .mooncake_client import MooncakeStorageClient +from .mooncake_client import MooncakeStoreClient from .ray_storage_client import RayStorageClient from .yuanrong_client import YuanrongStorageClient @@ -24,6 +24,6 @@ "TransferQueueStorageKVClient", "StorageClientFactory", "RayStorageClient", - "MooncakeStorageClient", + "MooncakeStoreClient", "YuanrongStorageClient", ] diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 4b4d9a3b..a6273210 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -36,8 +36,8 @@ BATCH_SIZE_LIMIT: int = 500 -@StorageClientFactory.register("MooncakeStorageClient") -class MooncakeStorageClient(TransferQueueStorageKVClient): +@StorageClientFactory.register("MooncakeStoreClient") +class MooncakeStoreClient(TransferQueueStorageKVClient): """ Storage client for MooncakeStore. """ @@ -46,13 +46,36 @@ def __init__(self, config: dict[str, Any]): if not MOONCAKE_STORE_IMPORTED: raise ImportError("Mooncake Store not installed. Please install via: pip install mooncake-transfer-engine") - self.local_hostname = config.get("local_hostname", "localhost") - self.metadata_server = config.get("metadata_server") - self.global_segment_size = config.get("global_segment_size", 512 * 1024 * 1024) - self.local_buffer_size = config.get("local_buffer_size", 128 * 1024 * 1024) + # Required: Address of local host + self.local_hostname = config.get("local_hostname", "") + # Required: Address of the HTTP metadata server (e.g., "localhost:8080") + self.metadata_server = config.get("metadata_server", None) + # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") + self.master_server_address = config.get("master_server_address") + + self.global_segment_size = int(config.get("global_segment_size", 4096 * 1024 * 1024)) + self.local_buffer_size = int(config.get("local_buffer_size", 1024 * 1024 * 1024)) self.protocol = config.get("protocol", "tcp") self.device_name = config.get("device_name", "") - self.master_server_address = config.get("master_server_address") + if self.device_name is None: + self.device_name = "" + + if self.local_hostname is None or self.local_hostname == "": + from transfer_queue.utils.zmq_utils import get_node_ip_address_raw + + ip = get_node_ip_address_raw() + logger.info(f"Try to use Ray IP ({ip}) as local hostname for MooncakeStore.") + self.local_hostname = ip + + if self.metadata_server is None or not isinstance(self.metadata_server, str): + raise ValueError("Missing or invalid 'metadata_server' in config") + if self.master_server_address is None or not isinstance(self.master_server_address, str): + raise ValueError("Missing or invalid 'master_server_address' in config") + + if not self.metadata_server.startswith("http://") and not self.metadata_server.startswith("etcd://"): + self.metadata_server = f"http://{self.metadata_server}" + if not self.metadata_server.startswith("etcd://") and not self.metadata_server.endswith("/metadata"): + self.metadata_server = self.metadata_server + "/metadata" if self.metadata_server is None: raise ValueError("Missing 'metadata_server' in config") @@ -146,7 +169,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non """ if shapes is None or dtypes is None: - raise ValueError("MooncakeStorageClient needs shapes and dtypes") + raise ValueError("MooncakeStoreClient needs shapes and dtypes") if not (len(keys) == len(shapes) == len(dtypes)): raise ValueError("Lengths of keys, shapes, dtypes must match") @@ -219,14 +242,23 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: def clear(self, keys: list[str], custom_backend_meta=None): """Deletes multiple keys from MooncakeStore. + Args: keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - for key in keys: - ret = self._store.remove(key) - if ret != 0: - logger.warning(f"remove failed for key '{key}' with error code: {ret}") + global_indexes_patterns = {key.split("@")[0] + "@.*" for key in keys} + for p in global_indexes_patterns: + ret = self._store.remove_by_regex(p, force=True) + if ret < 0: + logger.warning(f"remove failed for key '{p}' with error code: {ret}") + + # FIXME: controller returned BatchMeta may have mismatched fields in some case, preventing + # key-value based backends to accurately clear all existing keys.. + # for key in keys: + # ret = self._store.remove(key) + # if not (ret == 0 or ret == -704): + # logger.warning(f"remove failed for key '{key}' with error code: {ret}") def close(self): """Closes MooncakeStore.""" diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 28b35b36..d33c591f 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -32,7 +32,7 @@ from tensordict import NonTensorStack, TensorDict from torch import Tensor -from transfer_queue.metadata import BatchMeta +from transfer_queue.metadata import BatchMeta, extract_field_schema from transfer_queue.storage.clients.factory import StorageClientFactory from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket @@ -243,7 +243,7 @@ async def notify_data_update( while not response_received and timeout > 0: try: poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) - messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval) + messages = await asyncio.wait_for(sock.recv_multipart(copy=False), timeout=poll_interval) response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type] @@ -313,41 +313,6 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") - @staticmethod - def _extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: - """Extract field-level schema from TensorDict. O(F) complexity.""" - field_schema: dict[str, dict[str, Any]] = {} - - for field_name in data.keys(): - field_data = data[field_name] - - is_tensor = isinstance(field_data, torch.Tensor) - is_nested = is_tensor and field_data.is_nested - - if is_nested: - unbound = field_data.unbind() - first_item = unbound[0] if unbound else None - elif is_tensor: - first_item = field_data[0] if field_data.shape[0] > 0 else None - else: - first_item = field_data[0] if len(field_data) > 0 else None - - is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else False - - field_meta = { - "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), - "shape": getattr(first_item, "shape", None) if is_tensor and not is_nested else None, - "is_nested": is_nested, - "is_non_tensor": is_non_tensor, - } - - if is_nested: - field_meta["per_sample_shapes"] = [tuple(t.shape) for t in unbound] - - field_schema[field_name] = field_meta - - return field_schema - def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" # Close handshake socket if it exists @@ -557,31 +522,23 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): shapes = [] dtypes = [] custom_backend_meta_list = [] - num_samples = len(metadata) for field_name in sorted(metadata.field_names): - field_meta = metadata.field_schema.get(field_name, {}) - field_shape = field_meta.get("shape") - field_dtype = field_meta.get("dtype") - per_sample_shapes = field_meta.get("per_sample_shapes") - - for index in range(num_samples): - if per_sample_shapes is not None: - shapes.append(per_sample_shapes[index]) - else: - shapes.append(field_shape) - dtypes.append(field_dtype) - custom_backend_meta_list.append(metadata._custom_backend_meta[index].get(field_name, None)) + field_shape = metadata.get_shapes(field_name) + field_dtype = metadata.get_dtypes(field_name) + + shapes.extend(field_shape) + dtypes.extend(field_dtype) + + custom_backend_meta_list.extend( + [metadata._custom_backend_meta[i].get(field_name, None) for i in range(metadata.size)] + ) return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ Store tensor data in the backend storage and notify the controller. """ - if not metadata.field_names: - logger.warning("Attempted to put data, but metadata contains no fields.") - return - num_samples = len(metadata.global_indexes) if num_samples == 0: return @@ -592,7 +549,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) - field_schema = self._extract_field_schema(data) + field_schema = extract_field_schema(data) per_field_custom_backend_meta: dict[int, dict[str, Any]] = {} if custom_backend_meta: @@ -645,9 +602,12 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: async def clear_data(self, metadata: BatchMeta) -> None: """Remove stored data associated with the given metadata.""" + if not metadata.field_names: - logger.warning("Attempted to clear data, but metadata contains no fields.") - return + raise RuntimeError( + "Fail to clear_data for key-value based backends due to lack of `field_names` in BatchMeta" + ) + keys = self._generate_keys(metadata.field_names, metadata.global_indexes) _, _, custom_meta = self._get_shape_type_custom_backend_meta_list(metadata) self.storage_client.clear(keys=keys, custom_backend_meta=custom_meta) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 9f6f93a6..a24ffafd 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -30,20 +30,11 @@ class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): - # Required: Address of the HTTP metadata server (e.g., "localhost:8080") - metadata_server = config.get("metadata_server", None) - # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") - master_server_address = config.get("master_server_address", None) - # Optional: Name of the storage client, defaults to "MooncakeStorageClient" if not provided - client_name = config.get("client_name", None) + logger.warning( + "MooncakeStore backend doesn't support key update (upsert) for now. " + "You must delete the key before updating it. " + "Refer to https://github.com/kvcache-ai/Mooncake/issues/1645 for details." + ) - if metadata_server is None or not isinstance(metadata_server, str): - raise ValueError("Missing or invalid 'metadata_server' in config") - if master_server_address is None or not isinstance(master_server_address, str): - raise ValueError("Missing or invalid 'master_server_address' in config") - if client_name is None: - logger.info("Missing 'client_name' in config, using default value('MooncakeStorageClient')") - config["client_name"] = "MooncakeStorageClient" - elif client_name != "MooncakeStorageClient": - raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStorageClient'") + config["client_name"] = "MooncakeStoreClient" super().__init__(controller_info, config) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index f757cbf2..3d66dcb8 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -30,7 +30,7 @@ from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict -from transfer_queue.metadata import BatchMeta +from transfer_queue.metadata import BatchMeta, extract_field_schema from transfer_queue.storage.managers.base import TransferQueueStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory from transfer_queue.utils.zmq_utils import ( @@ -252,7 +252,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: if batch_size == 0: return - field_schema = self._extract_field_schema(data) + field_schema = extract_field_schema(data) routing = self._group_by_hash(metadata.global_indexes) tasks = [ @@ -305,7 +305,7 @@ async def _put_to_single_storage_unit( try: data = request_msg.serialize() await socket.send_multipart(data, copy=False) - messages = await socket.recv_multipart() + messages = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE: diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index d139f6b2..e6908e53 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -280,7 +280,7 @@ def _worker_routine(self) -> None: if worker_socket in socks: # Messages received from proxy: [identity, serialized_msg_frame1, ...] - messages = worker_socket.recv_multipart() + messages = worker_socket.recv_multipart(copy=False) identity = messages[0] serialized_msg = messages[1:] diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py index e25f6b09..a9d2b935 100644 --- a/transfer_queue/utils/common.py +++ b/transfer_queue/utils/common.py @@ -65,7 +65,7 @@ def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None target_num_threads = physical_cores if target_num_threads > physical_cores: - logger.error( + logger.warning( f"target_num_threads {target_num_threads} should not exceed total " f"physical CPU cores {physical_cores}. Setting to {physical_cores}." )