Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 298 additions & 0 deletions tests/test_controller_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,3 +1392,301 @@ def test_fieldmeta_to_batch_schema_nested_missing_sample(self):
schema = field_meta.to_batch_schema([0, 1])

assert schema["per_sample_shapes"] == [(3,), None]


def _ready_schema():
return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}}


class TestStreamingDrain:
"""No-preset-global-batch streaming end-of-stream: production_completed +
actual_sample_count, exercised via DataPartitionStatus directly (the controller
insert path just accumulates these fields and stashes pending_last_indexes)."""

def _make_partition(self):
from transfer_queue.controller import DataPartitionStatus

return DataPartitionStatus(partition_id="train@stream_0")

def _produce(self, partition, indices):
partition.update_production_status(
global_indices=list(indices),
field_names=["x"],
field_schema=_ready_schema(),
)

def _simulate_insert(self, partition, indices, is_last=False):
# Mirror controller.get_metadata(mode="insert") accounting.
partition.global_indexes.update(indices)
partition.actual_sample_count += len(indices)
if is_last:
partition.pending_last_indexes.update(indices)
partition.has_pending_last = True
partition.pending_last_fields.update(["x"]) # producer field, see _produce

def test_not_drained_before_completed(self):
p = self._make_partition()
self._simulate_insert(p, [0, 1, 2])
self._produce(p, [0, 1, 2])
p.mark_consumed("actor_train", [0, 1, 2])
# No is_last yet -> not completed -> not drained even though all consumed.
assert p.production_completed is False
assert p.is_stream_drained("actor_train") is False

def test_completed_flips_only_after_last_batch_ready(self):
p = self._make_partition()
# First batch.
self._simulate_insert(p, [0, 1])
self._produce(p, [0, 1])
# Final batch announced at insert, but data not yet produced.
self._simulate_insert(p, [2, 3], is_last=True)
assert p.has_pending_last is True
assert p.production_completed is False # data of final batch not ready yet
# Producing the final batch flips the flag.
self._produce(p, [2, 3])
assert p.production_completed is True

def test_drained_requires_completed_and_all_consumed(self):
p = self._make_partition()
self._simulate_insert(p, [0, 1])
self._produce(p, [0, 1])
self._simulate_insert(p, [2, 3], is_last=True)
self._produce(p, [2, 3])
assert p.production_completed is True
# Partial consumption -> not drained.
p.mark_consumed("actor_train", [0, 1, 2])
assert p.is_stream_drained("actor_train") is False
# Full consumption -> drained.
p.mark_consumed("actor_train", [3])
assert p.is_stream_drained("actor_train") is True

def test_unactivated_prealloc_rows_do_not_block_drain(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@stream_1")
# Pre-allocate extra rows that are never activated/inserted.
p.register_pre_allocated_indexes([0, 1, 2, 3, 4, 5, 6, 7])
# Only 4 samples are actually inserted+produced+consumed.
self._simulate_insert(p, [0, 1])
self._produce(p, [0, 1])
self._simulate_insert(p, [2, 3], is_last=True)
self._produce(p, [2, 3])
p.mark_consumed("actor_train", [0, 1, 2, 3])
# consumption tensor has 8 rows (4 of them never consumed) but drain only
# counts the actually-inserted samples.
assert p.actual_sample_count == 4
assert p.is_stream_drained("actor_train") is True

def test_non_contiguous_indices(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@stream_2")
# Non-contiguous activated global indexes (drain must not assume [:N]).
self._simulate_insert(p, [5, 9])
self._produce(p, [5, 9])
self._simulate_insert(p, [11, 20], is_last=True)
self._produce(p, [11, 20])
assert p.production_completed is True
p.mark_consumed("actor_train", [5, 9, 11])
assert p.is_stream_drained("actor_train") is False
p.mark_consumed("actor_train", [20])
assert p.is_stream_drained("actor_train") is True

def test_drained_false_for_unknown_task(self):
p = self._make_partition()
self._simulate_insert(p, [0, 1], is_last=True)
self._produce(p, [0, 1])
assert p.production_completed is True
# A task that never consumed anything is not drained.
assert p.is_stream_drained("never_seen_task") is False


class TestStreamingDrainOutOfBounds:
"""Regression: an is_last batch announced at insert registers high
pending_last_indexes; an EARLIER batch's production notify must not index
production_status out of bounds in _maybe_mark_production_completed."""

def _schema(self):
return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}}

def test_pending_last_beyond_tensor_does_not_raise(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@oob_0")
# Final batch announced at insert with HIGH indexes (e.g. 6,7), but the
# tensor has only been grown for the earlier batch (indexes 0,1).
p.global_indexes.update([6, 7])
p.actual_sample_count = 8
p.pending_last_indexes.update([6, 7])
p.has_pending_last = True
p.pending_last_fields.update(["x"])

# Earlier batch's notify: produces indexes 0,1 only. This calls
# _maybe_mark_production_completed internally; indexes 6,7 are beyond the
# tensor at this point — must NOT raise, must NOT mark completed.
ok = p.update_production_status(
global_indices=[0, 1],
field_names=["x"],
field_schema=self._schema(),
)
assert ok is True
assert p.production_completed is False

# Now the final batch's own notify lands (indexes 6,7) -> tensor grows ->
# completion flips True.
ok = p.update_production_status(
global_indices=[6, 7],
field_names=["x"],
field_schema=self._schema(),
)
assert ok is True
assert p.production_completed is True


class TestCheckProductionCompleted:
"""Producer-side admission gate: check_production_completed reflects
production_completed only (no consumption), unlike is_stream_drained."""

def _schema(self):
return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}}

def test_completed_independent_of_consumption(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@gate_0")
# Not completed before is_last data lands.
p.global_indexes.update([0, 1])
p.actual_sample_count = 2
p.pending_last_indexes.update([0, 1])
p.has_pending_last = True
p.pending_last_fields.update(["x"])
assert p.production_completed is False

# Produce the final batch -> completed True, with ZERO consumption.
p.update_production_status([0, 1], ["x"], field_schema=self._schema())
assert p.production_completed is True
# is_stream_drained still False (nothing consumed) — the two gates differ.
assert p.is_stream_drained("actor_train") is False

def test_notify_before_is_last_flag_recheck_flips_completed(self):
"""Regression: the is_last batch's production notify can arrive BEFORE the
insert that sets has_pending_last (get_meta and put_data are separate RPCs).

Order: data is produced (update_production_status) while has_pending_last is
still False → that notify's completion check is a no-op. Then the is_last
insert sets the flag but, without a re-check, no further notify fires →
production_completed would never flip → drain deadlock. The fix re-runs the
completion check right after the insert sets the flag.
"""
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@race_0")

# 1) Producer's NOTIFY arrives first: data for the final batch is marked
# ready while has_pending_last is still False (no-op completion check).
p.update_production_status([0, 1], ["x"], field_schema=self._schema())
p.actual_sample_count = 2
assert p.production_completed is False # flag not set yet → not completed

# 2) The is_last insert now sets the flag. Simulate the controller's
# insert-path: set flag, then re-run the completion check (the fix).
p.global_indexes.update([0, 1])
p.pending_last_indexes.update([0, 1])
p.has_pending_last = True
p.pending_last_fields.update(["x"])
p._maybe_mark_production_completed() # re-check after flag set

# The data was already ready, so the re-check must flip completion True.
assert p.production_completed is True


class TestStreamingDrainBackfillFields:
"""Regression for the step-8 hang: downstream consumers (advantages, ref/
actor_fwd) backfill EXTRA fields into the same partition. production_completed
must only require the PRODUCER's fields on the is_last samples — not the
backfilled columns the producer never writes."""

def _producer_schema(self):
return {
"tokens": {"dtype": "torch.int32", "shape": (8,), "is_nested": False, "is_non_tensor": False},
"rewards": {"dtype": "torch.float32", "shape": (1,), "is_nested": False, "is_non_tensor": False},
}

def _adv_schema(self):
return {
"advantages": {"dtype": "torch.float32", "shape": (8,), "is_nested": False, "is_non_tensor": False},
"returns": {"dtype": "torch.float32", "shape": (8,), "is_nested": False, "is_non_tensor": False},
}

def test_backfilled_fields_do_not_block_completion(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@bf_0")
# Producer inserts the (only, final) batch declaring its own fields.
p.global_indexes.update([0, 1])
p.actual_sample_count = 2
p.pending_last_indexes.update([0, 1])
p.has_pending_last = True
p.pending_last_fields.update(["tokens", "rewards"])

# Producer writes its fields -> completion should flip True even though
# downstream fields (advantages/returns) have NOT been written yet.
p.update_production_status([0, 1], ["tokens", "rewards"], field_schema=self._producer_schema())
assert p.production_completed is True

def test_completion_not_blocked_when_adv_backfill_grows_columns_first(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@bf_1")
p.global_indexes.update([0, 1])
p.actual_sample_count = 2
p.pending_last_indexes.update([0, 1])
p.has_pending_last = True
p.pending_last_fields.update(["tokens", "rewards"])

# Producer writes its fields first -> completed True.
p.update_production_status([0, 1], ["tokens", "rewards"], field_schema=self._producer_schema())
assert p.production_completed is True

# A later advantages backfill adds advantages/returns columns. The producer
# samples are 0 on those new columns, but completion already (correctly)
# latched True and must stay True.
p.update_production_status([0, 1], ["advantages", "returns"], field_schema=self._adv_schema())
assert p.production_completed is True


class TestStreamDrainedConsumptionUndersized:
"""Regression: is_stream_drained must not index past a lazily-sized per-task
consumption tensor (crash: 'index 56 out of bounds for size 56'). The tensor
grows lazily; once production_completed there may be no further production
notify to expand it, so is_stream_drained must ensure capacity itself."""

def _schema(self):
return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}}

def test_drained_with_undersized_consumption_tensor(self):
from transfer_queue.controller import DataPartitionStatus

p = DataPartitionStatus(partition_id="train@undersize_0")
# Produce + complete a partition spanning indexes 0..7 (8 samples).
idxs = list(range(8))
p.global_indexes.update(idxs)
p.actual_sample_count = 8
p.pending_last_indexes.update(idxs)
p.has_pending_last = True
p.pending_last_fields.update(["x"])
p.update_production_status(idxs, ["x"], field_schema=self._schema())
assert p.production_completed is True

# Force the task's consumption tensor to be SMALLER than max(active)+1,
# mimicking a tensor that was sized before later samples were activated.
import torch

p.consumption_status["actor_train"] = torch.zeros(3, dtype=torch.int8)

# Must not raise (previously IndexError), and not be drained (nothing consumed).
assert p.is_stream_drained("actor_train") is False

# After consuming all, drained becomes True (capacity ensured internally).
p.mark_consumed("actor_train", idxs)
assert p.is_stream_drained("actor_train") is True
Loading
Loading