Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/e2e/test_kv_interface_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,52 @@ def test_kv_clear_multiple_keys(self, controller, tq_api):

tq_api.kv_clear(keys=keys[2:], partition_id=partition_id)

def test_kv_clear_idempotent(self, controller, tq_api):
"""Test kv_clear is idempotent for non-existent keys and partitions."""
partition_id = "test_partition"
keys = ["idempotent_0", "idempotent_1", "idempotent_2", "idempotent_3"]

# Batch put 4 keys
fields = TensorDict(
{"data": torch.tensor([[0], [1], [2], [3]])},
batch_size=4,
)
tq_api.kv_batch_put(keys=keys, partition_id=partition_id, fields=fields, tags=[{}, {}, {}, {}])

# Clear non-existent keys should not raise and should not affect existing keys
tq_api.kv_clear(keys=["nonexistent_key_1", "nonexistent_key_2"], partition_id=partition_id)
partition_info = tq_api.kv_list(partition_id=partition_id)
assert len(partition_info[partition_id]) == 4
for key in keys:
assert key in partition_info[partition_id]

# Clear mix of existent and non-existent keys should only remove existent ones
tq_api.kv_clear(keys=[keys[0], "nonexistent_key_3"], partition_id=partition_id)
partition_info = tq_api.kv_list(partition_id=partition_id)
assert len(partition_info[partition_id]) == 3
assert keys[0] not in partition_info[partition_id]
assert keys[1] in partition_info[partition_id]
assert keys[2] in partition_info[partition_id]
assert keys[3] in partition_info[partition_id]

# Clear already-cleared key should be idempotent
tq_api.kv_clear(keys=[keys[0], "nonexistent_key_4"], partition_id=partition_id)
partition_info = tq_api.kv_list(partition_id=partition_id)
assert len(partition_info[partition_id]) == 3
assert keys[0] not in partition_info[partition_id]

# Verify via controller - only keys[0] should be removed from keys_mapping
partition = get_controller_partition(controller, partition_id)
assert keys[0] not in partition.keys_mapping
for key in keys[1:]:
assert key in partition.keys_mapping

# Clear non-existent partition should not raise
tq_api.kv_clear(keys=["any_key"], partition_id="nonexistent_partition")

# Clean up
tq_api.kv_clear(keys=keys, partition_id=partition_id)


class TestKVE2ECornerCases:
"""End-to-end tests for corner cases."""
Expand Down
126 changes: 122 additions & 4 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,98 @@ def test_controller_clear_meta(self, ray_setup):

print("✓ Clear meta correct")

def test_controller_clear_meta_idempotent(self, ray_setup):
"""Test clear_meta is idempotent when clearing non-existent data."""
gbs = 4
num_n_samples = 2
partition_id = "test_clear_meta_idempotent"
other_partition_id = "test_clear_meta_idempotent_other"

tq_controller = TransferQueueController.remote()

# Create two partitions
data_fields = ["prompt_ids", "attention_mask"]
metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs * num_n_samples,
partition_id=partition_id,
mode="insert",
)
)
other_metadata = ray.get(
tq_controller.get_metadata.remote(
data_fields=data_fields,
batch_size=gbs * num_n_samples,
partition_id=other_partition_id,
mode="insert",
)
)

# Update production status for both partitions
field_schema = {
"prompt_ids": {"dtype": "torch.int64", "shape": (32,)},
"attention_mask": {"dtype": "torch.bool", "shape": (32,)},
}
ray.get(
tq_controller.update_production_status.remote(
partition_id=partition_id,
global_indexes=metadata.global_indexes,
field_schema=field_schema,
)
)
ray.get(
tq_controller.update_production_status.remote(
partition_id=other_partition_id,
global_indexes=other_metadata.global_indexes,
field_schema=field_schema,
)
)

# Clear non-existent partition should not raise
ray.get(
tq_controller.clear_meta.remote(
global_indexes=[0, 1],
partition_ids=["non_existent_partition", "non_existent_partition"],
)
)

partition_before = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
initial_global_indexes = set(partition_before.global_indexes)

# Clear mix of existent and non-existent global indexes
ray.get(
tq_controller.clear_meta.remote(
global_indexes=[0, 1, 100, 101],
partition_ids=[partition_id, partition_id, partition_id, partition_id],
)
)
partition_after = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
# Only existing indexes 0 and 1 should be removed; 100 and 101 are ignored
assert set(partition_after.global_indexes) == initial_global_indexes - {0, 1}

# Clearing already-cleared indexes should be idempotent
ray.get(
tq_controller.clear_meta.remote(
global_indexes=[0, 1, 100],
partition_ids=[partition_id, partition_id, partition_id],
)
)
partition_after = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
assert set(partition_after.global_indexes) == initial_global_indexes - {0, 1}

# Clearing across multiple partitions where one does not exist should still clear the other
ray.get(
tq_controller.clear_meta.remote(
global_indexes=[2, 0],
partition_ids=[other_partition_id, "non_existent_partition"],
)
)
other_partition_after = ray.get(tq_controller.get_partition_snapshot.remote(other_partition_id))
assert 2 not in other_partition_after.global_indexes

print("✓ Clear meta idempotent correct")


class TestTransferQueueControllerCustomMeta:
"""Integration tests for TransferQueueController custom_meta and custom_backend_meta methods.
Expand Down Expand Up @@ -807,8 +899,34 @@ def test_controller_kv_retrieve_meta_existing_keys(self, ray_setup):
# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))

def test_controller_kv_retrieve_meta_partial_keys_without_create(self, ray_setup):
"""Test kv_retrieve_meta filters out non-existent keys without create."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_partial_test"

# Create some keys first
existing_keys = ["existing_key_1", "existing_key_2", "existing_key_3"]
ray.get(tq_controller.kv_retrieve_meta.remote(keys=existing_keys, partition_id=partition_id, create=True))

# Retrieve a mix of existing and non-existing keys
keys_to_retrieve = ["existing_key_1", "nonexistent_key", "existing_key_3", "another_missing_key"]
batch_meta = ray.get(
tq_controller.kv_retrieve_meta.remote(keys=keys_to_retrieve, partition_id=partition_id, create=False)
)

# Only existing keys should be returned
assert batch_meta.size == 2
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
retrieved_keys = {partition.revert_keys_mapping[idx] for idx in batch_meta.global_indexes}
assert retrieved_keys == {"existing_key_1", "existing_key_3"}

print("✓ kv_retrieve_meta filters out non-existent keys without create")

# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))

def test_controller_kv_retrieve_meta_non_existent_without_create(self, ray_setup):
"""Test kv_retrieve_meta raises error for non-existent keys without create."""
"""Test kv_retrieve_meta returns empty BatchMeta for non-existent keys without create."""
tq_controller = TransferQueueController.remote()
partition_id = "kv_nonexistent_test"

Expand All @@ -821,13 +939,13 @@ def test_controller_kv_retrieve_meta_non_existent_without_create(self, ray_setup
)
assert batch_meta.size == 0

print("✓ kv_retrieve_meta return an empty BatchMeta for non-existent keys without create")
print("✓ kv_retrieve_meta returns an empty BatchMeta for non-existent keys without create")

# Clean up
ray.get(tq_controller.clear_partition.remote(partition_id))

def test_controller_kv_retrieve_meta_empty_partition_without_create(self, ray_setup):
"""Test kv_retrieve_meta raises error for non-existent partition without create."""
"""Test kv_retrieve_meta returns empty BatchMeta for non-existent partition without create."""
tq_controller = TransferQueueController.remote()
partition_id = "nonexistent_partition"

Expand All @@ -836,7 +954,7 @@ def test_controller_kv_retrieve_meta_empty_partition_without_create(self, ray_se
)
assert batch_meta.size == 0

print("✓ kv_retrieve_meta return an empty BatchMeta for non-existent partition_id without create")
print("✓ kv_retrieve_meta returns an empty BatchMeta for non-existent partition_id without create")

def test_controller_kv_retrieve_meta_with_production_status(self, ray_setup):
"""Test kv_retrieve_meta works with production status update."""
Expand Down
5 changes: 4 additions & 1 deletion transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,10 @@ async def async_clear_partition(self, partition_id: str):
metadata = await self._get_partition_meta(partition_id)

if not metadata:
logger.warning(f"Try to clear an non-exist partition {partition_id}. No action will be taken.")
logger.warning(
f"[{self.client_id}]: Trying to clear a non-existent partition {partition_id}. "
f"No action will be taken."
)
return

# Clear the controller metadata
Expand Down
47 changes: 35 additions & 12 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,11 +1463,13 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
clear_consumption: Whether to also clear consumption status
"""

logger.debug(f"[{self.controller_id}]: clearing metadata in partition {partition_id}")
logger.debug(f"[{self.controller_id}]: Clearing metadata in partition {partition_id}.")

partition = self._get_partition(partition_id)
if not partition:
logger.warning(f"Try to clear an non-existent partition {partition_id}!")
logger.warning(
f"[{self.controller_id}]: Trying to clear a non-existent partition {partition_id}. No action taken."
)
return

global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id))
Expand All @@ -1490,7 +1492,10 @@ def reset_consumption(self, partition_id: str, task_name: str | None = None):
logger.debug(f"[{self.controller_id}]: Resetting consumption for partition {partition_id}, task={task_name}")
partition = self._get_partition(partition_id)
if not partition:
logger.warning(f"Try to reset consumption of an non-existent partition {partition_id}!")
logger.warning(
f"[{self.controller_id}]: Trying to reset consumption of a non-existent partition {partition_id}. "
f"No action taken."
)
return
partition.reset_consumption(task_name)

Expand Down Expand Up @@ -1528,15 +1533,29 @@ def clear_meta(
for partition_id, group in groupby(combined, key=itemgetter(0)):
partition = self._get_partition(partition_id)
if not partition:
raise ValueError(f"Partition {partition_id} not found")
logger.info(
f"[{self.controller_id}]: Trying to clear data in a non-existent partition {partition_id}. "
f"Skipping operation for this partition."
)
continue

global_indexes_to_clear = [idx for _, idx in group]
if not set(global_indexes_to_clear).issubset(partition.global_indexes):
raise ValueError(
f"Some global_indexes to clear do not exist in partition {partition_id}. "
f"Target: {global_indexes_to_clear}, Existing: {partition.global_indexes}"
existing_global_indexes = partition.global_indexes
non_existent_global_indexes = set(global_indexes_to_clear) - existing_global_indexes
if non_existent_global_indexes:
logger.info(
f"[{self.controller_id}]: Some global_indexes to be cleared do not exist in "
f"partition {partition_id}: {non_existent_global_indexes}. They will be ignored."
)

global_indexes_to_clear = list(set(global_indexes_to_clear) & existing_global_indexes)
if not global_indexes_to_clear:
logger.info(
f"[{self.controller_id}]: No existing global indexes to clear in partition {partition_id}. "
f"Skipping operation for this partition."
)
continue

# Clear data from partition
partition.clear_data(global_indexes_to_clear, clear_consumption)

Expand All @@ -1551,6 +1570,7 @@ def kv_retrieve_meta(
) -> BatchMeta:
"""
Retrieve BatchMeta from the controller using a list of keys.
For non-existing keys, we simply emit them in returned BatchMeta.

Args:
keys: List of keys to retrieve from the controller
Expand All @@ -1567,7 +1587,9 @@ def kv_retrieve_meta(
partition = self._get_partition(partition_id)
if partition is None:
if not create:
logger.warning(f"Partition {partition_id} not found!")
logger.warning(
f"[{self.controller_id}]: Partition {partition_id} not found. Returning empty BatchMeta."
)
return BatchMeta.empty()

self.create_partition(partition_id)
Expand All @@ -1579,8 +1601,10 @@ def kv_retrieve_meta(
none_indexes = [idx for idx, value in enumerate(global_indexes) if value is None]
if len(none_indexes) > 0:
if not create:
logger.error(f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}!")
return BatchMeta.empty()
logger.warning(
f"Keys {[keys[i] for i in none_indexes]} were not found in partition {partition_id}. "
f"They will be excluded from the retrieved BatchMeta."
)
else:
# create non-exist keys
batch_global_indexes = partition.activate_pre_allocated_indexes(len(none_indexes))
Expand All @@ -1603,7 +1627,6 @@ def kv_retrieve_meta(
partition.ensure_samples_capacity(max(batch_global_indexes) + 1)

verified_global_indexes = [idx for idx in global_indexes if idx is not None]
assert len(verified_global_indexes) == len(keys)

# must fetch fields that the requested samples all have
col_mask = partition.production_status[verified_global_indexes, :].sum(dim=0).reshape(-1) == len(
Expand Down
Loading