From 6eb511537ff6dacb21053351cdac25962354d434 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 12 Mar 2026 10:41:41 +0800 Subject: [PATCH 01/13] support one-click init of mooncakestore Signed-off-by: 0oshowero0 fix pre commit Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 --- .github/workflows/python-package.yml | 10 +- pyproject.toml | 3 + tests/e2e/test_e2e_lifecycle_consistency.py | 77 ++++++++++---- tests/e2e/test_kv_interface_e2e.py | 100 +++++++++++++++++- transfer_queue/config.yaml | 23 +++- transfer_queue/interface.py | 87 +++++++++++++-- transfer_queue/storage/clients/__init__.py | 4 +- .../storage/clients/mooncake_client.py | 30 ++++-- transfer_queue/storage/managers/base.py | 4 - .../storage/managers/mooncake_manager.py | 18 +--- 10 files changed, 294 insertions(+), 62 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 7cd24fc5..091ab99d 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -33,6 +33,7 @@ jobs: python -m pip install --upgrade pip pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -e ".[test,build,yuanrong]" + pip install mooncake-transfer-engine-non-cuda - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -43,11 +44,8 @@ jobs: run: | python -m build --wheel pip install dist/*.whl --force-reinstall - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False) + - name: Test with pytest run: | pytest tests - - name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True) - run: | - ray stop --force - export TQ_ZERO_COPY_SERIALIZATION=True - pytest tests \ No newline at end of file + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 35d65242..1fba227f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,9 @@ test = [ yuanrong = [ "openyuanrong-datasystem" ] +mooncake = [ + "mooncake-transfer-engine" +] # If you need to mimic `package_dir={'': '.'}`: [tool.setuptools.package-dir] diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 22b45c5c..0da25731 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""E2E lifecycle consistency tests for TransferQueue.""" - +import os import sys import time from pathlib import Path @@ -23,6 +22,7 @@ import pytest import ray import torch +from omegaconf import OmegaConf from tensordict import TensorDict from tensordict.tensorclass import NonTensorData @@ -46,6 +46,38 @@ "non_tensor_stack", ] +# Backend configurations for E2E tests +BACKEND_CONFIGS = { + "SimpleStorage": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + }, + "MooncakeStore": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "MooncakeStore", + "MooncakeStore": { + "global_segment_size": 134217728, # 128MB + "local_buffer_size": 134217728, # 128MB + "metadata_server": "localhost:50050", + "master_server_address": "localhost:50051", + "protocol": "tcp", + "device_name": "", + }, + }, + }, +} + @pytest.fixture(scope="module") def ray_cluster(): @@ -57,24 +89,33 @@ def ray_cluster(): @pytest.fixture(scope="module") -def e2e_client(ray_cluster): - """Create a client using transfer_queue.init() for lifecycle testing.""" - from omegaconf import OmegaConf +def backend_name(): + """Get the backend name from environment variable. + + Environment variables: + TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore) + To run tests for a specific backend: + TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py + """ + return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") + + +@pytest.fixture(scope="module") +def e2e_client(ray_cluster, backend_name): + """Create a client using transfer_queue.init() for lifecycle testing. + + Args: + ray_cluster: Ray cluster fixture + backend_name: Backend name from TQ_TEST_BACKEND env var + """ import transfer_queue - config = { - "controller": { - "polling_mode": True, - }, - "backend": { - "storage_backend": "SimpleStorage", - "SimpleStorage": { - "total_storage_size": 200, - "num_data_storage_units": 2, - }, - }, - } + if backend_name not in BACKEND_CONFIGS: + raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}") + + config = BACKEND_CONFIGS[backend_name] transfer_queue.init(OmegaConf.create(config)) client = transfer_queue.get_client() yield client @@ -277,7 +318,7 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict: # Scenario One: Core Read/Write Consistency def test_core_consistency(e2e_client): - """Put full complex data then get — verify all field types are correctly round-tripped.""" + """Put full complex data then get - verify all field types are correctly round-tripped.""" client = e2e_client partition_id = "test_core_consistency" batch_size = 20 diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index 71e9833a..a110c882 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -27,6 +27,7 @@ import pytest import ray import torch +from omegaconf import OmegaConf from tensordict import TensorDict # Add parent directory to path @@ -38,6 +39,40 @@ # Configure Ray for tests os.environ["RAY_DEDUP_LOGS"] = "0" +# Backend configurations for E2E tests +# Adjust values for GitHub CI environment (smaller memory footprint) +BACKEND_CONFIGS = { + "SimpleStorage": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": { + "total_storage_size": 200, + "num_data_storage_units": 2, + }, + }, + }, + "MooncakeStore": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "MooncakeStore", + "MooncakeStore": { + # Reduced memory sizes for CI/testing environment + "global_segment_size": 134217728, # 128MB + "local_buffer_size": 134217728, # 128MB + "metadata_server": os.environ.get("TQ_MOONCAKE_METADATA_SERVER", "localhost:50050"), + "master_server_address": os.environ.get("TQ_MOONCAKE_MASTER_SERVER", "localhost:50051"), + "protocol": "tcp", + "device_name": "", + }, + }, + }, +} + @pytest.fixture(scope="module") def ray_init(): @@ -50,9 +85,32 @@ def ray_init(): @pytest.fixture(scope="module") -def tq_system(ray_init): - """Initialize TransferQueue system for the test module.""" - tq.init() +def backend_name(): + """Get the backend name from environment variable. + + Environment variables: + TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore) + + To run tests for a specific backend: + TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_kv_interface_e2e.py + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py + """ + return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") + + +@pytest.fixture(scope="module") +def tq_system(ray_init, backend_name): + """Initialize TransferQueue system for the test module. + + Args: + ray_init: Ray cluster fixture + backend_name: Backend name from TQ_TEST_BACKEND env var + """ + if backend_name not in BACKEND_CONFIGS: + raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}") + + config = BACKEND_CONFIGS[backend_name] + tq.init(OmegaConf.create(config)) yield tq.close() @@ -109,6 +167,9 @@ def test_kv_put_with_dict_fields(self, controller): expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["data"], expected) + # delete the key (MooncakeStore does not support updating existing key, so we need to clear it before next test) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_with_tensordict_fields(self, controller): """Test kv_put with tensordict fields.""" partition_id = "test_partition" @@ -128,6 +189,8 @@ def test_kv_put_with_tensordict_fields(self, controller): expected = torch.tensor([[1, 2, 3, 4]]) # unsqueezed assert_tensor_equal(retrieved["input_ids"], expected) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_single_sample_with_fields_and_tag(self, controller): """Test putting a single sample with fields and tag.""" partition_id = "test_partition" @@ -175,6 +238,8 @@ def test_kv_put_single_sample_with_fields_and_tag(self, controller): assert_tensor_equal(retrieved["input_ids"], expected_input_ids) assert_tensor_equal(retrieved["attention_mask"], expected_attention_mask) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_update_tag_only(self, controller): """Test updating only tag without providing fields.""" partition_id = "test_partition" @@ -198,6 +263,8 @@ def test_kv_put_update_tag_only(self, controller): retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["value"], torch.tensor([[10]])) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_put_partial_update(self, controller): """Test adding new fields to existing sample.""" partition_id = "test_partition" @@ -232,6 +299,8 @@ def test_kv_put_partial_update(self, controller): # key should have response marked as produced assert partition.production_status[global_idx, response_col_idx] == 1, "Key should have response" + tq.kv_clear(keys=key, partition_id=partition_id) + class TestKVBatchPutE2E: """End-to-end tests for kv_batch_put functionality.""" @@ -282,6 +351,8 @@ def test_kv_batch_put_multiple_samples(self, controller): assert_tensor_equal(retrieved["input_ids"], batch_input_ids) assert_tensor_equal(retrieved["attention_mask"], batch_attention_mask) + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_put_partial_update(self, controller): """Test adding new fields to existing samples.""" partition_id = "test_partition" @@ -320,6 +391,8 @@ def test_kv_batch_put_partial_update(self, controller): # keys[1] should have response marked as produced assert partition.production_status[global_idx_1, response_col_idx] == 1, "Keys[1] should have response" + tq.kv_clear(keys=keys, partition_id=partition_id) + class TestKVGetE2E: """End-to-end tests for kv_batch_get functionality.""" @@ -337,6 +410,8 @@ def test_kv_batch_get_single_key(self, controller): retrieved = tq.kv_batch_get(keys=key, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_batch_get_multiple_keys(self, controller): """Test getting data for multiple keys.""" partition_id = "test_partition" @@ -349,6 +424,8 @@ def test_kv_batch_get_multiple_keys(self, controller): retrieved = tq.kv_batch_get(keys=keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_get_partial_keys(self, controller): """Test getting data for partial keys.""" partition_id = "test_partition" @@ -363,6 +440,8 @@ def test_kv_batch_get_partial_keys(self, controller): retrieved = tq.kv_batch_get(keys=partial_keys, partition_id=partition_id) assert_tensor_equal(retrieved["data"], expected_data) + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_batch_get_partial_fields(self, controller): """Test getting only partial fields.""" partition_id = "test_partition" @@ -394,6 +473,8 @@ def test_kv_batch_get_partial_fields(self, controller): assert_tensor_equal(retrieved["input_ids"], input_ids) assert_tensor_equal(retrieved["response"], response) + tq.kv_clear(keys=key, partition_id=partition_id) + def test_kv_batch_get_nonexistent_key(self, controller): """Test that getting data for non-existent key returns empty result.""" partition_id = "test_partition" @@ -432,6 +513,8 @@ def test_kv_list_single_partition(self, controller): for i, (key, tag) in enumerate(partition_info["test_partition"].items()): assert tag["id"] == i + tq.kv_clear(keys=keys, partition_id=partition_id) + def test_kv_list_all_partitions(self, controller): """Test listing keys and tags in all partitions.""" partition_id = ["test_partition0", "test_partition1", "test_partition2"] @@ -488,6 +571,10 @@ def test_kv_list_all_partitions(self, controller): for i, (key, tag) in enumerate(partition_info["test_partition2"].items()): assert tag["id"] == i + 6 + tq.kv_clear(keys=keys_partition0, partition_id=partition_id[0]) + tq.kv_clear(keys=keys_partition1, partition_id=partition_id[1]) + tq.kv_clear(keys=keys_partition2, partition_id=partition_id[2]) + def test_kv_list_empty_partition(self): """Test listing empty partition.""" partition_id = "test_partition_empty" @@ -522,6 +609,9 @@ def test_kv_clear_single_key(self, controller): assert key not in partition.keys_mapping assert other_key in partition.keys_mapping + tq.kv_clear(keys=key, partition_id=partition_id) + tq.kv_clear(keys=other_key, partition_id=partition_id) + def test_kv_clear_multiple_keys(self, controller): """Test clearing multiple keys.""" partition_id = "test_partition" @@ -541,6 +631,8 @@ def test_kv_clear_multiple_keys(self, controller): assert keys[2] in partition_info[partition_id] assert keys[3] in partition_info[partition_id] + tq.kv_clear(keys=keys[2:], partition_id=partition_id) + class TestKVE2ECornerCases: """End-to-end tests for corner cases.""" @@ -578,6 +670,8 @@ def test_field_expansion_across_samples(self, controller): assert "field_b" not in data assert "field_c" not in data + tq.kv_clear(keys=keys, partition_id=partition_id) + def run_tests(): """Run all e2e tests manually for debugging.""" diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index c0ddfe76..85440560 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -24,10 +24,27 @@ backend: # ZMQ Server IP & Ports (automatically generated during init) zmq_info: null + # For MooncakeStore: + MooncakeStore: + # Auto init metadata_server + auto_init: true + # Address of the HTTP metadata server + metadata_server: localhost:50050 + # Address of master server + master_server_address: localhost:50051 + # Address of local host + local_hostname: localhost + # Protocol for transmission. Choose from: tcp, rdma. (default: tcp) + protocol: tcp + # Memory segment size in bytes for mounting (default: 4GB) + global_segment_size: 4294967296 + # Local buffer size in bytes (default: 1GB) + local_buffer_size: 1073741824 + # Network device name. Set to "" to let Mooncake to auto-picks devices + device_name: "" + + # For RayStore: RayStore: # For Yuanrong: - # TODO - - # For MooncakeStore: # TODO \ No newline at end of file diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 97f25758..1eac34f5 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -17,8 +17,10 @@ import logging import math import os +import subprocess import time from typing import Any, Optional +from urllib.parse import urlparse import ray import torch @@ -73,6 +75,7 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: _TRANSFER_QUEUE_STORAGE = {} if conf.backend.storage_backend == "SimpleStorage": # initialize SimpleStorageUnit + simple_storage_handles = {} num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units total_storage_size = conf.backend.SimpleStorage.total_storage_size storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) @@ -86,13 +89,72 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: ).remote( storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), ) - _TRANSFER_QUEUE_STORAGE[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") - storage_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_STORAGE) + storage_zmq_info = process_zmq_server_info(simple_storage_handles) backend_name = conf.backend.storage_backend conf.backend[backend_name].zmq_info = storage_zmq_info - + _TRANSFER_QUEUE_STORAGE["SimpleStorage"] = simple_storage_handles + if conf.backend.storage_backend == "MooncakeStore": + if conf.backend.MooncakeStore.auto_init: + logger.info("Try to initialize mooncake_master automatically.") + raw_address = conf.backend.MooncakeStore.metadata_server + + if raw_address is None or not isinstance(raw_address, str): + raise ValueError("Missing or invalid 'metadata_server' in config") + + if "://" not in raw_address: + raw_address = "//" + raw_address + + parsed = urlparse(raw_address) + + if not parsed.hostname or parsed.port is None: + raise ValueError( + f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. " + f"Host and port are required (e.g., host:port)." + ) + + metadata_server_host = parsed.hostname + metadata_server_port = str(parsed.port) + + cmd = [ + "mooncake_master", + "-default_kv_lease_ttl=0", + "-default_kv_soft_pin_ttl=0", + "--enable_http_metadata_server=true", + "--allow_evict_soft_pinned_objects=false", + f"--http_metadata_server_host={metadata_server_host}", + f"--http_metadata_server_port={metadata_server_port}", + ] + + log_file_path = "/tmp/mooncake_master.log" + + with open(log_file_path, "w") as log_file: + process = subprocess.Popen( + cmd, stdout=log_file, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True + ) + + time.sleep(3) + + if process.poll() is None: + logger.info( + f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}" + ) + else: + error_msg = "" + try: + with open(log_file_path) as f: + error_msg = f.read() + except Exception as e: + error_msg = f"Failed to read log file: {e}" + + raise RuntimeError( + f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " + f"Output:\n{error_msg}" + ) + + _TRANSFER_QUEUE_STORAGE["MooncakeStore"] = process return conf @@ -235,9 +297,22 @@ def close(): try: if _TRANSFER_QUEUE_STORAGE: - # only the process that do first-time init can clean the distributed storage - for storage in _TRANSFER_QUEUE_STORAGE.values(): - ray.kill(storage) + for key, value in _TRANSFER_QUEUE_STORAGE.items(): + if key == "SimpleStorage": + # only the process that do first-time init can clean the distributed storage + for storage in value.values(): + ray.kill(storage) + elif key == "MooncakeStore": + process = value + if process and process.poll() is None: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + else: + logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") + _TRANSFER_QUEUE_STORAGE = None except Exception: pass diff --git a/transfer_queue/storage/clients/__init__.py b/transfer_queue/storage/clients/__init__.py index 93e81114..2b861166 100644 --- a/transfer_queue/storage/clients/__init__.py +++ b/transfer_queue/storage/clients/__init__.py @@ -16,7 +16,7 @@ # This module is currently empty but reserved for future client implementations from .base import TransferQueueStorageKVClient from .factory import StorageClientFactory -from .mooncake_client import MooncakeStorageClient +from .mooncake_client import MooncakeStoreClient from .ray_storage_client import RayStorageClient from .yuanrong_client import YuanrongStorageClient @@ -24,6 +24,6 @@ "TransferQueueStorageKVClient", "StorageClientFactory", "RayStorageClient", - "MooncakeStorageClient", + "MooncakeStoreClient", "YuanrongStorageClient", ] diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 4b4d9a3b..b9a9b8a6 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -36,8 +36,8 @@ BATCH_SIZE_LIMIT: int = 500 -@StorageClientFactory.register("MooncakeStorageClient") -class MooncakeStorageClient(TransferQueueStorageKVClient): +@StorageClientFactory.register("MooncakeStoreClient") +class MooncakeStoreClient(TransferQueueStorageKVClient): """ Storage client for MooncakeStore. """ @@ -46,13 +46,29 @@ def __init__(self, config: dict[str, Any]): if not MOONCAKE_STORE_IMPORTED: raise ImportError("Mooncake Store not installed. Please install via: pip install mooncake-transfer-engine") + # Required: Address of local host self.local_hostname = config.get("local_hostname", "localhost") - self.metadata_server = config.get("metadata_server") - self.global_segment_size = config.get("global_segment_size", 512 * 1024 * 1024) - self.local_buffer_size = config.get("local_buffer_size", 128 * 1024 * 1024) + # Required: Address of the HTTP metadata server (e.g., "localhost:8080") + self.metadata_server = config.get("metadata_server", None) + # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") + self.master_server_address = config.get("master_server_address") + + self.global_segment_size = int(config.get("global_segment_size", 4096 * 1024 * 1024)) + self.local_buffer_size = int(config.get("local_buffer_size", 1024 * 1024 * 1024)) self.protocol = config.get("protocol", "tcp") self.device_name = config.get("device_name", "") - self.master_server_address = config.get("master_server_address") + if self.device_name is None: + self.device_name = "" + + if self.metadata_server is None or not isinstance(self.metadata_server, str): + raise ValueError("Missing or invalid 'metadata_server' in config") + if self.master_server_address is None or not isinstance(self.master_server_address, str): + raise ValueError("Missing or invalid 'master_server_address' in config") + + if not self.metadata_server.startswith("http://") and not self.metadata_server.startswith("etcd://"): + self.metadata_server = f"http://{self.metadata_server}" + if not self.metadata_server.startswith("etcd://") and not self.metadata_server.endswith("/metadata"): + self.metadata_server = self.metadata_server + "/metadata" if self.metadata_server is None: raise ValueError("Missing 'metadata_server' in config") @@ -146,7 +162,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non """ if shapes is None or dtypes is None: - raise ValueError("MooncakeStorageClient needs shapes and dtypes") + raise ValueError("MooncakeStoreClient needs shapes and dtypes") if not (len(keys) == len(shapes) == len(dtypes)): raise ValueError("Lengths of keys, shapes, dtypes must match") diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 37b4b3fb..d5177151 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -584,10 +584,6 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ Store tensor data in the backend storage and notify the controller. """ - if not metadata.field_names: - logger.warning("Attempted to put data, but metadata contains no fields.") - return - num_samples = len(metadata.global_indexes) if num_samples == 0: return diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 9f6f93a6..7b8219c4 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -30,20 +30,12 @@ class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): - # Required: Address of the HTTP metadata server (e.g., "localhost:8080") - metadata_server = config.get("metadata_server", None) - # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") - master_server_address = config.get("master_server_address", None) - # Optional: Name of the storage client, defaults to "MooncakeStorageClient" if not provided client_name = config.get("client_name", None) - if metadata_server is None or not isinstance(metadata_server, str): - raise ValueError("Missing or invalid 'metadata_server' in config") - if master_server_address is None or not isinstance(master_server_address, str): - raise ValueError("Missing or invalid 'master_server_address' in config") if client_name is None: - logger.info("Missing 'client_name' in config, using default value('MooncakeStorageClient')") - config["client_name"] = "MooncakeStorageClient" - elif client_name != "MooncakeStorageClient": - raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStorageClient'") + logger.info("Missing 'client_name' in config, using default value('MooncakeStoreClient')") + config["client_name"] = "MooncakeStoreClient" + elif client_name != "MooncakeStoreClient": + raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStoreClient'") + super().__init__(controller_info, config) From 888c788f6ffa0f0931d9192d06d365b0b584297e Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 13 Mar 2026 10:36:24 +0800 Subject: [PATCH 02/13] fix e2e kv test Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 55 +++++++++++++------ .../storage/clients/mooncake_client.py | 9 +-- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 1eac34f5..00f76ab6 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -98,12 +98,19 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: _TRANSFER_QUEUE_STORAGE["SimpleStorage"] = simple_storage_handles if conf.backend.storage_backend == "MooncakeStore": if conf.backend.MooncakeStore.auto_init: - logger.info("Try to initialize mooncake_master automatically.") - raw_address = conf.backend.MooncakeStore.metadata_server - - if raw_address is None or not isinstance(raw_address, str): - raise ValueError("Missing or invalid 'metadata_server' in config") + # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts + check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) + if check.returncode == 0: + pids = check.stdout.strip().replace("\n", ", ") + logging.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...") + + result = os.system('pkill -f "[m]ooncake_master"') + if result == 0: + logging.info("Successfully killed existing mooncake_master processes.") + else: + raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).") + raw_address = conf.backend.MooncakeStore.metadata_server if "://" not in raw_address: raw_address = "//" + raw_address @@ -120,8 +127,10 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: cmd = [ "mooncake_master", - "-default_kv_lease_ttl=0", - "-default_kv_soft_pin_ttl=0", + "-default_kv_lease_ttl=999", + "-default_kv_soft_pin_ttl=999", + "--eviction_high_watermark_ratio=1.0", + "--eviction_ratio=0.0", "--enable_http_metadata_server=true", "--allow_evict_soft_pinned_objects=false", f"--http_metadata_server_host={metadata_server_host}", @@ -129,11 +138,17 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: ] log_file_path = "/tmp/mooncake_master.log" - - with open(log_file_path, "w") as log_file: - process = subprocess.Popen( - cmd, stdout=log_file, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True - ) + log_file = open(log_file_path, "w") + + process = subprocess.Popen( + cmd, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + start_new_session=True, + ) time.sleep(3) @@ -153,7 +168,6 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " f"Output:\n{error_msg}" ) - _TRANSFER_QUEUE_STORAGE["MooncakeStore"] = process return conf @@ -303,13 +317,20 @@ def close(): for storage in value.values(): ray.kill(storage) elif key == "MooncakeStore": + import signal + process = value if process and process.poll() is None: - process.terminate() try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - process.kill() + pgid = os.getpgid(process.pid) + os.killpg(pgid, signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + process.wait(timeout=5) + except ProcessLookupError: + logger.warning(f"MooncakeStore process already exited: pid={process.pid}") else: logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index b9a9b8a6..672693ef 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -239,10 +239,11 @@ def clear(self, keys: list[str], custom_backend_meta=None): keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - for key in keys: - ret = self._store.remove(key) - if ret != 0: - logger.warning(f"remove failed for key '{key}' with error code: {ret}") + global_indexes = [key.split("@")[0] + "@*" for key in keys] + for gid in global_indexes: + ret = self._store.remove_by_regex(gid, force=True) + if ret < 0: + logger.warning(f"remove failed for key '{gid}' with error code: {ret}") def close(self): """Closes MooncakeStore.""" From c801e89216035bd5907b831c1eb252eb900a012c Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 13 Mar 2026 10:57:19 +0800 Subject: [PATCH 03/13] modify tq.close() Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 00f76ab6..bf8d9cec 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -317,20 +317,27 @@ def close(): for storage in value.values(): ray.kill(storage) elif key == "MooncakeStore": - import signal - - process = value - if process and process.poll() is None: - try: - pgid = os.getpgid(process.pid) - os.killpg(pgid, signal.SIGTERM) - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - os.killpg(pgid, signal.SIGKILL) - process.wait(timeout=5) - except ProcessLookupError: - logger.warning(f"MooncakeStore process already exited: pid={process.pid}") + check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) + if check.returncode == 0: + pids = check.stdout.strip().replace("\n", ", ") + logger.warning( + f"mooncake_master process still exists with PID: {pids}. " + f"Consider manually killing mooncake_master." + ) + # os.system('pkill -f "TransferQueue"') + # process = value + # if process and process.poll() is None: + # try: + # import signal + # pgid = os.getpgid(process.pid) + # os.killpg(pgid, signal.SIGTERM) + # try: + # process.wait(timeout=5) + # except subprocess.TimeoutExpired: + # os.killpg(pgid, signal.SIGKILL) + # process.wait(timeout=5) + # except ProcessLookupError: + # logger.warning(f"MooncakeStore process already exited: pid={process.pid}") else: logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") From c5d792978a1bc8f07126f2a73572711c6ede93f9 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 13 Mar 2026 11:48:07 +0800 Subject: [PATCH 04/13] add fixme Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index d5177151..39a6facc 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -565,6 +565,8 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): custom_backend_meta_list = [] num_samples = len(metadata) + # FIXME: Use BatchMeta.get_dtype and .get_shape instead + for field_name in sorted(metadata.field_names): field_meta = metadata.field_schema.get(field_name, {}) field_shape = field_meta.get("shape") From a9f8c5ba46e71b5b4c3287fa8f3e760233d6cb95 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 13 Mar 2026 17:15:13 +0800 Subject: [PATCH 05/13] auto set local host Signed-off-by: 0oshowero0 --- transfer_queue/config.yaml | 6 +++--- transfer_queue/interface.py | 4 ++-- transfer_queue/storage/clients/mooncake_client.py | 9 ++++++++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/transfer_queue/config.yaml b/transfer_queue/config.yaml index 85440560..98819edd 100644 --- a/transfer_queue/config.yaml +++ b/transfer_queue/config.yaml @@ -26,14 +26,14 @@ backend: # For MooncakeStore: MooncakeStore: - # Auto init metadata_server + # Whether to let TQ automatically init metadata_server. auto_init: true # Address of the HTTP metadata server metadata_server: localhost:50050 # Address of master server master_server_address: localhost:50051 - # Address of local host - local_hostname: localhost + # Address of local host. Set to "" to use Ray IP as local host address + local_hostname: "" # Protocol for transmission. Choose from: tcp, rdma. (default: tcp) protocol: tcp # Memory segment size in bytes for mounting (default: 4GB) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index bf8d9cec..c6db10c0 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -127,8 +127,8 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: cmd = [ "mooncake_master", - "-default_kv_lease_ttl=999", - "-default_kv_soft_pin_ttl=999", + "-default_kv_lease_ttl=999999", + "-default_kv_soft_pin_ttl=999999", "--eviction_high_watermark_ratio=1.0", "--eviction_ratio=0.0", "--enable_http_metadata_server=true", diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 672693ef..dd98653a 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -47,7 +47,7 @@ def __init__(self, config: dict[str, Any]): raise ImportError("Mooncake Store not installed. Please install via: pip install mooncake-transfer-engine") # Required: Address of local host - self.local_hostname = config.get("local_hostname", "localhost") + self.local_hostname = config.get("local_hostname", "") # Required: Address of the HTTP metadata server (e.g., "localhost:8080") self.metadata_server = config.get("metadata_server", None) # Required: Address of the master server RPC endpoint (e.g., "localhost:8081") @@ -60,6 +60,13 @@ def __init__(self, config: dict[str, Any]): if self.device_name is None: self.device_name = "" + if self.local_hostname is None or self.local_hostname == "": + from transfer_queue.utils.zmq_utils import get_node_ip_address_raw + + ip = get_node_ip_address_raw() + logger.info(f"Try to use Ray IP ({ip}) as local hostname for MooncakeStore.") + self.local_hostname = ip + if self.metadata_server is None or not isinstance(self.metadata_server, str): raise ValueError("Missing or invalid 'metadata_server' in config") if self.master_server_address is None or not isinstance(self.master_server_address, str): From 901765c5c221e6904d0a007305c23ee82905c456 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 13 Mar 2026 20:53:57 +0800 Subject: [PATCH 06/13] use new batchmeta api Signed-off-by: 0oshowero0 try: unify extract_field_schema Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 partially fix bugs Signed-off-by: 0oshowero0 fix remove_samples Signed-off-by: 0oshowero0 fix all ut Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 --- .github/workflows/python-package.yml | 4 +- tests/e2e/test_e2e_lifecycle_consistency.py | 35 ++++++- tests/test_controller.py | 10 +- transfer_queue/client.py | 11 +++ transfer_queue/controller.py | 92 +++++++++++++------ transfer_queue/interface.py | 2 +- transfer_queue/metadata.py | 87 ++++++++++++------ .../storage/clients/mooncake_client.py | 8 ++ transfer_queue/storage/managers/base.py | 70 ++++---------- .../managers/simple_backend_manager.py | 4 +- 10 files changed, 202 insertions(+), 121 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 091ab99d..9629df75 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -47,5 +47,5 @@ jobs: - name: Test with pytest run: | pytest tests - TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py - TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py \ No newline at end of file + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/ + pkill -f "mooncake_master" \ No newline at end of file diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index 8774ff9a..39b0b91e 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -285,7 +285,7 @@ def verify_list_equal(retrieved, expected) -> bool: if isinstance(retrieved, NonTensorStack): retrieved = retrieved.tolist() elif isinstance(retrieved, torch.Tensor): - retrieved = retrieved.tolist() + retrieved = retrieved.reshape(-1).tolist() # may get 2D tensor back using key-value based backend if isinstance(expected, NonTensorStack): expected = expected.tolist() elif isinstance(expected, torch.Tensor): @@ -324,6 +324,18 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict: return TensorDict(reordered, batch_size=td.batch_size) +def recover_local_index(global_index_order, new_global_index_order): + value_to_new_index = {} + for idx, val in enumerate(new_global_index_order): + value_to_new_index[val] = idx + + local_index_order_to_recover = [] + for val in global_index_order: + local_index_order_to_recover.append(value_to_new_index[val]) + + return local_index_order_to_recover + + # Scenario One: Core Read/Write Consistency def test_core_consistency(e2e_client): """Put full complex data then get - verify all field types are correctly round-tripped.""" @@ -403,6 +415,12 @@ def test_core_consistency(e2e_client): # Scenario Two: Cross-Shard Update def test_cross_shard_complex_update(e2e_client): """Cross-shard update: put A + put B, update overlapping region, verify all regions.""" + + # FIXME: Add data update test to MooncakeStore after Upsert function is ready + # https://github.com/kvcache-ai/Mooncake/issues/1645 + if os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") == "MooncakeStore": + return + client = e2e_client partition_id = "test_cross_shard_update" task_name = "cross_shard_task" @@ -785,12 +803,19 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): indices = list(range(batch_size)) original_data = generate_complex_data(indices) - client.put(data=original_data, partition_id=partition_id) + original_meta = client.put(data=original_data, partition_id=partition_id) + global_index_order = original_meta.global_indexes try: # === Phase 1: Retrieve and verify writability === meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") assert meta is not None and meta.size == batch_size + + # the global_index_order in retrieved meta is different from the original one. + # we need to reorder first. + local_index_order = recover_local_index(global_index_order, meta.global_indexes) + meta = meta.select_samples(local_index_order) + retrieved = client.get_data(meta) # 1. tensor_f32: writable @@ -834,6 +859,12 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): # Re-retrieve the same data — modifications above should NOT have affected storage meta2 = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch") assert meta2 is not None and meta2.size == batch_size + + # the global_index_order in retrieved meta is different from the original one. + # we need to reorder first. + local_index_order = recover_local_index(global_index_order, meta2.global_indexes) + meta2 = meta2.select_samples(local_index_order) + retrieved2 = client.get_data(meta2) # tensor_f32[0,0] should be the original value, not 99999.0 diff --git a/tests/test_controller.py b/tests/test_controller.py index 74793bd4..2e559600 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -176,14 +176,14 @@ def test_controller_with_single_partition(self, ray_setup): # Test get clear meta clear_meta = ray.get( tq_controller.get_metadata.remote( - data_fields=[], + data_fields=gen_meta.field_names, partition_id=partition_id, - mode="insert", + mode="force_fetch", ) ) assert clear_meta.global_indexes == list(range(gbs * num_n_samples)) # In insert mode with no fields, field_schema should be empty - assert clear_meta.field_schema == {} or clear_meta.field_names == [] + assert clear_meta.field_names == gen_meta.field_names print("✓ Clear metadata correct") # Test clear_partition @@ -431,9 +431,9 @@ def test_controller_with_multi_partitions(self, ray_setup): # Test get clear meta clear_meta = ray.get( tq_controller.get_metadata.remote( - data_fields=[], + data_fields=gen_meta.field_names, partition_id=partition_id_1, - mode="insert", + mode="force_fetch", ) ) assert clear_meta diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 199ceeda..a1154c0b 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -387,6 +387,13 @@ async def async_put( "Call initialize_storage_manager() before performing storage operations." ) + for field_name, field_data in data.items(): + if isinstance(field_data, torch.Tensor) and field_data.ndim == 1: + logger.warning( + f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. " + f"You may receive 2D tensors in key-value based backend." + ) + if metadata is None: if partition_id is None: raise ValueError("partition_id must be provided if metadata is not given") @@ -480,6 +487,10 @@ async def async_clear_partition(self, partition_id: str): metadata = await self._get_partition_meta(partition_id) + if not metadata: + logger.warning(f"Try to clear an non-exist partition {partition_id}. No action will be taken.") + return + # Clear the controller metadata await self._clear_partition_in_controller(partition_id) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c74e515a..6d003dc4 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -251,6 +251,21 @@ def remove_samples(self, indexes: list[int]): for idx in indexes: self.per_sample_shapes.pop(idx, None) + # After removing samples, check if we can update is_nested and shape + # If per_sample_shapes is empty or all remaining shapes are the same, + # we should reset is_nested to False and update shape accordingly + if not self.per_sample_shapes: + # All samples removed - reset to non-nested state + self.is_nested = False + self.shape = None + else: + # Check if all remaining shapes are the same + remaining_shapes = set(self.per_sample_shapes.values()) + if len(remaining_shapes) == 1: + # All remaining samples have the same shape - update to non-nested + self.is_nested = False + self.shape = next(iter(remaining_shapes)) + def to_batch_schema(self, batch_global_indexes: list[int]) -> dict[str, Any]: """Export as a BatchMeta.field_schema-compatible dict for generate_batch_meta.""" schema = { @@ -529,7 +544,24 @@ def _update_field_metadata( is_non_tensor=meta.get("is_non_tensor", False), ) else: + # Track if is_nested changed from False to True during update + was_not_nested = not self.field_metadata[field_name].is_nested + # Save old shape before update (for filling per_sample_shapes of existing samples) + old_shape = self.field_metadata[field_name].shape self.field_metadata[field_name].update(meta) + # If is_nested became True due to shape mismatch, capture shapes for all samples + if was_not_nested and self.field_metadata[field_name].is_nested: + col_meta = self.field_metadata[field_name] + new_shape = meta.get("shape") + # Fill new samples with new shape + if new_shape is not None: + for gi in global_indexes: + col_meta.per_sample_shapes[gi] = new_shape + # Fill existing samples with old shape + if old_shape is not None: + for gi in self.global_indexes: + if gi not in col_meta.per_sample_shapes: + col_meta.per_sample_shapes[gi] = old_shape # nested per-sample shapes per_sample_shapes = meta.get("per_sample_shapes") @@ -1214,31 +1246,31 @@ def get_metadata( self.create_partition(partition_id) partition = self._get_partition(partition_id) - if data_fields: - # This is called during put_data call without providing metadata. - # try to use pre-allocated global index first + if data_fields is None: + raise RuntimeError("Must provide data_fields for inserting new data") - if batch_size is None: - raise ValueError("must provide batch_size for inserting new data") + # This is called during put_data call without providing metadata. + # try to use pre-allocated global index first - assert partition is not None - batch_global_indexes = partition.activate_pre_allocated_indexes(batch_size) + if batch_size is None: + raise ValueError("must provide batch_size for inserting new data") - if len(batch_global_indexes) < batch_size: - new_global_indexes = self.index_manager.allocate_indexes( - partition_id, count=(batch_size - len(batch_global_indexes)) - ) - batch_global_indexes.extend(new_global_indexes) + assert partition is not None + batch_global_indexes = partition.activate_pre_allocated_indexes(batch_size) - # register global_indexes in partition - partition.global_indexes.update(batch_global_indexes) + if len(batch_global_indexes) < batch_size: + new_global_indexes = self.index_manager.allocate_indexes( + partition_id, count=(batch_size - len(batch_global_indexes)) + ) + batch_global_indexes.extend(new_global_indexes) + + # register global_indexes in partition + partition.global_indexes.update(batch_global_indexes) - else: - batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) - assert task_name is not None if mode == "fetch": + assert task_name is not None # Find ready samples within current data partition and package into BatchMeta when reading if batch_size is None: @@ -1288,6 +1320,11 @@ def get_metadata( f"after sampling: {len(batch_global_indexes)}" ) + # Mark samples as consumed if in fetch mode + if consumed_indexes: + partition = self.partitions[partition_id] + partition.mark_consumed(task_name, consumed_indexes) + elif mode == "force_fetch": batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id) consumed_indexes = [] @@ -1295,11 +1332,6 @@ def get_metadata( # Package into metadata metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) - # Mark samples as consumed if in fetch mode - if mode == "fetch" and consumed_indexes: - partition = self.partitions[partition_id] - partition.mark_consumed(task_name, consumed_indexes) - return metadata def scan_data_status( @@ -1779,12 +1811,18 @@ def _process_request(self): with perf_monitor.measure(op_type="GET_PARTITION_META"): params = request_msg.body partition_id = params["partition_id"] + partition = self._get_partition(partition_id) + if partition is not None: + partition_data_fields = list(partition.field_name_mapping.keys()) + + metadata = self.get_metadata( + data_fields=partition_data_fields, + partition_id=partition_id, + mode="force_fetch", + ) + else: + metadata = None - metadata = self.get_metadata( - data_fields=[], - partition_id=partition_id, - mode="insert", - ) response_msg = ZMQMessage.create( request_type=ZMQRequestType.GET_PARTITION_META_RESPONSE, sender_id=self.controller_id, diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index c6db10c0..386d6073 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -324,7 +324,7 @@ def close(): f"mooncake_master process still exists with PID: {pids}. " f"Consider manually killing mooncake_master." ) - # os.system('pkill -f "TransferQueue"') + # os.system('pkill -f "mooncake_master"') # process = value # if process and process.poll() is None: # try: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index dd601362..42d24d25 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -141,6 +141,60 @@ def __iter__(self): return (_SampleView(self._batch, i) for i in range(len(self))) +def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: + """Extract field-level schema from TensorDict.""" + field_schema: dict[str, dict[str, Any]] = {} + batch_size = data.batch_size[0] + + for field_name, value in data.items(): + is_tensor = isinstance(value, torch.Tensor) + is_nested = is_tensor and value.is_nested + + first_item = None + if is_nested: + unbound = value.unbind() + first_item = unbound[0] if unbound else None + elif is_tensor: + first_item = value[0] if value.shape[0] > 0 else None + else: + first_item = value[0] if len(value) > 0 else None + + # Determine is_non_tensor: when first_item is None (empty field), cannot determine type + if first_item is None: + is_non_tensor = None + else: + is_non_tensor = not is_tensor + + # Determine the shape of each sample (excluding batch dimension) + # When TensorDict converts a Python list to tensor, the first dimension equals batch_size + # We need to strip this batch dimension to get per-sample shape + if isinstance(value, torch.Tensor) and not is_nested and value.shape[0] > 0: + # Check if first dimension is batch dimension + assert value.shape[0] == batch_size + if len(value.shape) > 1: + # Multi-dim tensor: shape = value.shape[1:] + sample_shape = value.shape[1:] + else: + sample_shape = torch.Size([1]) + else: + sample_shape = getattr(first_item, "shape", None) if first_item is not None else None + + field_meta = { + "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), + "shape": sample_shape, + "is_nested": is_nested, + "is_non_tensor": is_non_tensor, + } + + # For nested tensors, record per-sample shapes + if is_nested: + field_meta["per_sample_shapes"] = [tuple(t.shape) for t in value.unbind()] + + field_schema[field_name] = field_meta + + return field_schema + + @dataclass class BatchMeta: """Records the metadata of a batch of data samples with optimized field-level schema. @@ -160,9 +214,9 @@ class BatchMeta: global_indexes: list[int] partition_ids: list[str] - # O(F) field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} + # field-level metadata: {field_name: {dtype, shape, is_nested, is_non_tensor}} field_schema: dict[str, dict[str, Any]] = dataclasses.field(default_factory=dict) - # O(B) vectorized production status; always np.ndarray after __post_init__ (never None) + # vectorized production status matrix production_status: np.ndarray = dataclasses.field(default=None, repr=False) # type: ignore[assignment] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) # user-defined meta for each sample (sample-level), list aligned with global_indexes @@ -387,33 +441,10 @@ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "Ba if batch_size != self.size: raise ValueError(f"add_fields batch size mismatch: self.size={self.size} vs tensor_dict={batch_size}") - for name, value in tensor_dict.items(): - # Determine if this is a nested tensor - is_nested = isinstance(value, torch.Tensor) and value.is_nested - - first_item = None - if is_nested: - unbound = value.unbind() - first_item = unbound[0] if unbound else None - else: - first_item = value[0] if len(value) > 0 else None - - # Determine if this is non-tensor data. - # When first_item is None (empty field), we cannot determine type—leave as None. - is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else None - - field_meta = { - "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), - "shape": getattr(first_item, "shape", None) if not is_nested else None, - "is_nested": is_nested, - "is_non_tensor": is_non_tensor, - } - - # For nested tensors, record per-sample shapes - if is_nested: - field_meta["per_sample_shapes"] = [tuple(t.shape) for t in value.unbind()] + field_schema = extract_field_schema(tensor_dict) - self.field_schema[name] = field_meta + for key, value in field_schema.items(): + self.field_schema[key] = value if set_all_ready: self.production_status[:] = 1 diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index dd98653a..be4fe675 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -242,6 +242,7 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]: def clear(self, keys: list[str], custom_backend_meta=None): """Deletes multiple keys from MooncakeStore. + Args: keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... @@ -252,6 +253,13 @@ def clear(self, keys: list[str], custom_backend_meta=None): if ret < 0: logger.warning(f"remove failed for key '{gid}' with error code: {ret}") + # FIXME: controller returned BatchMeta may have mismatched fields in some case, preventing + # key-value based backends to accurately clear all existing keys.. + # for key in keys: + # ret = self._store.remove(key) + # if not (ret == 0 or ret == -704): + # logger.warning(f"remove failed for key '{key}' with error code: {ret}") + def close(self): """Closes MooncakeStore.""" if self._store: diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 621aa008..26c92f50 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -32,7 +32,7 @@ from tensordict import NonTensorStack, TensorDict from torch import Tensor -from transfer_queue.metadata import BatchMeta +from transfer_queue.metadata import BatchMeta, extract_field_schema from transfer_queue.storage.clients.factory import StorageClientFactory from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket @@ -313,41 +313,6 @@ async def clear_data(self, metadata: BatchMeta) -> None: """ raise NotImplementedError("Subclasses must implement clear_data") - @staticmethod - def _extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: - """Extract field-level schema from TensorDict. O(F) complexity.""" - field_schema: dict[str, dict[str, Any]] = {} - - for field_name in data.keys(): - field_data = data[field_name] - - is_tensor = isinstance(field_data, torch.Tensor) - is_nested = is_tensor and field_data.is_nested - - if is_nested: - unbound = field_data.unbind() - first_item = unbound[0] if unbound else None - elif is_tensor: - first_item = field_data[0] if field_data.shape[0] > 0 else None - else: - first_item = field_data[0] if len(field_data) > 0 else None - - is_non_tensor = not isinstance(first_item, torch.Tensor) if first_item is not None else False - - field_meta = { - "dtype": getattr(first_item, "dtype", type(first_item) if first_item is not None else None), - "shape": getattr(first_item, "shape", None) if is_tensor and not is_nested else None, - "is_nested": is_nested, - "is_non_tensor": is_non_tensor, - } - - if is_nested: - field_meta["per_sample_shapes"] = [tuple(t.shape) for t in unbound] - - field_schema[field_name] = field_meta - - return field_schema - def close(self) -> None: """Close all ZMQ sockets and context to prevent resource leaks.""" # Close handshake socket if it exists @@ -557,23 +522,17 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): shapes = [] dtypes = [] custom_backend_meta_list = [] - num_samples = len(metadata) - - # FIXME: Use BatchMeta.get_dtype and .get_shape instead for field_name in sorted(metadata.field_names): - field_meta = metadata.field_schema.get(field_name, {}) - field_shape = field_meta.get("shape") - field_dtype = field_meta.get("dtype") - per_sample_shapes = field_meta.get("per_sample_shapes") - - for index in range(num_samples): - if per_sample_shapes is not None: - shapes.append(per_sample_shapes[index]) - else: - shapes.append(field_shape) - dtypes.append(field_dtype) - custom_backend_meta_list.append(metadata._custom_backend_meta[index].get(field_name, None)) + field_shape = metadata.get_shapes(field_name) + field_dtype = metadata.get_dtypes(field_name) + + shapes.extend(field_shape) + dtypes.extend(field_dtype) + + custom_backend_meta_list.extend( + [metadata._custom_backend_meta[i].get(field_name, None) for i in range(metadata.size)] + ) return shapes, dtypes, custom_backend_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: @@ -590,7 +549,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: custom_backend_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) - field_schema = self._extract_field_schema(data) + field_schema = extract_field_schema(data) per_field_custom_backend_meta: dict[int, dict[str, Any]] = {} if custom_backend_meta: @@ -643,9 +602,12 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: async def clear_data(self, metadata: BatchMeta) -> None: """Remove stored data associated with the given metadata.""" + if not metadata.field_names: - logger.warning("Attempted to clear data, but metadata contains no fields.") - return + raise RuntimeError( + "Fail to clear_data for key-value based backends due to lack of `field_names` in BatchMeta" + ) + keys = self._generate_keys(metadata.field_names, metadata.global_indexes) _, _, custom_meta = self._get_shape_type_custom_backend_meta_list(metadata) self.storage_client.clear(keys=keys, custom_backend_meta=custom_meta) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index f757cbf2..b77e42e2 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -30,7 +30,7 @@ from omegaconf import DictConfig from tensordict import NonTensorStack, TensorDict -from transfer_queue.metadata import BatchMeta +from transfer_queue.metadata import BatchMeta, extract_field_schema from transfer_queue.storage.managers.base import TransferQueueStorageManager from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory from transfer_queue.utils.zmq_utils import ( @@ -252,7 +252,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: if batch_size == 0: return - field_schema = self._extract_field_schema(data) + field_schema = extract_field_schema(data) routing = self._group_by_hash(metadata.global_indexes) tasks = [ From d93733f95d54c0ac1bf77d96f81ab54e323d7f2f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 12:02:30 +0800 Subject: [PATCH 07/13] fix Signed-off-by: 0oshowero0 --- tests/e2e/test_kv_interface_e2e.py | 1 - transfer_queue/client.py | 24 +++++++++---------- transfer_queue/controller.py | 2 ++ transfer_queue/interface.py | 19 +++++++++++---- transfer_queue/metadata.py | 11 +++++---- .../storage/clients/mooncake_client.py | 15 ++++++------ transfer_queue/storage/managers/base.py | 2 +- .../storage/managers/mooncake_manager.py | 6 +++++ .../managers/simple_backend_manager.py | 2 +- transfer_queue/storage/simple_backend.py | 2 +- 10 files changed, 51 insertions(+), 33 deletions(-) diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index a110c882..0a248ea7 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -609,7 +609,6 @@ def test_kv_clear_single_key(self, controller): assert key not in partition.keys_mapping assert other_key in partition.keys_mapping - tq.kv_clear(keys=key, partition_id=partition_id) tq.kv_clear(keys=other_key, partition_id=partition_id) def test_kv_clear_multiple_keys(self, controller): diff --git a/transfer_queue/client.py b/transfer_queue/client.py index a1154c0b..235c9b07 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -231,7 +231,7 @@ async def async_get_meta( ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get_meta response: {response_msg} from controller {self._controller.id}" @@ -307,7 +307,7 @@ async def async_set_custom_meta( ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client set_custom_meta response: {response_msg} from controller {self._controller.id}" @@ -554,7 +554,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: @@ -582,7 +582,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) if response_msg.request_type != ZMQRequestType.GET_PARTITION_META_RESPONSE: @@ -610,7 +610,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): ) await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) if response_msg.request_type != ZMQRequestType.CLEAR_PARTITION_RESPONSE: @@ -661,7 +661,7 @@ async def async_get_consumption_status( try: await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get consumption response: {response_msg} " @@ -723,7 +723,7 @@ async def async_get_production_status( try: await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get production response: {response_msg} " @@ -855,7 +855,7 @@ async def async_reset_consumption( ) try: await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client reset consumption response: {response_msg} " @@ -901,7 +901,7 @@ async def async_get_partition_list( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get partition list response: {response_msg} " @@ -968,7 +968,7 @@ async def async_kv_retrieve_meta( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get kv_retrieve_keys response: {response_msg} " @@ -1029,7 +1029,7 @@ async def async_kv_retrieve_keys( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get kv_retrieve_indexes response: {response_msg} " @@ -1090,7 +1090,7 @@ async def async_kv_list( try: assert socket is not None await socket.send_multipart(request_msg.serialize()) - response_serialized = await socket.recv_multipart() + response_serialized = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( f"[{self.client_id}]: Client get kv_list response: {response_msg} from controller {self._controller.id}" diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 6d003dc4..b0ef7572 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -265,6 +265,8 @@ def remove_samples(self, indexes: list[int]): # All remaining samples have the same shape - update to non-nested self.is_nested = False self.shape = next(iter(remaining_shapes)) + # Clear per-sample shapes since we are no longer nested + self.per_sample_shapes.clear() def to_batch_schema(self, batch_global_indexes: list[int]) -> dict[str, Any]: """Export as a BatchMeta.field_schema-compatible dict for generate_batch_meta.""" diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 386d6073..7c4894cb 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -305,9 +305,6 @@ def close(): global _TRANSFER_QUEUE_CLIENT global _TRANSFER_QUEUE_STORAGE global _TRANSFER_QUEUE_CONTROLLER - if _TRANSFER_QUEUE_CLIENT: - _TRANSFER_QUEUE_CLIENT.close() - _TRANSFER_QUEUE_CLIENT = None try: if _TRANSFER_QUEUE_STORAGE: @@ -321,9 +318,17 @@ def close(): if check.returncode == 0: pids = check.stdout.strip().replace("\n", ", ") logger.warning( - f"mooncake_master process still exists with PID: {pids}. " - f"Consider manually killing mooncake_master." + f"TransferQueue will not stop mooncake_master process with PID: {pids}. " + f"Consider manually killing the mooncake_master." ) + + if _TRANSFER_QUEUE_CLIENT: + ret = _TRANSFER_QUEUE_CLIENT.storage_manager.storage_client._store.remove_all() + if ret < 0: + logger.error("Failed to remove existing keys in mooncake_master.") + else: + logger.info("Successfully removed all existing keys in mooncake_master.") + # os.system('pkill -f "mooncake_master"') # process = value # if process and process.poll() is None: @@ -345,6 +350,10 @@ def close(): except Exception: pass + if _TRANSFER_QUEUE_CLIENT: + _TRANSFER_QUEUE_CLIENT.close() + _TRANSFER_QUEUE_CLIENT = None + if _TRANSFER_QUEUE_CONTROLLER: try: ray.kill(_TRANSFER_QUEUE_CONTROLLER) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 42d24d25..6e6cf627 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -171,11 +171,12 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: if isinstance(value, torch.Tensor) and not is_nested and value.shape[0] > 0: # Check if first dimension is batch dimension assert value.shape[0] == batch_size - if len(value.shape) > 1: - # Multi-dim tensor: shape = value.shape[1:] - sample_shape = value.shape[1:] - else: - sample_shape = torch.Size([1]) + # if len(value.shape) > 1: + # # Multi-dim tensor: shape = value.shape[1:] + # sample_shape = value.shape[1:] + # else: + # sample_shape = torch.Size([1]) + sample_shape = value.shape[1:] else: sample_shape = getattr(first_item, "shape", None) if first_item is not None else None diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index be4fe675..3b8f8989 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -218,9 +218,10 @@ def _batch_get_tensors(self, keys: list[str], shapes: list, dtypes: list) -> lis if tensor is None: raise RuntimeError(f"batch_get_tensor returned None for key '{batch_keys[j]}'") if tensor.shape != torch.Size(shape): - raise RuntimeError( - f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" - ) + if not (tensor.shape == torch.Size([1]) and torch.Size(shape) == torch.Size([])): + raise RuntimeError( + f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" + ) if tensor.dtype != dtype: raise RuntimeError( f"Dtype mismatch for key '{batch_keys[j]}': expected {dtype}, got {tensor.dtype}" @@ -247,11 +248,11 @@ def clear(self, keys: list[str], custom_backend_meta=None): keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - global_indexes = [key.split("@")[0] + "@*" for key in keys] - for gid in global_indexes: - ret = self._store.remove_by_regex(gid, force=True) + global_indexes_patterns = [key.split("@")[0] + "@.*" for key in keys] + for p in global_indexes_patterns: + ret = self._store.remove_by_regex(p, force=True) if ret < 0: - logger.warning(f"remove failed for key '{gid}' with error code: {ret}") + logger.warning(f"remove failed for key '{p}' with error code: {ret}") # FIXME: controller returned BatchMeta may have mismatched fields in some case, preventing # key-value based backends to accurately clear all existing keys.. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 26c92f50..d33c591f 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -243,7 +243,7 @@ async def notify_data_update( while not response_received and timeout > 0: try: poll_interval = min(TQ_STORAGE_POLLER_TIMEOUT, timeout) - messages = await asyncio.wait_for(sock.recv_multipart(), timeout=poll_interval) + messages = await asyncio.wait_for(sock.recv_multipart(copy=False), timeout=poll_interval) response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ACK: # type: ignore[arg-type] diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 7b8219c4..705873fc 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -38,4 +38,10 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): elif client_name != "MooncakeStoreClient": raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStoreClient'") + logger.warning( + "MooncakeStore backend doesn't support key update (upsert) for now. " + "You must delete the key before updating it. " + "Refer to https://github.com/kvcache-ai/Mooncake/issues/1645 for details." + ) + super().__init__(controller_info, config) diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index b77e42e2..3d66dcb8 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -305,7 +305,7 @@ async def _put_to_single_storage_unit( try: data = request_msg.serialize() await socket.send_multipart(data, copy=False) - messages = await socket.recv_multipart() + messages = await socket.recv_multipart(copy=False) response_msg = ZMQMessage.deserialize(messages) if response_msg.request_type != ZMQRequestType.PUT_DATA_RESPONSE: diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index d139f6b2..e6908e53 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -280,7 +280,7 @@ def _worker_routine(self) -> None: if worker_socket in socks: # Messages received from proxy: [identity, serialized_msg_frame1, ...] - messages = worker_socket.recv_multipart() + messages = worker_socket.recv_multipart(copy=False) identity = messages[0] serialized_msg = messages[1:] From 6f0e5a533b8ada7a0f113b4319b609fa23ca4105 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 14:13:48 +0800 Subject: [PATCH 08/13] fix Signed-off-by: 0oshowero0 --- tests/test_controller_data_partitions.py | 5 ++- transfer_queue/interface.py | 37 ++++++++++--------- transfer_queue/metadata.py | 17 +++++---- .../storage/clients/mooncake_client.py | 9 ++--- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index a5776924..904f4320 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -1057,10 +1057,11 @@ def test_remove_samples(self): fm = FieldMeta(is_nested=True) fm.per_sample_shapes = {0: (3,), 1: (5,), 2: (7,)} fm.remove_samples([0, 2]) - assert fm.per_sample_shapes == {1: (5,)} + assert fm.per_sample_shapes == {} + assert fm.shape == (5,) + assert not fm.is_nested # Removing non-existent index should not raise fm.remove_samples([99]) - assert fm.per_sample_shapes == {1: (5,)} def test_to_batch_schema_regular(self): from transfer_queue.controller import FieldMeta diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 7c4894cb..4698e7b5 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -138,19 +138,17 @@ def _maybe_create_transferqueue_storage(conf: DictConfig) -> DictConfig: ] log_file_path = "/tmp/mooncake_master.log" - log_file = open(log_file_path, "w") - - process = subprocess.Popen( - cmd, - stdout=log_file, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True, - start_new_session=True, - ) - - time.sleep(3) + with open(log_file_path, "w") as log_file: + process = subprocess.Popen( + cmd, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + start_new_session=True, + ) + time.sleep(3) if process.poll() is None: logger.info( @@ -323,11 +321,14 @@ def close(): ) if _TRANSFER_QUEUE_CLIENT: - ret = _TRANSFER_QUEUE_CLIENT.storage_manager.storage_client._store.remove_all() - if ret < 0: - logger.error("Failed to remove existing keys in mooncake_master.") - else: - logger.info("Successfully removed all existing keys in mooncake_master.") + try: + ret = _TRANSFER_QUEUE_CLIENT.storage_manager.storage_client._store.remove_all() + if ret < 0: + logger.error("Failed to remove existing keys in mooncake_master.") + else: + logger.info("Successfully removed all existing keys in mooncake_master.") + except Exception: + pass # os.system('pkill -f "mooncake_master"') # process = value diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 6e6cf627..fdb3adcc 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -169,13 +169,16 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: # When TensorDict converts a Python list to tensor, the first dimension equals batch_size # We need to strip this batch dimension to get per-sample shape if isinstance(value, torch.Tensor) and not is_nested and value.shape[0] > 0: - # Check if first dimension is batch dimension - assert value.shape[0] == batch_size - # if len(value.shape) > 1: - # # Multi-dim tensor: shape = value.shape[1:] - # sample_shape = value.shape[1:] - # else: - # sample_shape = torch.Size([1]) + if value.shape[0] != batch_size: + raise ValueError( + f"Inconsistent batch dimension for field '{field_name}': " + f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}" + ) + if len(value.shape) > 1: + sample_shape = value.shape[1:] + else: + # When input is 1D tensor, manually set to torch.Size([1]). + sample_shape = torch.Size([1]) sample_shape = value.shape[1:] else: sample_shape = getattr(first_item, "shape", None) if first_item is not None else None diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 3b8f8989..a6273210 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -218,10 +218,9 @@ def _batch_get_tensors(self, keys: list[str], shapes: list, dtypes: list) -> lis if tensor is None: raise RuntimeError(f"batch_get_tensor returned None for key '{batch_keys[j]}'") if tensor.shape != torch.Size(shape): - if not (tensor.shape == torch.Size([1]) and torch.Size(shape) == torch.Size([])): - raise RuntimeError( - f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" - ) + raise RuntimeError( + f"Shape mismatch for key '{batch_keys[j]}': expected {shape}, got {tensor.shape}" + ) if tensor.dtype != dtype: raise RuntimeError( f"Dtype mismatch for key '{batch_keys[j]}': expected {dtype}, got {tensor.dtype}" @@ -248,7 +247,7 @@ def clear(self, keys: list[str], custom_backend_meta=None): keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - global_indexes_patterns = [key.split("@")[0] + "@.*" for key in keys] + global_indexes_patterns = {key.split("@")[0] + "@.*" for key in keys} for p in global_indexes_patterns: ret = self._store.remove_by_regex(p, force=True) if ret < 0: From 62079db87dd08c4be2fa739322fb0d8877b67c88 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 14:25:38 +0800 Subject: [PATCH 09/13] fix Signed-off-by: 0oshowero0 fix Signed-off-by: 0oshowero0 --- transfer_queue/interface.py | 20 ++------------------ transfer_queue/metadata.py | 1 - transfer_queue/utils/common.py | 2 +- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 4698e7b5..d0fd2f7f 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib.resources as pkg_resources import logging import math import os import subprocess import time +from importlib import resources from typing import Any, Optional from urllib.parse import urlparse @@ -241,8 +241,7 @@ def init(conf: Optional[DictConfig] = None) -> None: # create config final_conf = OmegaConf.create({}, flags={"allow_objects": True}) - with pkg_resources.path("transfer_queue", "config.yaml") as p: - default_conf = OmegaConf.load(p) + default_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml") final_conf = OmegaConf.merge(final_conf, default_conf) if conf: final_conf = OmegaConf.merge(final_conf, conf) @@ -329,21 +328,6 @@ def close(): logger.info("Successfully removed all existing keys in mooncake_master.") except Exception: pass - - # os.system('pkill -f "mooncake_master"') - # process = value - # if process and process.poll() is None: - # try: - # import signal - # pgid = os.getpgid(process.pid) - # os.killpg(pgid, signal.SIGTERM) - # try: - # process.wait(timeout=5) - # except subprocess.TimeoutExpired: - # os.killpg(pgid, signal.SIGKILL) - # process.wait(timeout=5) - # except ProcessLookupError: - # logger.warning(f"MooncakeStore process already exited: pid={process.pid}") else: logger.warning(f"close for _TRANSFER_QUEUE_STORAGE with key {key} is not supported for now.") diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index fdb3adcc..24248e97 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -179,7 +179,6 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]: else: # When input is 1D tensor, manually set to torch.Size([1]). sample_shape = torch.Size([1]) - sample_shape = value.shape[1:] else: sample_shape = getattr(first_item, "shape", None) if first_item is not None else None diff --git a/transfer_queue/utils/common.py b/transfer_queue/utils/common.py index e25f6b09..a9d2b935 100644 --- a/transfer_queue/utils/common.py +++ b/transfer_queue/utils/common.py @@ -65,7 +65,7 @@ def limit_pytorch_auto_parallel_threads(target_num_threads: Optional[int] = None target_num_threads = physical_cores if target_num_threads > physical_cores: - logger.error( + logger.warning( f"target_num_threads {target_num_threads} should not exceed total " f"physical CPU cores {physical_cores}. Setting to {physical_cores}." ) From ab12d7b7d3009bdefc224ce49f6ff4671ba49b69 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 15:17:43 +0800 Subject: [PATCH 10/13] try fix ci Signed-off-by: 0oshowero0 --- .github/workflows/python-package.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 9629df75..31c8e975 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -47,5 +47,7 @@ jobs: - name: Test with pytest run: | pytest tests - TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/ + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py + pkill -f "mooncake_master" + TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py pkill -f "mooncake_master" \ No newline at end of file From 7b1790897d1c33e1341a5936fa2f4832996ac214 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 15:50:15 +0800 Subject: [PATCH 11/13] clean legacy config check Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/mooncake_manager.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 705873fc..62dd06e2 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -30,14 +30,6 @@ class MooncakeStorageManager(KVStorageManager): """Storage manager for MooncakeStorage backend.""" def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): - client_name = config.get("client_name", None) - - if client_name is None: - logger.info("Missing 'client_name' in config, using default value('MooncakeStoreClient')") - config["client_name"] = "MooncakeStoreClient" - elif client_name != "MooncakeStoreClient": - raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'MooncakeStoreClient'") - logger.warning( "MooncakeStore backend doesn't support key update (upsert) for now. " "You must delete the key before updating it. " From def924f78f07e5858eb3d90663661529a92e4a97 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 15:59:47 +0800 Subject: [PATCH 12/13] fix Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/mooncake_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index 62dd06e2..f370e657 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -36,4 +36,5 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): "Refer to https://github.com/kvcache-ai/Mooncake/issues/1645 for details." ) + config["client_name"] = "MooncakeStorageClient" super().__init__(controller_info, config) From b154fad92ba15f180a6b1eb5381fc9ae5d551177 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 16 Mar 2026 16:09:01 +0800 Subject: [PATCH 13/13] fix Signed-off-by: 0oshowero0 --- transfer_queue/storage/managers/mooncake_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transfer_queue/storage/managers/mooncake_manager.py b/transfer_queue/storage/managers/mooncake_manager.py index f370e657..a24ffafd 100644 --- a/transfer_queue/storage/managers/mooncake_manager.py +++ b/transfer_queue/storage/managers/mooncake_manager.py @@ -36,5 +36,5 @@ def __init__(self, controller_info: ZMQServerInfo, config: dict[str, Any]): "Refer to https://github.com/kvcache-ai/Mooncake/issues/1645 for details." ) - config["client_name"] = "MooncakeStorageClient" + config["client_name"] = "MooncakeStoreClient" super().__init__(controller_info, config)