Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/e2e/test_kv_interface_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,29 @@ def test_kv_clear_multiple_keys(self, controller, tq_api):

tq_api.kv_clear(keys=keys[2:], partition_id=partition_id)

def test_kv_clear_does_not_leak_reused_index_across_partitions(self, controller, tq_api):
"""Clearing a key in one partition must not make reused indexes visible there again."""
field_name = "x"
p1 = "clear_reuse_partition_1"
p2 = "clear_reuse_partition_2"

tq_api.kv_put(key="a", partition_id=p1, fields={field_name: torch.tensor([1])})
tq_api.kv_clear(keys="a", partition_id=p1)

tq_api.kv_put(key="b", partition_id=p2, fields={field_name: torch.tensor([2])})

leaked_meta = tq.get_client().get_meta(
data_fields=[field_name],
batch_size=1,
partition_id=p1,
mode="fetch",
task_name="clear_reuse_after_other_partition_put",
)

assert leaked_meta.size == 0
assert leaked_meta.global_indexes == []
assert not leaked_meta.is_ready

def test_kv_clear_idempotent(self, controller, tq_api):
"""Test kv_clear is idempotent for non-existent keys and partitions."""
partition_id = "test_partition"
Expand Down
35 changes: 35 additions & 0 deletions tests/test_controller_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,41 @@ def test_data_partition_status():
print("DataPartitionStatus tests passed!\n")


def test_data_partition_status_production_status_is_partition_local():
"""Regression test that partitions do not share production_status tensors."""
from transfer_queue.controller import DataPartitionStatus

partition = DataPartitionStatus(partition_id="partition_a")
other_partition = DataPartitionStatus(partition_id="partition_b")

assert partition.production_status.data_ptr() != other_partition.production_status.data_ptr()


def test_cleared_partition_does_not_observe_reused_index_from_other_partition():
"""A cleared partition must not become ready when another partition reuses its index."""
from transfer_queue.controller import DataPartitionStatus

field_schema = {
"x": {
"dtype": "torch.int64",
"shape": (1,),
"is_nested": False,
"is_non_tensor": False,
}
}

partition = DataPartitionStatus(partition_id="partition_a")
other_partition = DataPartitionStatus(partition_id="partition_b")

assert partition.update_production_status([0], ["x"], field_schema=field_schema)
partition.clear_data([0], clear_consumption=True)
assert partition.scan_data_status(["x"], task_name="consumer_before_reuse") == []

assert other_partition.update_production_status([0], ["x"], field_schema=field_schema)

assert partition.scan_data_status(["x"], task_name="consumer_after_reuse") == []


def test_partition_interface():
"""Test the partition interface design."""
print("Testing partition interface design...")
Expand Down
4 changes: 3 additions & 1 deletion transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,9 @@ class DataPartitionStatus:
# Values: 0 = not produced, 1 = ready for consumption
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))

production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
production_status: Tensor = field(
default_factory=lambda: torch.zeros(DataPartitionStatus.TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
)

# Consumption status per task - task_name -> consumption_tensor
# Each tensor tracks which samples have been consumed by that task
Expand Down
Loading