Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
python -m pip install --upgrade pip
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install -e ".[test,build,yuanrong]"
pip install mooncake-transfer-engine-non-cuda
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand All @@ -43,11 +44,10 @@ jobs:
run: |
python -m build --wheel
pip install dist/*.whl --force-reinstall
- name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False)
- name: Test with pytest
run: |
pytest tests
- name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True)
run: |
ray stop --force
export TQ_ZERO_COPY_SERIALIZATION=True
pytest tests
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py
pkill -f "mooncake_master"
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py
pkill -f "mooncake_master"
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ test = [
yuanrong = [
"openyuanrong-datasystem"
]
mooncake = [
"mooncake-transfer-engine"
]

# If you need to mimic `package_dir={'': '.'}`:
[tool.setuptools.package-dir]
Expand Down
112 changes: 92 additions & 20 deletions tests/e2e/test_e2e_lifecycle_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""E2E lifecycle consistency tests for TransferQueue."""

import os
import sys
import time
from pathlib import Path
Expand All @@ -23,6 +22,7 @@
import pytest
import ray
import torch
from omegaconf import OmegaConf
from tensordict import TensorDict
from tensordict.tensorclass import NonTensorData

Expand All @@ -48,6 +48,38 @@
"non_tensor_stack",
]

# Backend configurations for E2E tests
BACKEND_CONFIGS = {
"SimpleStorage": {
"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "SimpleStorage",
"SimpleStorage": {
"total_storage_size": 200,
"num_data_storage_units": 2,
},
},
},
"MooncakeStore": {
"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "MooncakeStore",
"MooncakeStore": {
"global_segment_size": 134217728, # 128MB
"local_buffer_size": 134217728, # 128MB
"metadata_server": "localhost:50050",
"master_server_address": "localhost:50051",
"protocol": "tcp",
"device_name": "",
},
},
},
}


@pytest.fixture(scope="module")
def ray_cluster():
Expand All @@ -59,24 +91,33 @@ def ray_cluster():


@pytest.fixture(scope="module")
def e2e_client(ray_cluster):
"""Create a client using transfer_queue.init() for lifecycle testing."""
from omegaconf import OmegaConf
def backend_name():
"""Get the backend name from environment variable.

Environment variables:
TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore)

To run tests for a specific backend:
TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py
"""
return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage")


@pytest.fixture(scope="module")
def e2e_client(ray_cluster, backend_name):
"""Create a client using transfer_queue.init() for lifecycle testing.

Args:
ray_cluster: Ray cluster fixture
backend_name: Backend name from TQ_TEST_BACKEND env var
"""
import transfer_queue

config = {
"controller": {
"polling_mode": True,
},
"backend": {
"storage_backend": "SimpleStorage",
"SimpleStorage": {
"total_storage_size": 200,
"num_data_storage_units": 2,
},
},
}
if backend_name not in BACKEND_CONFIGS:
raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}")

config = BACKEND_CONFIGS[backend_name]
transfer_queue.init(OmegaConf.create(config))
client = transfer_queue.get_client()
yield client
Expand Down Expand Up @@ -244,7 +285,7 @@ def verify_list_equal(retrieved, expected) -> bool:
if isinstance(retrieved, NonTensorStack):
retrieved = retrieved.tolist()
elif isinstance(retrieved, torch.Tensor):
retrieved = retrieved.tolist()
retrieved = retrieved.reshape(-1).tolist() # may get 2D tensor back using key-value based backend
if isinstance(expected, NonTensorStack):
expected = expected.tolist()
elif isinstance(expected, torch.Tensor):
Expand Down Expand Up @@ -283,9 +324,21 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict:
return TensorDict(reordered, batch_size=td.batch_size)


def recover_local_index(global_index_order, new_global_index_order):
value_to_new_index = {}
for idx, val in enumerate(new_global_index_order):
value_to_new_index[val] = idx

local_index_order_to_recover = []
for val in global_index_order:
local_index_order_to_recover.append(value_to_new_index[val])

return local_index_order_to_recover


# Scenario One: Core Read/Write Consistency
def test_core_consistency(e2e_client):
"""Put full complex data then get verify all field types are correctly round-tripped."""
"""Put full complex data then get - verify all field types are correctly round-tripped."""
client = e2e_client
partition_id = "test_core_consistency"
batch_size = 20
Expand Down Expand Up @@ -362,6 +415,12 @@ def test_core_consistency(e2e_client):
# Scenario Two: Cross-Shard Update
def test_cross_shard_complex_update(e2e_client):
"""Cross-shard update: put A + put B, update overlapping region, verify all regions."""

# FIXME: Add data update test to MooncakeStore after Upsert function is ready
# https://github.com/kvcache-ai/Mooncake/issues/1645
if os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") == "MooncakeStore":
return

client = e2e_client
partition_id = "test_cross_shard_update"
task_name = "cross_shard_task"
Expand Down Expand Up @@ -744,12 +803,19 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):

indices = list(range(batch_size))
original_data = generate_complex_data(indices)
client.put(data=original_data, partition_id=partition_id)
original_meta = client.put(data=original_data, partition_id=partition_id)

global_index_order = original_meta.global_indexes
try:
# === Phase 1: Retrieve and verify writability ===
meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
assert meta is not None and meta.size == batch_size

# the global_index_order in retrieved meta is different from the original one.
# we need to reorder first.
local_index_order = recover_local_index(global_index_order, meta.global_indexes)
meta = meta.select_samples(local_index_order)

retrieved = client.get_data(meta)

# 1. tensor_f32: writable
Expand Down Expand Up @@ -793,6 +859,12 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
# Re-retrieve the same data — modifications above should NOT have affected storage
meta2 = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
assert meta2 is not None and meta2.size == batch_size

# the global_index_order in retrieved meta is different from the original one.
# we need to reorder first.
local_index_order = recover_local_index(global_index_order, meta2.global_indexes)
meta2 = meta2.select_samples(local_index_order)

retrieved2 = client.get_data(meta2)

# tensor_f32[0,0] should be the original value, not 99999.0
Expand Down
Loading
Loading