diff --git a/tests/e2e/test_kv_interface_e2e.py b/tests/e2e/test_kv_interface_e2e.py index e3e7ed2..a1a3976 100644 --- a/tests/e2e/test_kv_interface_e2e.py +++ b/tests/e2e/test_kv_interface_e2e.py @@ -946,6 +946,29 @@ def test_kv_clear_multiple_keys(self, controller, tq_api): tq_api.kv_clear(keys=keys[2:], partition_id=partition_id) + def test_kv_clear_does_not_leak_reused_index_across_partitions(self, controller, tq_api): + """Clearing a key in one partition must not make reused indexes visible there again.""" + field_name = "x" + p1 = "clear_reuse_partition_1" + p2 = "clear_reuse_partition_2" + + tq_api.kv_put(key="a", partition_id=p1, fields={field_name: torch.tensor([1])}) + tq_api.kv_clear(keys="a", partition_id=p1) + + tq_api.kv_put(key="b", partition_id=p2, fields={field_name: torch.tensor([2])}) + + leaked_meta = tq.get_client().get_meta( + data_fields=[field_name], + batch_size=1, + partition_id=p1, + mode="fetch", + task_name="clear_reuse_after_other_partition_put", + ) + + assert leaked_meta.size == 0 + assert leaked_meta.global_indexes == [] + assert not leaked_meta.is_ready + def test_kv_clear_idempotent(self, controller, tq_api): """Test kv_clear is idempotent for non-existent keys and partitions.""" partition_id = "test_partition" diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 8f14517..2f6b874 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -107,6 +107,41 @@ def test_data_partition_status(): print("DataPartitionStatus tests passed!\n") +def test_data_partition_status_production_status_is_partition_local(): + """Regression test that partitions do not share production_status tensors.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="partition_a") + other_partition = DataPartitionStatus(partition_id="partition_b") + + assert partition.production_status.data_ptr() != other_partition.production_status.data_ptr() + + +def test_cleared_partition_does_not_observe_reused_index_from_other_partition(): + """A cleared partition must not become ready when another partition reuses its index.""" + from transfer_queue.controller import DataPartitionStatus + + field_schema = { + "x": { + "dtype": "torch.int64", + "shape": (1,), + "is_nested": False, + "is_non_tensor": False, + } + } + + partition = DataPartitionStatus(partition_id="partition_a") + other_partition = DataPartitionStatus(partition_id="partition_b") + + assert partition.update_production_status([0], ["x"], field_schema=field_schema) + partition.clear_data([0], clear_consumption=True) + assert partition.scan_data_status(["x"], task_name="consumer_before_reuse") == [] + + assert other_partition.update_production_status([0], ["x"], field_schema=field_schema) + + assert partition.scan_data_status(["x"], task_name="consumer_after_reuse") == [] + + def test_partition_interface(): """Test the partition interface design.""" print("Testing partition interface design...") diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 407cb11..7c31522 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -332,7 +332,9 @@ class DataPartitionStatus: # Values: 0 = not produced, 1 = ready for consumption TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) - production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) + production_status: Tensor = field( + default_factory=lambda: torch.zeros(DataPartitionStatus.TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) + ) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task