From e861d15bd2fa2a5e9cc2a24479f4e35616baa140 Mon Sep 17 00:00:00 2001 From: xupinjie Date: Tue, 30 Jun 2026 01:26:34 -0700 Subject: [PATCH] [feat] support GDR in mooncake backend Signed-off-by: xupinjie --- pyproject.toml | 3 +- scripts/performance_test/perftest.py | 42 +- scripts/performance_test/perftest_config.yaml | 2 + tests/e2e/test_e2e_lifecycle_consistency.py | 147 +++++-- tests/test_mooncake_utils.py | 394 ++++++++++++++++++ .../storage/clients/mooncake_client.py | 242 +++++++++-- transfer_queue/utils/mooncake_utils.py | 224 ++++++++++ 7 files changed, 988 insertions(+), 66 deletions(-) create mode 100644 tests/test_mooncake_utils.py create mode 100644 transfer_queue/utils/mooncake_utils.py diff --git a/pyproject.toml b/pyproject.toml index 20b148be..7ab0ba14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,8 @@ yuanrong = [ "openyuanrong-datasystem" ] mooncake = [ - "mooncake-transfer-engine==0.3.10.post2" + "mooncake-transfer-engine==0.3.10.post2", + "cuda-python", ] # If you need to mimic `package_dir={'': '.'}`: diff --git a/scripts/performance_test/perftest.py b/scripts/performance_test/perftest.py index 7abe1ddb..fef6005f 100644 --- a/scripts/performance_test/perftest.py +++ b/scripts/performance_test/perftest.py @@ -212,12 +212,17 @@ class TQClientActor: def __init__(self, config: dict[str, Any], use_complex_case: bool = False): self.config = config self.use_complex_case = use_complex_case + mooncake_cfg = config.get("backend", {}).get("MooncakeStore", {}) + self.use_gdr = bool(mooncake_cfg.get("use_gdr", False)) + self.gdr_device = "cuda:0" self.test_data = None self.total_data_size_gb = 0.0 self.test_keys = None def initialize(self) -> None: """Initialize transfer_queue with the config.""" + if self.use_gdr: + torch.cuda.set_device(self.gdr_device) tq.init(OmegaConf.create(self.config)) def create_test_case( @@ -249,11 +254,16 @@ def list_keys(self, partition_id: str) -> list[str]: return list(partition_info[partition_id].keys()) return [] - def get_data(self, partition_id: str, keys: list[str] | None = None) -> None: + def get_data(self, partition_id: str, keys: list[str] | None = None, move_to_gpu: bool = False) -> None: """Get data from storage using kv_batch_get.""" if keys is None: keys = self.test_keys - tq.kv_batch_get(keys=keys, partition_id=partition_id) + result = tq.kv_batch_get(keys=keys, partition_id=partition_id) + if move_to_gpu: + cpu_tensors = [v for v in result.values() if torch.is_tensor(v) and not v.is_cuda] + torch.cuda.synchronize() + _ = [t.to(self.gdr_device) for t in cpu_tensors] + torch.cuda.synchronize() def delete(self, partition_id: str, keys: list[str] | None = None) -> None: """Delete data from storage using kv_clear.""" @@ -316,6 +326,9 @@ def __init__( # Get backend from config self.backend = self.full_config["backend"]["storage_backend"] + # GDR is configured via backend.MooncakeStore.use_gdr (no separate CLI flag). + self.use_gdr = bool(self.full_config["backend"].get("MooncakeStore", {}).get("use_gdr", False)) + # For Yuanrong, always use inter_node self.use_inter_node = self.backend == "Yuanrong" @@ -331,6 +344,18 @@ def _validate_args(self) -> None: if self.use_inter_node and self.worker_node_ip is None: raise ValueError("worker_node_ip is required for Yuanrong backend") + # GDR only applies to MooncakeStore on GPU; reject other combos up front. + if self.use_gdr: + if self.backend != "MooncakeStore": + raise ValueError( + f"backend.MooncakeStore.use_gdr=true requires the MooncakeStore backend, got '{self.backend}'." + ) + if self.device != "gpu": + raise ValueError( + f"backend.MooncakeStore.use_gdr=true requires --device gpu " + f"(CUDA tensors are needed for the GDR path), got '{self.device}'." + ) + def _prepare_config(self) -> dict[str, Any]: """Prepare the config by directly reading the backend_config file. @@ -393,9 +418,9 @@ def _initialize_clients(self) -> None: # Initialize transfer_queue logger.info(f"Using {self.backend} as storage backend.") - w = self.writer.initialize.remote() - r = self.reader.initialize.remote() - ray.get([w, r]) + # Writer first: ensures storage bootstrap binds to the head address before reader attaches. + ray.get(self.writer.initialize.remote()) + ray.get(self.reader.initialize.remote()) def run_throughput_test(self, skip_dataset_create=False) -> dict[str, Any]: """Run the throughput test and print results. @@ -438,10 +463,11 @@ def run_throughput_test(self, skip_dataset_create=False) -> dict[str, Any]: time.sleep(2) - # GET_DATA operation using kv_batch_get + # GET_DATA operation using kv_batch_get; move_to_gpu adds H2D into get_time + move_to_gpu = self.device == "gpu" and not self.use_gdr logger.info("Starting GET_DATA operation (kv_batch_get)...") start_get_data = time.perf_counter() - ray.get(self.reader.get_data.remote(partition_id=partition_id, keys=keys)) + ray.get(self.reader.get_data.remote(partition_id=partition_id, keys=keys, move_to_gpu=move_to_gpu)) end_get_data = time.perf_counter() get_time = end_get_data - start_get_data get_gbit_per_sec = (self.total_data_size_gb * 8) / get_time @@ -462,6 +488,7 @@ def run_throughput_test(self, skip_dataset_create=False) -> dict[str, Any]: logger.info("=" * 60) logger.info(f"Backend: {self.backend}") logger.info(f"Device: {self.device}") + logger.info(f"GDR: {self.use_gdr}") logger.info(f"Total Data Size: {self.total_data_size_gb:.6f} GB") logger.info(f"PUT Time: {put_time:.8f}s") logger.info(f"GET Time: {get_time:.8f}s") @@ -474,6 +501,7 @@ def run_throughput_test(self, skip_dataset_create=False) -> dict[str, Any]: return { "backend": self.backend, "device": self.device, + "use_gdr": self.use_gdr, "total_data_size_gb": self.total_data_size_gb, "put_time": put_time, "get_time": get_time, diff --git a/scripts/performance_test/perftest_config.yaml b/scripts/performance_test/perftest_config.yaml index 310653b4..98517090 100644 --- a/scripts/performance_test/perftest_config.yaml +++ b/scripts/performance_test/perftest_config.yaml @@ -53,6 +53,8 @@ backend: # Network device name. # Set to "" to let Mooncake auto-select available devices. device_name: "" + # GPU Direct RDMA. When true, CUDA tensors are transferred directly from GPU memory + use_gdr: false # For Yuanrong: Yuanrong: diff --git a/tests/e2e/test_e2e_lifecycle_consistency.py b/tests/e2e/test_e2e_lifecycle_consistency.py index ab8b8373..40b2e71c 100644 --- a/tests/e2e/test_e2e_lifecycle_consistency.py +++ b/tests/e2e/test_e2e_lifecycle_consistency.py @@ -85,6 +85,29 @@ }, }, }, + # MooncakeStore with GDR staging buffer enabled. + # GPU tensors are transferred via the persistent staging buffer path; + # CPU tensors and non-tensors fall through to the original CPU RDMA path. + # auto_init=true: TQ starts mooncake_master automatically on transfer_queue.init(). + "MooncakeStore_GDR": { + "controller": { + "polling_mode": True, + }, + "backend": { + "storage_backend": "MooncakeStore", + "MooncakeStore": { + "global_segment_size": 134217728, # 128MB + "local_buffer_size": 134217728, # 128MB + "metadata_server": "localhost:50050", + "master_server_address": "localhost:50051", + "protocol": "tcp", + "device_name": "", + "auto_init": True, + "use_gdr": True, + "gdr_staging_buffer_mb": 256, + }, + }, + }, } @@ -102,12 +125,13 @@ def backend_name(): """Get the backend name from environment variable. Environment variables: - TQ_TEST_BACKEND: Backend name (SimpleStorage, MooncakeStore, or Yuanrong) + TQ_TEST_BACKEND: Backend name (SimpleStorage, MooncakeStore, Yuanrong, or MooncakeStore_GDR) To run tests for a specific backend: TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py TQ_TEST_BACKEND=Yuanrong pytest tests/e2e/test_e2e_lifecycle_consistency.py + TQ_TEST_BACKEND=MooncakeStore_GDR pytest tests/e2e/test_e2e_lifecycle_consistency.py """ return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") @@ -132,7 +156,7 @@ def e2e_client(ray_cluster, backend_name): transfer_queue.close() -def generate_complex_data(indices: list[int]) -> TensorDict: +def generate_complex_data(indices: list[int], device: torch.device | None = None) -> TensorDict: """Generate complex TensorDict with all supported field types.""" n = len(indices) @@ -185,6 +209,15 @@ def generate_complex_data(indices: list[int]) -> TensorDict: # List of objects (dicts) list_obj = [{"key": f"value_{i}", "num": i} for i in indices] + if device is not None: + tensor_f32 = tensor_f32.to(device) + tensor_i64 = tensor_i64.to(device) + tensor_bf16 = tensor_bf16.to(device) + tensor_f16 = tensor_f16.to(device) + nested_jagged = nested_jagged.to(device) + nested_strided = nested_strided.to(device) + special_val = special_val.to(device) + field_values = { "tensor_f32": tensor_f32, "tensor_i64": tensor_i64, @@ -237,6 +270,7 @@ def verify_special_values(retrieved: torch.Tensor, expected: torch.Tensor) -> bo if len(retrieved) != len(expected): return False for r, e in zip(retrieved, expected, strict=True): + r, e = r.cpu(), e.cpu() # Check Inf column if not (torch.isinf(r[0]) and r[0] > 0): return False @@ -256,6 +290,7 @@ def verify_nested_tensor_equal(retrieved, expected) -> bool: if len(r_list) != len(e_list): return False for r, e in zip(r_list, e_list, strict=True): + r, e = r.cpu(), e.cpu() # Handle NaN: positions must match r_nan = torch.isnan(r) e_nan = torch.isnan(e) @@ -350,7 +385,7 @@ def recover_local_index(global_index_order, new_global_index_order): # Scenario One: Core Read/Write Consistency -def test_core_consistency(e2e_client): +def test_core_consistency(e2e_client, backend_name): """Put full complex data then get - verify all field types are correctly round-tripped.""" client = e2e_client partition_id = "test_core_consistency" @@ -359,7 +394,8 @@ def test_core_consistency(e2e_client): # 1. Put full complex data indices = list(range(batch_size)) - original_data = generate_complex_data(indices) + device = torch.device("cuda") if backend_name == "MooncakeStore_GDR" and torch.cuda.is_available() else None + original_data = generate_complex_data(indices, device=device) fields = DEFAULT_FIELDS meta = client.put(data=original_data, partition_id=partition_id) @@ -372,16 +408,16 @@ def test_core_consistency(e2e_client): # 3. Verify Standard Tensors (may be returned as nested tensors) for i in range(batch_size): - assert torch.allclose(retrieved_data["tensor_f32"][i], original_data["tensor_f32"][i]), ( + assert torch.allclose(retrieved_data["tensor_f32"][i].cpu(), original_data["tensor_f32"][i].cpu()), ( f"tensor_f32 mismatch at index {i}" ) - assert torch.equal(retrieved_data["tensor_i64"][i], original_data["tensor_i64"][i]), ( + assert torch.equal(retrieved_data["tensor_i64"][i].cpu(), original_data["tensor_i64"][i].cpu()), ( f"tensor_i64 mismatch at index {i}" ) - assert torch.equal(retrieved_data["tensor_bf16"][i], original_data["tensor_bf16"][i]), ( + assert torch.equal(retrieved_data["tensor_bf16"][i].cpu(), original_data["tensor_bf16"][i].cpu()), ( f"tensor_bf16 mismatch at index {i}" ) - assert torch.equal(retrieved_data["tensor_f16"][i], original_data["tensor_f16"][i]), ( + assert torch.equal(retrieved_data["tensor_f16"][i].cpu(), original_data["tensor_f16"][i].cpu()), ( f"tensor_f16 mismatch at index {i}" ) @@ -402,7 +438,7 @@ def test_core_consistency(e2e_client): # 7. Verify NumPy Arrays (may be returned as nested tensors) for i in range(batch_size): - assert np.allclose(retrieved_data["np_array"][i].numpy(), original_data["np_array"][i]), ( + assert np.allclose(retrieved_data["np_array"][i].cpu().numpy(), original_data["np_array"][i]), ( f"np_array mismatch at index {i}" ) @@ -442,12 +478,13 @@ def test_core_consistency(e2e_client): # Scenario Two: Cross-Shard Update -def test_cross_shard_complex_update(e2e_client): +def test_cross_shard_complex_update(e2e_client, backend_name): """Cross-shard update: put A + put B, update overlapping region, verify all regions.""" client = e2e_client partition_id = "test_cross_shard_update" task_name = "cross_shard_task" + device = torch.device("cuda") if backend_name == "MooncakeStore_GDR" and torch.cuda.is_available() else None # Define index ranges idx_a = list(range(0, 20)) # Put A @@ -467,18 +504,18 @@ def test_cross_shard_complex_update(e2e_client): try: # 2. Put A: indices 0-19 - data_a = generate_complex_data(idx_a) + data_a = generate_complex_data(idx_a, device=device) meta_a = alloc_meta.select_samples(list(range(0, 20))) client.put(data=data_a, metadata=meta_a) # 3. Put B: indices 20-39 - data_b = generate_complex_data(idx_b) + data_b = generate_complex_data(idx_b, device=device) meta_b = alloc_meta.select_samples(list(range(20, 40))) client.put(data=data_b, metadata=meta_b) # 4. Update indices 10-29 with modified values and new fields modified_indices = [i + 1000 for i in idx_update] # Offset to make values distinguishable - data_update = generate_complex_data(modified_indices) + data_update = generate_complex_data(modified_indices, device=device) # Add new fields new_extra_tensor = torch.stack([torch.ones(3) * i for i in idx_update]) # Shape: (20, 3) @@ -505,23 +542,23 @@ def test_cross_shard_complex_update(e2e_client): full_data = _reorder_tensordict(full_data, sorted_order) # 6. Verify region 0-9: original Put A values - original_data_0_9 = generate_complex_data(list(range(0, 10))) + original_data_0_9 = generate_complex_data(list(range(0, 10)), device=device) for i in range(10): - assert torch.allclose(full_data["tensor_f32"][i], original_data_0_9["tensor_f32"][i]), ( + assert torch.allclose(full_data["tensor_f32"][i].cpu(), original_data_0_9["tensor_f32"][i].cpu()), ( f"Region 0-9 tensor_f32 mismatch at index {i}" ) # 7. Verify region 10-29: updated values (using offset indices 1010-1029) - updated_expected = generate_complex_data([i + 1000 for i in range(10, 30)]) + updated_expected = generate_complex_data([i + 1000 for i in range(10, 30)], device=device) for i in range(20): - assert torch.allclose(full_data["tensor_f32"][10 + i], updated_expected["tensor_f32"][i]), ( + assert torch.allclose(full_data["tensor_f32"][10 + i].cpu(), updated_expected["tensor_f32"][i].cpu()), ( f"Region 10-29 tensor_f32 mismatch at index {10 + i}" ) # 8. Verify region 30-39: original Put B values - original_data_30_39 = generate_complex_data(list(range(30, 40))) + original_data_30_39 = generate_complex_data(list(range(30, 40)), device=device) for i in range(10): - assert torch.allclose(full_data["tensor_f32"][30 + i], original_data_30_39["tensor_f32"][i]), ( + assert torch.allclose(full_data["tensor_f32"][30 + i].cpu(), original_data_30_39["tensor_f32"][i].cpu()), ( f"Region 30-39 tensor_f32 mismatch at index {30 + i}" ) @@ -818,7 +855,7 @@ def test_dynamic_tensor_shape_nested_transition(e2e_client): # Scenario Seven: Retrieved Data Writability and Memory Safety -def test_retrieved_data_writability_and_memory_safety(e2e_client): +def test_retrieved_data_writability_and_memory_safety(e2e_client, backend_name): """Verify that all data types retrieved via GET are writable and memory-independent. This test validates the ZMQ copy=False GET path (Plan 1): @@ -832,9 +869,10 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): batch_size = 8 task_name = "writability_task" fields = DEFAULT_FIELDS + device = torch.device("cuda") if backend_name == "MooncakeStore_GDR" and torch.cuda.is_available() else None indices = list(range(batch_size)) - original_data = generate_complex_data(indices) + original_data = generate_complex_data(indices, device=device) original_meta = client.put(data=original_data, partition_id=partition_id) global_index_order = original_meta.global_indexes @@ -901,25 +939,25 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): # tensor_f32[0,0] should be the original value, not 99999.0 for i in range(batch_size): - assert torch.allclose(retrieved2["tensor_f32"][i], original_data["tensor_f32"][i]), ( + assert torch.allclose(retrieved2["tensor_f32"][i].cpu(), original_data["tensor_f32"][i].cpu()), ( "Modifying retrieved tensor_f32 should not affect stored data" ) # tensor_i64[0,0] should be the original value, not 88888 for i in range(batch_size): - assert torch.equal(retrieved2["tensor_i64"][i], original_data["tensor_i64"][i]), ( + assert torch.equal(retrieved2["tensor_i64"][i].cpu(), original_data["tensor_i64"][i].cpu()), ( "Modifying retrieved tensor_i64 should not affect stored data" ) # tensor_bf16 should match original for i in range(batch_size): - assert torch.equal(retrieved2["tensor_bf16"][i], original_data["tensor_bf16"][i]), ( + assert torch.equal(retrieved2["tensor_bf16"][i].cpu(), original_data["tensor_bf16"][i].cpu()), ( "Modifying retrieved tensor_bf16 should not affect stored data" ) # tensor_f16 should match original for i in range(batch_size): - assert torch.equal(retrieved2["tensor_f16"][i], original_data["tensor_f16"][i]), ( + assert torch.equal(retrieved2["tensor_f16"][i].cpu(), original_data["tensor_f16"][i].cpu()), ( "Modifying retrieved tensor_f16 should not affect stored data" ) @@ -932,5 +970,64 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client): client.clear_partition(partition_id) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GDR test") +def test_gdr_gpu_tensor_roundtrip(e2e_client, backend_name): + """GDR path: put GPU tensors, get back and verify data, clear and verify gone. + + Only meaningful when backend_name == "MooncakeStore_GDR"; skipped otherwise. + Tests the full staging-buffer path: D2D pack → RDMA PUT → RDMA GET → D2D unpack. + """ + if backend_name != "MooncakeStore_GDR": + pytest.skip("GDR roundtrip only runs with TQ_TEST_BACKEND=MooncakeStore_GDR") + + client = e2e_client + partition_id = "test_gdr_gpu_tensor_roundtrip" + task_name = "gdr_roundtrip_task" + batch_size = 8 + device = torch.device("cuda", torch.cuda.current_device()) + + # Build a TensorDict with GPU tensors only (GDR path). + indices = list(range(batch_size)) + gpu_data = TensorDict( + { + "gpu_f32": torch.stack([torch.arange(i, i + 64, dtype=torch.float32, device=device) for i in indices]), + "gpu_bf16": torch.stack( + [torch.full((32,), float(i), dtype=torch.bfloat16, device=device) for i in indices] + ), + "gpu_f16": torch.stack([torch.linspace(0, i, 16, dtype=torch.float16, device=device) for i in indices]), + "gpu_i64": torch.stack( + [torch.arange(i * 10, i * 10 + 8, dtype=torch.int64, device=device) for i in indices] + ), + }, + batch_size=batch_size, + ) + + meta = client.put(data=gpu_data, partition_id=partition_id) + assert meta.size == batch_size + + try: + # GET: verify data integrity + retrieved_meta = poll_for_meta(client, partition_id, list(gpu_data.keys()), batch_size, task_name, mode="fetch") + assert retrieved_meta is not None and retrieved_meta.size == batch_size, "Failed to retrieve GDR metadata" + retrieved = client.get_data(retrieved_meta) + + for i in range(batch_size): + assert torch.equal(retrieved["gpu_f32"][i].cpu(), gpu_data["gpu_f32"][i].cpu()), ( + f"gpu_f32 mismatch at index {i}" + ) + assert torch.equal(retrieved["gpu_bf16"][i].cpu(), gpu_data["gpu_bf16"][i].cpu()), ( + f"gpu_bf16 mismatch at index {i}" + ) + assert torch.equal(retrieved["gpu_f16"][i].cpu(), gpu_data["gpu_f16"][i].cpu()), ( + f"gpu_f16 mismatch at index {i}" + ) + assert torch.equal(retrieved["gpu_i64"][i].cpu(), gpu_data["gpu_i64"][i].cpu()), ( + f"gpu_i64 mismatch at index {i}" + ) + + finally: + client.clear_partition(partition_id) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", __file__])) diff --git a/tests/test_mooncake_utils.py b/tests/test_mooncake_utils.py new file mode 100644 index 00000000..d0c766a8 --- /dev/null +++ b/tests/test_mooncake_utils.py @@ -0,0 +1,394 @@ +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for transfer_queue.utils.mooncake_utils and MooncakeStoreClient.clear().""" + +import logging +import threading +import time +from unittest.mock import MagicMock + +import pytest +import torch + +from transfer_queue.utils.mooncake_utils import ( + GdrStaging, + _aligned_offsets, + chunk_subkeys, + split_by_bytes, +) + +_DEFAULT_ALIGN = 256 + + +def _aligned(n: int) -> int: + return (n + _DEFAULT_ALIGN - 1) // _DEFAULT_ALIGN * _DEFAULT_ALIGN + + +# =========================================================================== +# _aligned_offsets +# =========================================================================== + + +class TestAlignedOffsets: + def test_empty(self): + offsets, total = _aligned_offsets([]) + assert offsets == [] + assert total == 0 + + def test_single_already_aligned(self): + offsets, total = _aligned_offsets([256]) + assert offsets == [0] + assert total == 256 + + def test_single_unaligned(self): + offsets, total = _aligned_offsets([100]) + assert offsets == [0] + assert total == 256 # ceil(100/256)*256 + + def test_multiple_unaligned(self): + # 100 → pad to 256, 200 → pad to 256, 300 → pad to 512 + offsets, total = _aligned_offsets([100, 200, 300]) + assert offsets == [0, 256, 512] + assert total == 512 + 512 # last slot: 300 → 512 + + def test_exact_multiples(self): + offsets, total = _aligned_offsets([256, 512, 256]) + assert offsets == [0, 256, 768] + assert total == 1024 + + +# =========================================================================== +# chunk_subkeys +# =========================================================================== + + +class TestChunkSubkeys: + def test_fits_exactly(self): + assert chunk_subkeys("k", 1024, 1024) == ["k"] + + def test_fits_below(self): + assert chunk_subkeys("k", 100, 1024) == ["k"] + + def test_oversized_two_chunks(self): + result = chunk_subkeys("k", 1025, 1024) + assert result == ["k:c0", "k:c1"] + + def test_oversized_exact_multiple(self): + result = chunk_subkeys("k", 2048, 1024) + assert result == ["k:c0", "k:c1"] + + def test_oversized_three_chunks(self): + result = chunk_subkeys("k", 2049, 1024) + assert result == ["k:c0", "k:c1", "k:c2"] + + def test_key_format_preserved(self): + result = chunk_subkeys("field@0", 3000, 1024) + assert all(s.startswith("field@0:c") for s in result) + assert [s.split(":c")[1] for s in result] == ["0", "1", "2"] + + def test_zero_bytes_fits(self): + assert chunk_subkeys("k", 0, 1024) == ["k"] + + +# =========================================================================== +# split_by_bytes +# =========================================================================== + + +class TestSplitByBytes: + def test_empty(self): + assert split_by_bytes([], 1024) == [] + + def test_single_fits(self): + groups = split_by_bytes([100], 1024) + assert groups == [[0]] + + def test_all_fit_one_group(self): + # 100+100+100 aligned = 256*3 = 768 <= 1024 + groups = split_by_bytes([100, 100, 100], 1024) + assert len(groups) == 1 + assert sorted(groups[0]) == [0, 1, 2] + + def test_splits_into_two_groups(self): + # 500 aligned=512, 500 aligned=512; 512+512=1024 fits; third pushes to new group + groups = split_by_bytes([500, 500, 500], 1024) + assert len(groups) == 2 + total_indices = sorted(idx for g in groups for idx in g) + assert total_indices == [0, 1, 2] + + def test_oversized_singleton(self): + # 2000 > 1024 → own group + groups = split_by_bytes([2000], 1024) + assert groups == [[0]] + + def test_oversized_in_mixed_batch(self): + # [100, 2000, 100]: the 2000-byte tensor must be its own singleton group + groups = split_by_bytes([100, 2000, 100], 1024) + singleton_groups = [g for g in groups if len(g) == 1] + multi_groups = [g for g in groups if len(g) > 1] + oversized_idx = next(g[0] for g in singleton_groups if g[0] == 1) + assert oversized_idx == 1 + assert sorted(multi_groups[0]) == [0, 2] + + def test_multiple_oversized_each_gets_singleton(self): + groups = split_by_bytes([2000, 3000, 2000], 1024) + assert len(groups) == 3 + assert all(len(g) == 1 for g in groups) + + def test_ascending_sort_prevents_fragmentation(self): + # Without ascending sort, processing [200, 900, 200, 200] in order would produce 3 groups: + # group0=[0], group1=[1(900)], group2=[2,3] + # With ascending sort the three 200-byte tensors pack together first, 900-byte goes last: + # group0=[0,2,3], group1=[1] → only 2 groups + # + # buffer_size=1024; aligned sizes: 200→256, 900→1024 + # 256*3=768 ≤ 1024 (three smalls fit); adding 900's 1024 would overflow → separate group + groups = split_by_bytes([200, 900, 200, 200], 1024) + assert len(groups) == 2 + # The three small-tensor indices must share a group + small_indices = {0, 2, 3} + assert any(small_indices == set(g) for g in groups) + # The large tensor must be alone + assert any(g == [1] for g in groups) + + def test_alignment_boundary(self): + # Two tensors each aligned to exactly buffer_size/2 should fit in one group + # 512 bytes each → aligned=512; 512+512=1024 == buffer_size + groups = split_by_bytes([512, 512], 1024) + assert len(groups) == 1 + assert sorted(groups[0]) == [0, 1] + + def test_all_indices_covered(self): + nbytes = [100, 200, 900, 50, 1200, 300] + groups = split_by_bytes(nbytes, 1024) + covered = sorted(idx for g in groups for idx in g) + assert covered == list(range(len(nbytes))) + + +# =========================================================================== +# GdrStaging – lock only (no CUDA required) +# =========================================================================== + + +class TestGdrStagingLock: + def test_acquire_blocks_concurrent_thread(self): + staging = GdrStaging(1024 * 1024) + entered = threading.Event() + released = threading.Event() + results: list[str] = [] + + def holder(): + with staging.acquire(): + entered.set() + released.wait(timeout=2.0) + results.append("holder_done") + + def waiter(): + entered.wait(timeout=2.0) + with staging.acquire(): + results.append("waiter_in") + + t1 = threading.Thread(target=holder) + t2 = threading.Thread(target=waiter) + t1.start() + t2.start() + + entered.wait(timeout=1.0) + time.sleep(0.05) + assert "waiter_in" not in results # still blocked + + released.set() + t1.join(timeout=1.0) + t2.join(timeout=1.0) + assert results == ["holder_done", "waiter_in"] + + +# =========================================================================== +# GdrStaging – CUDA-dependent tests +# =========================================================================== + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestGdrStagingCuda: + def _mock_store(self): + store = MagicMock() + store.register_buffer.return_value = None + store.unregister_buffer.return_value = None + return store + + def test_lazy_init_idempotent(self): + store = self._mock_store() + staging = GdrStaging(1024 * 1024) + staging.lazy_init(store) + staging.lazy_init(store) + assert store.register_buffer.call_count == 1 + assert staging._initialized + staging.close(store) + + def test_close_resets_state(self): + store = self._mock_store() + staging = GdrStaging(1024 * 1024) + staging.lazy_init(store) + staging.close(store) + assert not staging._initialized + assert store.unregister_buffer.call_count == 1 + + def test_pack_unpack_roundtrip_single(self): + store = self._mock_store() + staging = GdrStaging(4 * 1024 * 1024) + staging.lazy_init(store) + try: + original = torch.arange(1024, dtype=torch.float32, device="cuda") + with staging.acquire(): + sub_ptrs, sizes = staging.pack([original]) + result = staging.unpack(sub_ptrs, sizes, [original.dtype], [tuple(original.shape)], original.device) + assert torch.equal(result[0], original) + finally: + staging.close(store) + + def test_pack_unpack_roundtrip_multiple(self): + store = self._mock_store() + staging = GdrStaging(4 * 1024 * 1024) + staging.lazy_init(store) + try: + tensors = [ + torch.randn(128, dtype=torch.float32, device="cuda"), + torch.randint(0, 100, (64,), dtype=torch.int64, device="cuda"), + torch.ones(256, dtype=torch.float16, device="cuda"), + ] + dtypes = [t.dtype for t in tensors] + shapes = [tuple(t.shape) for t in tensors] + with staging.acquire(): + sub_ptrs, sizes = staging.pack(tensors) + results = staging.unpack(sub_ptrs, sizes, dtypes, shapes, torch.device("cuda")) + for orig, got in zip(tensors, results, strict=True): + assert torch.equal(orig, got) + finally: + staging.close(store) + + def test_pack_contiguous_required(self): + # Non-contiguous tensor is contiguous()-ed by the caller before pack; staging + # itself only receives contiguous tensors. Verify pack/unpack still works. + store = self._mock_store() + staging = GdrStaging(4 * 1024 * 1024) + staging.lazy_init(store) + try: + base = torch.arange(256, dtype=torch.float32, device="cuda").reshape(16, 16) + t = base[:, :8].contiguous() # caller makes contiguous + with staging.acquire(): + sub_ptrs, sizes = staging.pack([t]) + result = staging.unpack(sub_ptrs, sizes, [t.dtype], [tuple(t.shape)], t.device) + assert torch.equal(result[0], t) + finally: + staging.close(store) + + +# =========================================================================== +# MooncakeStoreClient.clear() – sub-key expansion and cleanup +# =========================================================================== + + +def _make_clear_client(use_gdr: bool = True): + """Construct a minimal MooncakeStoreClient-like object for clear() testing. + + Bypasses __init__ to avoid needing a real Mooncake store connection. + """ + from transfer_queue.storage.clients.mooncake_client import MooncakeStoreClient + + client = object.__new__(MooncakeStoreClient) + client._store = MagicMock() + # Default: all batch_remove calls succeed + client._store.batch_remove.side_effect = lambda keys, force: [0] * len(keys) + client._gdr_staging = MagicMock() if use_gdr else None + return client + + +class TestClear: + def test_no_gdr_keys_pass_through(self): + client = _make_clear_client(use_gdr=False) + keys = ["a", "b", "c"] + client.clear(keys) + client._store.batch_remove.assert_called_once_with(keys, force=True) + + def test_gdr_all_normal_meta_no_expansion(self): + client = _make_clear_client(use_gdr=True) + keys = ["a", "b"] + client.clear(keys, custom_backend_meta=[None, None]) + client._store.batch_remove.assert_called_once_with(["a", "b"], force=True) + + def test_gdr_single_chunked_key_expands(self): + client = _make_clear_client(use_gdr=True) + client.clear(["big"], custom_backend_meta=[{"n_chunks": 3}]) + client._store.batch_remove.assert_called_once_with(["big:c0", "big:c1", "big:c2"], force=True) + + def test_gdr_mixed_chunked_and_normal(self): + client = _make_clear_client(use_gdr=True) + client.clear(["normal", "big"], custom_backend_meta=[None, {"n_chunks": 2}]) + client._store.batch_remove.assert_called_once_with(["normal", "big:c0", "big:c1"], force=True) + + def test_gdr_multiple_chunked_keys_all_expanded(self): + client = _make_clear_client(use_gdr=True) + keys = ["a", "b", "c"] + meta = [{"n_chunks": 2}, {"n_chunks": 3}, None] + client.clear(keys, custom_backend_meta=meta) + client._store.batch_remove.assert_called_once_with(["a:c0", "a:c1", "b:c0", "b:c1", "b:c2", "c"], force=True) + + def test_gdr_no_subkeys_leaked_after_clear(self): + # All keys (original + sub-keys) must appear exactly once in batch_remove; + # none should be silently dropped. + client = _make_clear_client(use_gdr=True) + keys = ["x", "y"] + meta = [{"n_chunks": 4}, {"n_chunks": 2}] + client.clear(keys, custom_backend_meta=meta) + call_args = client._store.batch_remove.call_args + removed = call_args[0][0] + assert removed == ["x:c0", "x:c1", "x:c2", "x:c3", "y:c0", "y:c1"] + # No original keys leaked through + assert "x" not in removed + assert "y" not in removed + + def test_gdr_meta_none_warns_and_uses_original_keys(self, caplog): + client = _make_clear_client(use_gdr=True) + keys = ["k0", "k1"] + with caplog.at_level(logging.WARNING, logger="transfer_queue.storage.clients.mooncake_client"): + client.clear(keys, custom_backend_meta=None) + assert "custom_backend_meta is None" in caplog.text + client._store.batch_remove.assert_called_once_with(keys, force=True) + + def test_no_gdr_meta_none_no_warning(self, caplog): + client = _make_clear_client(use_gdr=False) + with caplog.at_level(logging.WARNING): + client.clear(["k0"], custom_backend_meta=None) + assert "custom_backend_meta" not in caplog.text + + def test_error_code_triggers_log(self, caplog): + client = _make_clear_client(use_gdr=False) + client._store.batch_remove.side_effect = lambda keys, force: [-1] * len(keys) + with caplog.at_level(logging.ERROR, logger="transfer_queue.storage.clients.mooncake_client"): + client.clear(["k0"]) + assert "remove failed" in caplog.text + + def test_already_removed_code_704_is_silent(self, caplog): + client = _make_clear_client(use_gdr=False) + client._store.batch_remove.side_effect = lambda keys, force: [-704] * len(keys) + with caplog.at_level(logging.ERROR): + client.clear(["k0"]) + assert "remove failed" not in caplog.text + + def test_success_code_zero_is_silent(self, caplog): + client = _make_clear_client(use_gdr=False) + with caplog.at_level(logging.ERROR): + client.clear(["k0"]) + assert "remove failed" not in caplog.text diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index 297b7cf3..cb519541 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -23,6 +23,12 @@ from transfer_queue.storage.clients.base import StorageClientFactory, StorageKVClient from transfer_queue.utils import serial_utils from transfer_queue.utils.logging_utils import get_logger +from transfer_queue.utils.mooncake_utils import ( + GdrStaging, + _aligned_offsets, + chunk_subkeys, + split_by_bytes, +) from transfer_queue.utils.tensor_utils import allocate_empty_tensors, get_nbytes, merge_contiguous_memory logger = get_logger(__name__) @@ -66,6 +72,16 @@ def __init__(self, config: dict[str, Any]): if self.device_name is None: self.device_name = "" + self.use_gdr = bool(config.get("use_gdr", False)) + # gdr_staging_buffer_mb > 0: use persistent staging buffer (GDR path). + # gdr_staging_buffer_mb = 0: fall back to CPU RDMA path even if use_gdr=True. + self.gdr_staging_buffer_mb = int(config.get("gdr_staging_buffer_mb", 1024)) + buffer_bytes = self.gdr_staging_buffer_mb * 1024 * 1024 + # GdrStaging instance created eagerly but cudaMalloc is deferred to first use. + # Skip GDR if CUDA context is not initialized in this process (e.g. CPU-only workers) + gdr_eligible = self.use_gdr and buffer_bytes > 0 and torch.cuda.is_initialized() + self._gdr_staging: GdrStaging | None = GdrStaging(buffer_bytes) if gdr_eligible else None + if self.local_hostname is None or self.local_hostname == "": from transfer_queue.utils.zmq_utils import get_node_ip_address @@ -117,10 +133,12 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: if len(keys) != len(values): raise ValueError("Number of keys must match number of values") - tensor_keys = [] - tensor_values = [] - non_tensor_keys = [] - non_tensor_values = [] + use_gdr_path = self.use_gdr and self._gdr_staging is not None + + tensor_keys: list[str] = [] + tensor_values: list[Tensor] = [] + non_tensor_keys: list[str] = [] + non_tensor_values: list[Any] = [] for key, value in zip(keys, values, strict=True): if isinstance(value, torch.Tensor): @@ -130,13 +148,18 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: non_tensor_keys.append(key) non_tensor_values.append(value) + gdr_meta: dict[str, dict | None] = {} + if use_gdr_path and tensor_keys: + gdr_meta = dict(zip(tensor_keys, self._put_tensors_gdr(tensor_keys, tensor_values), strict=False)) + tensor_futures: list[Future[None]] = [] bytes_futures: list[Future[list[int]]] = [] with ThreadPoolExecutor(max_workers=MAX_BATCH_WORKER_THREADS) as executor: - for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): - batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] - batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] - tensor_futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) + if not use_gdr_path: + for i in range(0, len(tensor_keys), BATCH_SIZE_LIMIT): + batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_tensors = tensor_values[i : i + BATCH_SIZE_LIMIT] + tensor_futures.append(executor.submit(self._put_tensors_thread_worker, batch_keys, batch_tensors)) for i in range(0, len(non_tensor_keys), BATCH_SIZE_LIMIT): batch_keys = non_tensor_keys[i : i + BATCH_SIZE_LIMIT] @@ -150,19 +173,65 @@ def put(self, keys: list[str], values: list[Any]) -> list[dict | None]: for tf in tensor_futures: tf.result() - # bytes results arrive in non-tensor submit order, which matches the order of - # non-tensor values; walk values once to scatter packed_size back to its key slot. + # Walk keys/values once to scatter results back to original slots. sizes_iter = iter(packed_sizes) custom_backend_meta: list[dict | None] = [ - {"packed_size": next(sizes_iter)} if not isinstance(value, torch.Tensor) else None for value in values + gdr_meta.get(key) if isinstance(value, torch.Tensor) else {"packed_size": next(sizes_iter)} + for key, value in zip(keys, values, strict=False) ] return custom_backend_meta + def _put_tensors_gdr(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> list[dict | None]: + """GDR tensor PUT path using the persistent pre-registered staging buffer. + + split_by_bytes() groups tensors so each group's aligned total fits within the + staging buffer. Oversized tensors (nbytes > buffer_size) get a singleton group + and are stored as :c{i} sub-keys. Normal groups are packed and upserted together. + + Returns per-key meta: None for normal tensors, {"n_chunks": n} for oversized + tensors that were split into :c{i} sub-keys. clear() uses this to expand keys. + """ + assert self._gdr_staging is not None + self._gdr_staging.lazy_init(self._store) + staging = self._gdr_staging + buffer_size = staging.size + + # Caller guarantees all tensors are CUDA; make contiguous outside the lock. + contiguous_tensors = [t.contiguous() for t in batch_tensors] + nbytes = [t.nbytes for t in contiguous_tensors] + groups = split_by_bytes(nbytes, buffer_size) + + meta: list[dict | None] = [None] * len(batch_keys) + + with staging.acquire(): + for idxs in groups: + g_keys = [batch_keys[i] for i in idxs] + g_tensors = [contiguous_tensors[i] for i in idxs] + + if len(idxs) == 1 and g_tensors[0].nbytes > buffer_size: + # Oversized tensor: split into :c{i} sub-keys. + tensor = g_tensors[0] + key = g_keys[0] + sub_keys = chunk_subkeys(key, tensor.nbytes, buffer_size) + memcpy_chunk = staging.memcpy_d2d_async if tensor.is_cuda else staging.memcpy_h2d_async + for i, sub_key in enumerate(sub_keys): + chunk_size = min(buffer_size, tensor.nbytes - i * buffer_size) + memcpy_chunk(staging.ptr, tensor.data_ptr() + i * buffer_size, chunk_size) + staging.synchronize() + self._batch_upsert_with_retry([sub_key], [staging.ptr], [chunk_size]) + meta[idxs[0]] = {"n_chunks": len(sub_keys)} + else: + # Normal group: aligned total fits in buffer; pack and upsert together. + sub_ptrs, sizes = staging.pack(g_tensors) + self._batch_upsert_with_retry(g_keys, sub_ptrs, sizes) + + return meta + def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None: - """Worker thread for putting batch of tensors to MooncakeStore.""" + """Worker thread for putting batch of cpu tensors to MooncakeStore.""" - batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors) + batch_ptrs, batch_sizes, _ = self._preprocess_tensors_for_put(batch_tensors) batch_ptr_reduced, batch_sizes_reduced = merge_contiguous_memory(batch_ptrs, batch_sizes) self._register_all_buffers(batch_ptr_reduced, batch_sizes_reduced) try: @@ -227,26 +296,42 @@ def get( if not (len(keys) == len(shapes) == len(dtypes)): raise ValueError("Lengths of keys, shapes, dtypes must match") - tensor_indices: list[int] = [] - tensor_keys: list[str] = [] - tensor_shapes: list[Any] = [] - tensor_dtypes: list[Any] = [] + use_gdr_path = self.use_gdr and self._gdr_staging is not None + + gpu_tensor_indices: list[int] = [] + gpu_tensor_keys: list[str] = [] + gpu_tensor_shapes: list[Any] = [] + gpu_tensor_dtypes: list[Any] = [] + cpu_tensor_indices: list[int] = [] + cpu_tensor_keys: list[str] = [] + cpu_tensor_shapes: list[Any] = [] + cpu_tensor_dtypes: list[Any] = [] non_tensor_indices: list[int] = [] non_tensor_keys: list[str] = [] non_tensor_packed_sizes: list[int] = [] for i, dtype in enumerate(dtypes): if dtype is not None: - tensor_indices.append(i) - tensor_keys.append(keys[i]) - tensor_shapes.append(shapes[i]) - tensor_dtypes.append(dtype) + if use_gdr_path: + gpu_tensor_indices.append(i) + gpu_tensor_keys.append(keys[i]) + gpu_tensor_shapes.append(shapes[i]) + gpu_tensor_dtypes.append(dtype) + else: + cpu_tensor_indices.append(i) + cpu_tensor_keys.append(keys[i]) + cpu_tensor_shapes.append(shapes[i]) + cpu_tensor_dtypes.append(dtype) else: non_tensor_indices.append(i) non_tensor_keys.append(keys[i]) - if non_tensor_indices and (custom_backend_meta is None or len(custom_backend_meta) != len(keys)): - raise ValueError("custom_backend_meta with per-key packed_size is required when any dtype is None.") + if (gpu_tensor_indices and use_gdr_path) or non_tensor_indices: + if custom_backend_meta is None or len(custom_backend_meta) != len(keys): + raise ValueError( + "custom_backend_meta is required when GDR is enabled (for n_chunks) " + "or when any dtype is None (for packed_size)." + ) if non_tensor_indices: assert custom_backend_meta is not None @@ -257,13 +342,22 @@ def get( results = [None] * len(keys) + if gpu_tensor_keys: + assert custom_backend_meta is not None + gpu_tensor_meta = [custom_backend_meta[i] for i in gpu_tensor_indices] + retrieved, batch_idx = self._get_tensors_gdr( + gpu_tensor_keys, gpu_tensor_shapes, gpu_tensor_dtypes, gpu_tensor_indices, gpu_tensor_meta + ) + for idx, val in zip(batch_idx, retrieved, strict=True): + results[idx] = val + futures = [] with ThreadPoolExecutor(max_workers=MAX_BATCH_WORKER_THREADS) as executor: - for i in range(0, len(tensor_indices), BATCH_SIZE_LIMIT): - batch_keys = tensor_keys[i : i + BATCH_SIZE_LIMIT] - batch_shapes = tensor_shapes[i : i + BATCH_SIZE_LIMIT] - batch_dtypes = tensor_dtypes[i : i + BATCH_SIZE_LIMIT] - batch_indexes = tensor_indices[i : i + BATCH_SIZE_LIMIT] + for i in range(0, len(cpu_tensor_indices), BATCH_SIZE_LIMIT): + batch_keys = cpu_tensor_keys[i : i + BATCH_SIZE_LIMIT] + batch_shapes = cpu_tensor_shapes[i : i + BATCH_SIZE_LIMIT] + batch_dtypes = cpu_tensor_dtypes[i : i + BATCH_SIZE_LIMIT] + batch_indexes = cpu_tensor_indices[i : i + BATCH_SIZE_LIMIT] futures.append( executor.submit( self._get_tensors_thread_worker, batch_keys, batch_shapes, batch_dtypes, batch_indexes @@ -301,6 +395,74 @@ def _get_tensors_thread_worker( return batch_buffer_tensors, indexes + def _get_tensors_gdr( + self, + batch_keys: list[str], + batch_shapes: list[tuple], + batch_dtypes: list[torch.dtype], + indexes: list[int], + batch_meta: list[dict | None], + ) -> tuple[list[Tensor], list[int]]: + """GDR tensor GET path using the persistent pre-registered staging buffer. + + split_by_bytes() groups tensors so each group's aligned total fits within the + staging buffer. Oversized singleton groups reassemble from :c{i} sub-keys. + Normal groups use a single batch_get_into + unpack. + + NOTE: An alternative design is to skip the staging buffer entirely: for each group, + cudaMalloc a fresh buffer, register it, RDMA GET directly into it, unregister it, + then slice into tensors via torch.from_blob (eliminating the D2D copy and the staging + buffer lock). However, all tensors in a group would share one underlying buffer via + PyTorch's storage refcount — the buffer is freed only when the last tensor in the + group is GC'd. A single long-lived tensor silently keeps the entire batch allocation + alive, which is a hard-to-debug memory leak. We keep the D2D copy for now to give + each returned tensor an independent PyTorch-managed lifetime. + """ + assert self._gdr_staging is not None + self._gdr_staging.lazy_init(self._store) + staging = self._gdr_staging + device = torch.device("cuda", torch.cuda.current_device()) + buffer_size = staging.size + batch_nbytes = get_nbytes(batch_dtypes, batch_shapes) + + # Grouping happens outside the lock. + groups = split_by_bytes(batch_nbytes, buffer_size) + + tensors: list[torch.Tensor] = [None] * len(batch_keys) # type: ignore[list-item] + + with staging.acquire(): + for idxs in groups: + g_keys = [batch_keys[i] for i in idxs] + g_nbytes = [batch_nbytes[i] for i in idxs] + g_dtypes = [batch_dtypes[i] for i in idxs] + g_shapes = [batch_shapes[i] for i in idxs] + + if len(idxs) == 1 and g_nbytes[0] > buffer_size: + # Oversized tensor: reassemble from :c{i} sub-keys. + key = g_keys[0] + total = g_nbytes[0] + meta_entry = batch_meta[idxs[0]] + assert meta_entry is not None + n_chunks = meta_entry["n_chunks"] + sub_keys = [f"{key}:c{i}" for i in range(n_chunks)] + final_tensor = torch.empty(tuple(g_shapes[0]), dtype=g_dtypes[0], device=device) + for i, sub_key in enumerate(sub_keys): + chunk_size = min(buffer_size, total - i * buffer_size) + self._batch_get_into_with_retry([sub_key], [staging.ptr], [chunk_size]) + staging.memcpy_d2d_async(final_tensor.data_ptr() + i * buffer_size, staging.ptr, chunk_size) + staging.synchronize() + tensors[idxs[0]] = final_tensor + else: + # Normal group: aligned total fits; batch_get_into then unpack. + offsets, _ = _aligned_offsets(g_nbytes) + sub_ptrs = [staging._ptr + off for off in offsets] + self._batch_get_into_with_retry(g_keys, sub_ptrs, g_nbytes) + unpacked = staging.unpack(sub_ptrs, g_nbytes, g_dtypes, g_shapes, device) + for pos, t in zip(idxs, unpacked, strict=True): + tensors[pos] = t + + return tensors, indexes + def _get_bytes_thread_worker( self, batch_keys: list[str], batch_packed_sizes: list[int], indexes: list[int] ) -> tuple[list[Any], list[int]]: @@ -331,13 +493,30 @@ def clear(self, keys: list[str], custom_backend_meta: list[Any] | None = None) - keys (List[str]): List of keys to remove. custom_backend_meta (List[Any], optional): ... """ - ret_codes = self._store.batch_remove(keys, force=True) + if self._gdr_staging is not None and custom_backend_meta is not None: + actual_keys: list[str] = [] + for key, meta in zip(keys, custom_backend_meta, strict=True): + if isinstance(meta, dict) and "n_chunks" in meta: + actual_keys.extend(f"{key}:c{i}" for i in range(meta["n_chunks"])) + else: + actual_keys.append(key) + else: + if self._gdr_staging is not None: + logger.warning( + "GDR is enabled but custom_backend_meta is None; chunked sub-keys (if any) will not be removed." + ) + actual_keys = keys + + ret_codes = self._store.batch_remove(actual_keys, force=True) for i, ret in enumerate(ret_codes): if not (ret == 0 or ret == -704): - logger.error(f"remove failed for key `{keys[i]}` with error code: {ret}") + logger.error(f"remove failed for key `{actual_keys[i]}` with error code: {ret}") def close(self): """Closes MooncakeStore.""" + if self._gdr_staging is not None: + self._gdr_staging.close(self._store) + self._gdr_staging = None if self._store: self._store.close() self._store = None @@ -472,9 +651,6 @@ def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[i size_list: list[int] = [] tensor_list: list[Tensor] = [] # hold reference for the contiguous tensor for t in values: - # TODO: support gpu direct rdma and use different data paths. - # For GPU, it's more reasonable to perform data copy since - # The register overhead is much higher than CPU if t.device.type == "cuda": t = t.cpu() t = t.contiguous() diff --git a/transfer_queue/utils/mooncake_utils.py b/transfer_queue/utils/mooncake_utils.py new file mode 100644 index 00000000..f2e63b4d --- /dev/null +++ b/transfer_queue/utils/mooncake_utils.py @@ -0,0 +1,224 @@ +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Mooncake GDR transfers: persistent staging buffer and sub-key helpers.""" + +import contextlib +import threading +from math import ceil +from typing import Any + +import torch + +_DEFAULT_ALIGN = 256 + + +def _aligned_offsets(sizes: list[int], align: int = _DEFAULT_ALIGN) -> tuple[list[int], int]: + """Lay sizes out back-to-back with ``align``-byte alignment; return (offsets, total).""" + offsets: list[int] = [] + off = 0 + for sz in sizes: + offsets.append(off) + off += (sz + align - 1) // align * align + return offsets, off + + +def chunk_subkeys(key: str, nbytes: int, buffer_size: int) -> list[str]: + """Return the list of storage keys for a tensor of ``nbytes`` under ``buffer_size``. + + - nbytes <= buffer_size: returns [key] (no chunking) + - nbytes > buffer_size: returns ["{key}:c0", ..., "{key}:c{n-1}"] + """ + if nbytes <= buffer_size: + return [key] + n = ceil(nbytes / buffer_size) + return [f"{key}:c{i}" for i in range(n)] + + +def split_by_bytes(nbytes: list[int], buffer_size: int) -> list[list[int]]: + """Partition tensor indices into groups that fit within the staging buffer. + + Args: + nbytes: Per-tensor byte counts (same order as the tensor list being transferred). + buffer_size: Capacity of the staging buffer in bytes. + + Returns: + List of groups, each group is a list of indices into ``nbytes``. + Every group's 256-byte-aligned cumulative size fits within ``buffer_size``. + A tensor whose nbytes > buffer_size gets its own singleton group; + the caller handles it via the chunked ``:c{i}`` sub-key path. + Indices are processed in ascending size order so that large tensors do not + fragment the packing of small tensors. + + Call this before acquiring the staging-buffer lock; it does only integer arithmetic. + """ + groups: list[list[int]] = [] + current: list[int] = [] + current_total = 0 + + for i in sorted(range(len(nbytes)), key=lambda i: nbytes[i]): + nb = nbytes[i] + aligned = (nb + _DEFAULT_ALIGN - 1) // _DEFAULT_ALIGN * _DEFAULT_ALIGN + if nb > buffer_size: + if current: + groups.append(current) + current, current_total = [], 0 + groups.append([i]) + elif current and current_total + aligned > buffer_size: + groups.append(current) + current, current_total = [i], aligned + else: + current.append(i) + current_total += aligned + + if current: + groups.append(current) + return groups + + +class GdrStaging: + """Process-level persistent CUDA staging buffer for GDR transfers. + + One cudaMalloc buffer, registered once for the process lifetime. + All callers (PUT and GET) serialize through a single lock. + """ + + def __init__(self, buffer_size_bytes: int) -> None: + self._size = buffer_size_bytes + self._ptr: int = 0 + self._lock = threading.Lock() + self._rt: Any = None + self._stream: torch.cuda.Stream | None = None + self._initialized = False + + def lazy_init(self, store) -> None: + """Import cuda-python, cudaMalloc, register_buffer; idempotent.""" + if self._initialized: + return + try: + from cuda import cuda as cuda_driver + from cuda import cudart + except ImportError as exc: + raise ImportError( + "cuda-python is required for GDR transfers; install with: pip install 'TransferQueue[mooncake]'" + ) from exc + self._rt = cudart + err, device_ordinal = cudart.cudaGetDevice() + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaGetDevice() failed: {err.name}") + _, supported = cuda_driver.cuDeviceGetAttribute( + cuda_driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED, + device_ordinal, + ) + if not supported: + raise RuntimeError( + f"GPUDirect RDMA is not supported on device {device_ordinal}. " + "Please ensure the device supports GDR, or set use_gdr=False." + ) + err, ptr = cudart.cudaMalloc(self._size) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaMalloc({self._size}) failed: {err.name}") + self._ptr = ptr + store.register_buffer(self._ptr, self._size) + self._stream = torch.cuda.Stream() + self._initialized = True + + def close(self, store) -> None: + """store.unregister_buffer + cudaFree. Called by MooncakeStoreClient.close().""" + if self._initialized: + store.unregister_buffer(self._ptr) + (err,) = self._rt.cudaFree(self._ptr) + if err != self._rt.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree(0x{self._ptr:x}) failed: {err.name}") + self._initialized = False + + @contextlib.contextmanager + def acquire(self): + """Context manager that holds the internal mutex for the duration of one transfer.""" + with self._lock: + yield + + def memcpy_d2d_async(self, dst: int, src: int, nbytes: int) -> None: + """Enqueue a D2D async copy on the internal stream; call synchronize() when done.""" + assert self._stream is not None + rt = self._rt + (err,) = rt.cudaMemcpyAsync( + dst, src, nbytes, rt.cudaMemcpyKind.cudaMemcpyDeviceToDevice, self._stream.cuda_stream + ) + if err != rt.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaMemcpyAsync D2D failed: {err.name}") + + def memcpy_h2d_async(self, dst: int, src: int, nbytes: int) -> None: + """Enqueue a H2D async copy on the internal stream; call synchronize() when done.""" + assert self._stream is not None + rt = self._rt + (err,) = rt.cudaMemcpyAsync( + dst, src, nbytes, rt.cudaMemcpyKind.cudaMemcpyHostToDevice, self._stream.cuda_stream + ) + if err != rt.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaMemcpyAsync H2D failed: {err.name}") + + def synchronize(self) -> None: + """Synchronize the internal CUDA stream.""" + assert self._stream is not None + self._stream.synchronize() + + def pack(self, tensors: list[torch.Tensor]) -> tuple[list[int], list[int]]: + """Memcpy each tensor into the staging buffer at 256-byte aligned offsets. + + Supports both CPU (H2D) and CUDA (D2D) tensors transparently. + Caller must hold the lock (call inside acquire()). + Total packed size must fit in buffer_size (caller must ensure this). + Returns (sub_ptrs, sizes). + """ + sizes = [t.nbytes for t in tensors] + offsets, _ = _aligned_offsets(sizes) + for t, off in zip(tensors, offsets, strict=True): + if t.is_cuda: + self.memcpy_d2d_async(self._ptr + off, t.data_ptr(), t.nbytes) + else: + self.memcpy_h2d_async(self._ptr + off, t.data_ptr(), t.nbytes) + self.synchronize() + sub_ptrs = [self._ptr + off for off in offsets] + return sub_ptrs, sizes + + def unpack( + self, + sub_ptrs: list[int], + sizes: list[int], + dtypes: list[torch.dtype], + shapes: list[tuple], + device: torch.device, + ) -> list[torch.Tensor]: + """D2D memcpyAsync from each sub_ptr in staging into fresh tensors on device. + + Caller must hold the lock (call inside acquire()). + """ + out: list[torch.Tensor] = [] + for sub_ptr, sz, dt, shp in zip(sub_ptrs, sizes, dtypes, shapes, strict=True): + t = torch.empty(tuple(shp), dtype=dt, device=device) + self.memcpy_d2d_async(t.data_ptr(), sub_ptr, sz) + out.append(t) + self.synchronize() + return out + + @property + def ptr(self) -> int: + """Raw CUDA device pointer to the start of the staging buffer.""" + return self._ptr + + @property + def size(self) -> int: + """Capacity of the staging buffer in bytes.""" + return self._size