From f98d76aa89ed49714db3b1eef35d9d08a52a71f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=81=E6=9C=AC=E5=93=B2?= Date: Thu, 11 Jun 2026 09:56:32 +0000 Subject: [PATCH 1/2] [feat] Add StreamingTokenBudgetSampler for token-budget streaming fetch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce StreamingTokenBudgetSampler and wire up the controller, client, and storage managers to support a token-budget fetch mode for fully-async / dynamic-batch consumers: - get_metadata/get_meta accept token_budget (mutually exclusive with batch_size); the controller polls the streaming sampler instead of waiting for N ready samples. - user_custom_meta is written before samples are marked ready, so streaming consumers never observe a ready sample without its custom_meta. - async_put accepts inline custom_meta that lands atomically with readiness, avoiding the put/set_custom_meta round-trip and race. Signed-off-by: 宁本哲 --- tests/test_samplers.py | 522 +++++++++++++++ transfer_queue/__init__.py | 2 + transfer_queue/client.py | 30 +- transfer_queue/controller.py | 226 +++++-- transfer_queue/sampler/__init__.py | 10 +- .../sampler/streaming_token_budget_sampler.py | 600 ++++++++++++++++++ transfer_queue/storage/managers/base.py | 17 + .../managers/simple_backend_manager.py | 16 +- 8 files changed, 1361 insertions(+), 62 deletions(-) create mode 100644 transfer_queue/sampler/streaming_token_budget_sampler.py diff --git a/tests/test_samplers.py b/tests/test_samplers.py index c8e8f843..21782837 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -27,6 +27,7 @@ get_seqlen_balanced_partitions, ) from transfer_queue.sampler.sequential_sampler import SequentialSampler +from transfer_queue.sampler.streaming_token_budget_sampler import StreamingTokenBudgetSampler class TestBaseSampler: @@ -1130,6 +1131,527 @@ def test_identical_seqlens(self): assert sums[0] == sums[1] == 100 +class TestStreamingTokenBudgetSampler: + """Test cases for StreamingTokenBudgetSampler.""" + + class MockPartition: + """Minimal mock for DataPartitionStatus providing get_custom_meta.""" + + def __init__(self, custom_meta: dict[int, dict]): + self._custom_meta = custom_meta + + def get_custom_meta(self, global_indices: list[int]) -> dict[int, dict]: + return {idx: self._custom_meta.get(idx, {}) for idx in global_indices} + + @staticmethod + def _partition_with_uniform_lengths(indexes, length): + return TestStreamingTokenBudgetSampler.MockPartition({i: {"total_lengths": length} for i in indexes}) + + # ---- Initialization ---- + + def test_initialization_defaults(self): + sampler = StreamingTokenBudgetSampler() + assert isinstance(sampler, GRPOGroupNSampler) + assert sampler.n_samples_per_prompt == 1 + assert sampler.balance_unit_multiplier == 1 + assert sampler._buckets == {} + assert sampler._assigned_global == {} + assert sampler._resolved_lengths == {} + + def test_initialization_invalid_balance_unit_multiplier(self): + with pytest.raises(ValueError) as exc_info: + StreamingTokenBudgetSampler(balance_unit_multiplier=0) + assert "balance_unit_multiplier must be positive" in str(exc_info.value) + + with pytest.raises(ValueError): + StreamingTokenBudgetSampler(balance_unit_multiplier=-3) + + def test_initialization_invalid_n_samples_per_prompt(self): + # Inherited validation from GRPOGroupNSampler. + with pytest.raises(ValueError) as exc_info: + StreamingTokenBudgetSampler(n_samples_per_prompt=0) + assert "must be positive" in str(exc_info.value) + + # ---- Fallback path (no token_budget) ---- + + def test_fallback_to_grpo_without_token_budget(self): + """Without token_budget, delegate to the inherited GRPO sample().""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2) + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + + sampled, consumed = sampler.sample(ready_indexes, batch_size=4, task_name="ref", partition_id="p0") + + assert sampled == [0, 1, 2, 3] + assert consumed == [0, 1, 2, 3] + + def test_fallback_strips_streaming_only_kwargs(self): + """Streaming-only kwargs must not leak into the GRPO fallback call.""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2) + ready_indexes = [0, 1, 2, 3] + + # dp_size / allow_underfill / partition are streaming extras; GRPO must + # not choke on them. (token_budget absent → fallback path.) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size=2, + task_name="ref", + partition_id="p0", + dp_size=2, + allow_underfill=True, + partition=object(), + ) + + assert sampled == [0, 1] + assert consumed == [0, 1] + + # ---- Argument validation (token_budget path) ---- + + def test_requires_dp_rank_and_dp_size(self): + sampler = StreamingTokenBudgetSampler() + with pytest.raises(ValueError) as exc_info: + sampler.sample([0, 1], batch_size=0, token_budget=100, partition=object()) + assert "dp_rank" in str(exc_info.value) + + def test_requires_partition(self): + sampler = StreamingTokenBudgetSampler() + with pytest.raises(ValueError) as exc_info: + sampler.sample( + [0, 1, 2, 3], + batch_size=0, + token_budget=100, + dp_rank=0, + dp_size=1, + ) + assert "partition" in str(exc_info.value) + + # ---- Basic token-budget slicing (single DP) ---- + + def test_single_dp_packs_without_overshooting_budget(self): + """Single DP packs the largest prefix that does NOT overshoot the budget. + + With samples of length 100 and budget 250, the third sample would push the + slice to 300 > 250, so the slice stops at 2 samples (200). Including a + sample that overshoots the budget risks an oversized micro-batch (OOM). + """ + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_indexes, 100) + + sampled, consumed = sampler.sample( + ready_indexes, + batch_size=0, + task_name="actor", + partition_id="p0", + token_budget=250, + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + ) + + assert sampled == consumed + # 100 + 100 = 200 <= 250; adding a third (300) would overshoot. + assert sampled == [0, 1] + + def test_single_dp_exact_budget(self): + """When a prefix sums exactly to the budget, it is returned in full.""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_indexes, 100) + + sampled, _ = sampler.sample( + ready_indexes, + batch_size=0, + task_name="actor", + partition_id="p0", + token_budget=200, # exactly two samples + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + ) + + assert sampled == [0, 1] + + def test_single_dp_oversized_sample_yields_at_least_one(self): + """A single sample exceeding the budget is still returned (progress).""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_indexes, 1000) + + sampled, consumed = sampler.sample( + ready_indexes, + batch_size=0, + task_name="actor", + partition_id="p0", + token_budget=100, # smaller than a single sample + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + ) + + assert len(sampled) == 1 + assert sampled == consumed + + # ---- Cross-DP balancing ---- + + def test_cross_dp_token_balance(self): + """Two DPs at the same batch_index get token-balanced, disjoint slices.""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + # Two long, two short → balanced split should pair long+short per DP. + partition = self.MockPartition( + { + 0: {"total_lengths": 100}, + 1: {"total_lengths": 100}, + 2: {"total_lengths": 10}, + 3: {"total_lengths": 10}, + } + ) + common = dict( + task_name="actor", + partition_id="p0", + dp_size=2, + batch_index=0, + partition=partition, + token_budget=110, + ) + + sampled_0, consumed_0 = sampler.sample(ready_indexes, 0, dp_rank=0, **common) + sampled_1, consumed_1 = sampler.sample(ready_indexes, 0, dp_rank=1, **common) + + # Disjoint and fully covering. + assert set(sampled_0).isdisjoint(sampled_1) + assert set(sampled_0 + sampled_1) == {0, 1, 2, 3} + assert sampled_0 == consumed_0 + assert sampled_1 == consumed_1 + + lengths = {0: 100, 1: 100, 2: 10, 3: 10} + tok_0 = sum(lengths[i] for i in sampled_0) + tok_1 = sum(lengths[i] for i in sampled_1) + # Perfect balance: each DP gets one long + one short = 110. + assert tok_0 == tok_1 == 110 + + # ---- PP-stage cache (batch_index alignment) ---- + + def test_batch_index_cache_is_stable(self): + """Repeated requests for the same (dp_rank, batch_index) hit the cache.""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_indexes, 50) + kwargs = dict( + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=2, + batch_index=0, + partition=partition, + token_budget=50, + ) + + first, _ = sampler.sample(ready_indexes, 0, **kwargs) + # Second call with a DIFFERENT ready pool must still return the cached slice. + second, _ = sampler.sample([10, 11, 12, 13], 0, **kwargs) + + assert first == second + assert first # non-empty + + def test_different_batch_index_advances_stream(self): + """Different batch_index consumes the next slice (no re-issue).""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_all = [0, 1, 2, 3, 4, 5, 6, 7] + partition = self._partition_with_uniform_lengths(ready_all, 100) + base = dict( + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=1, + partition=partition, + token_budget=100, + ) + + b0, _ = sampler.sample(ready_all, 0, batch_index=0, **base) + # Remaining ready pool excludes what batch 0 took. + remaining = [i for i in ready_all if i not in b0] + b1, _ = sampler.sample(remaining, 0, batch_index=1, **base) + + assert b0 + assert b1 + assert set(b0).isdisjoint(b1) + + # ---- End-of-stream tail flush ---- + + def test_tail_flush_on_production_done_drains_all(self): + """production_done releases a sub-balance_unit remainder into the buckets. + + A single batch_index returns one budget-sized micro-batch per DP, but the + tail flush guarantees EVERY ready sample is assigned to some DP bucket so + nothing is orphaned (which would livelock end-of-stream). We drain across + successive batch_index calls and assert full coverage with no overlap. + """ + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) + # balance_unit = dp_size(2) * n(2) * mult(4) = 16, but only 4 ready. + all_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(all_indexes, 10) + + # The controller removes consumed samples from the ready pool, so we model + # that by feeding only the not-yet-drained indexes on each new batch_index. + drained: list[int] = [] + for batch_index in range(4): # more than enough to fully drain + ready_now = [i for i in all_indexes if i not in drained] + for dp_rank in (0, 1): + sampled, consumed = sampler.sample( + ready_now, + 0, + task_name="actor", + partition_id="p0", + dp_rank=dp_rank, + dp_size=2, + batch_index=batch_index, + partition=partition, + token_budget=10, + production_done=True, + ) + assert sampled == consumed + drained.extend(sampled) + + # Every produced sample drained exactly once across the stream. + assert sorted(drained) == [0, 1, 2, 3] + assert len(drained) == len(set(drained)) + + def test_tail_flush_assigns_all_to_buckets(self): + """First production_done call must assign all ready samples to buckets.""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) + ready_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_indexes, 10) + + sampler.sample( + ready_indexes, + 0, + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=2, + batch_index=0, + partition=partition, + token_budget=10, + production_done=True, + ) + + assigned = sampler._assigned_global[("p0", "actor")] + bucketed = set() + for bucket in sampler._buckets[("p0", "actor")].values(): + bucketed.update(bucket) + # Everything not yet popped is still tracked in assigned and lives in a bucket. + assert assigned == bucketed + # All four samples are accounted for (either popped this call or bucketed). + popped = set(range(4)) - assigned + assert (assigned | popped) == {0, 1, 2, 3} + + def test_no_complete_group_without_production_done_returns_empty(self): + """No complete GRPO group + not end-of-stream → return empty (wait for more).""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) + # Non-consecutive indexes → GRPOGroupNSampler finds no complete group of 2. + ready_indexes = [0, 2] + partition = self._partition_with_uniform_lengths(ready_indexes, 10) + + sampled, consumed = sampler.sample( + ready_indexes, + 0, + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=2, + batch_index=0, + partition=partition, + token_budget=10, + production_done=False, + ) + + assert sampled == [] + assert consumed == [] + + def test_complete_group_drains_below_balance_unit(self): + """A single complete group is released even below balance_unit (trickle drain). + + With long responses the producer trickles one group at a time and the ready + pool stays below balance_unit; the sampler must still make progress rather + than wait for a full unit forever. + """ + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) + # balance_unit = 16, but a single consecutive group [0,1] is ready. + ready_indexes = [0, 1] + partition = self._partition_with_uniform_lengths(ready_indexes, 10) + + got_any = False + for dp_rank in (0, 1): + sampled, _ = sampler.sample( + ready_indexes, + 0, + task_name="actor", + partition_id="p0", + dp_rank=dp_rank, + dp_size=2, + batch_index=0, + partition=partition, + token_budget=10, + production_done=False, + ) + got_any = got_any or bool(sampled) + + assert got_any, "a complete group below balance_unit should still be released" + + # ---- Missing total_lengths fallback ---- + + def test_missing_total_lengths_uses_budget_fallback(self): + """Samples lacking total_lengths fall back to token_budget (safe over-estimate).""" + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + # No custom_meta at all → every sample missing total_lengths. + partition = self.MockPartition({}) + + sampled, consumed = sampler.sample( + ready_indexes, + 0, + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + token_budget=500, + ) + + # With fallback length == budget, one sample already meets the budget. + assert len(sampled) == 1 + assert sampled == consumed + + # ---- Assignment / no double-issue ---- + + def test_bucketed_samples_not_reassigned(self): + """Samples already sitting in a bucket are filtered out of the ready pool. + + The sampler tracks bucketed-but-not-yet-popped samples in ``assigned`` and + excludes them from ``available_ready`` so a balance round never re-assigns a + sample that is already waiting in some DP's bucket. (Consumed samples are + filtered by the controller, not the sampler.) + """ + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + # dp_size=2, budget small → each balance round assigns 2 samples (one per + # DP) but each DP pops only its budget slice, leaving the rest bucketed. + ready_all = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_all, 100) + + # First call (batch 0) seeds buckets for both DPs. + sampler.sample( + ready_all, + 0, + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=2, + batch_index=0, + partition=partition, + token_budget=100, + ) + sampler.sample( + ready_all, + 0, + task_name="actor", + partition_id="p0", + dp_rank=1, + batch_index=0, + dp_size=2, + partition=partition, + token_budget=100, + ) + + assigned = sampler._assigned_global[("p0", "actor")] + bucketed = set() + for bucket in sampler._buckets[("p0", "actor")].values(): + bucketed.update(bucket) + # Invariant: assigned set == union of bucket contents (no leak, no double-count). + assert assigned == bucketed + # No index appears in more than one DP bucket. + all_bucket_items = [i for b in sampler._buckets[("p0", "actor")].values() for i in b] + assert len(all_bucket_items) == len(set(all_bucket_items)) + + # ---- clear_cache ---- + + def test_clear_cache_removes_partition_state(self): + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + ready_indexes = [0, 1, 2, 3] + partition = self._partition_with_uniform_lengths(ready_indexes, 50) + + sampler.sample( + ready_indexes, + 0, + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + token_budget=50, + ) + + # State should now exist for p0. + assert any(k[0] == "p0" for k in sampler._buckets) + assert "p0" in sampler._states + + sampler.clear_cache("p0") + + assert all(k[0] != "p0" for k in sampler._buckets) + assert all(k[0] != "p0" for k in sampler._assigned_global) + assert all(k[0] != "p0" for k in sampler._resolved_lengths) + assert "p0" not in sampler._states + # Internal GRPO scratch state used by balance rounds must be cleared too. + assert "__streaming_internal__" not in sampler._states + + def test_clear_cache_only_affects_target_partition(self): + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + for pid in ("p0", "p1"): + partition = self._partition_with_uniform_lengths([0, 1, 2, 3], 50) + sampler.sample( + [0, 1, 2, 3], + 0, + task_name="actor", + partition_id=pid, + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + token_budget=50, + ) + + sampler.clear_cache("p0") + + assert all(k[0] != "p0" for k in sampler._buckets) + assert any(k[0] == "p1" for k in sampler._buckets) + + # ---- Empty input ---- + + def test_empty_ready_indexes_returns_empty(self): + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + partition = self.MockPartition({}) + + sampled, consumed = sampler.sample( + [], + 0, + task_name="actor", + partition_id="p0", + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + token_budget=100, + ) + + assert sampled == [] + assert consumed == [] + + class TestSamplerIntegration: """Integration tests for samplers.""" diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index c74a3b87..b5c38cd6 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -40,6 +40,7 @@ from .sampler.rank_aware_sampler import RankAwareSampler from .sampler.seqlen_balanced_sampler import SeqlenBalancedSampler from .sampler.sequential_sampler import SequentialSampler +from .sampler.streaming_token_budget_sampler import StreamingTokenBudgetSampler __all__ = ( [ @@ -78,6 +79,7 @@ "SequentialSampler", "RankAwareSampler", "SeqlenBalancedSampler", + "StreamingTokenBudgetSampler", ] ) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 79b5c5d2..b8e32f2b 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -159,11 +159,12 @@ async def wrapper(self, *args, **kwargs): async def async_get_meta( self, data_fields: list[str], - batch_size: int, - partition_id: str, + batch_size: Optional[int] = None, + partition_id: str = "", mode: str = "fetch", task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, + token_budget: Optional[int] = None, socket: Optional[zmq.asyncio.Socket] = None, ) -> BatchMeta: """Asynchronously fetch data metadata from the controller via ZMQ. @@ -227,6 +228,7 @@ async def async_get_meta( "mode": mode, "task_name": task_name, "sampling_config": sampling_config, + "token_budget": token_budget, }, ) @@ -324,6 +326,7 @@ async def async_put( data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None, + custom_meta: Optional[list[dict[str, Any]]] = None, ) -> BatchMeta: """Asynchronously write data to storage units based on metadata. @@ -333,6 +336,13 @@ async def async_put( During put, the custom_meta in metadata will update the corresponding custom_meta in TransferQueue Controller. + If ``custom_meta`` is provided (a per-sample list aligned with ``data`` rows), + it is attached to the metadata BEFORE the data is written, so it rides the same + readiness notification and lands atomically with the samples becoming consumable. + This avoids a separate ``async_set_custom_meta`` round-trip and the associated + race where a streaming consumer fetches a ready sample before its custom_meta + has been set. + Note: When using multiple workers for distributed execution, there may be data ordering inconsistencies between workers during put operations. @@ -408,12 +418,20 @@ async def async_put( if not metadata or metadata.size == 0: raise ValueError("metadata cannot be none or empty") + # Attach inline custom_meta before the write so it rides put_data's readiness + # notification and lands atomically with the samples becoming consumable. + if custom_meta is not None: + metadata.update_custom_meta(custom_meta) + with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put" ): await self.storage_manager.put_data(data, metadata) - await self.async_set_custom_meta(metadata) + # Inline custom_meta is already delivered atomically via put_data → notify; + # only fall back to the separate RPC when it was set out-of-band on metadata. + if custom_meta is None: + await self.async_set_custom_meta(metadata) logger.debug( f"[{self.client_id}]: partition {partition_id} put {metadata.size} samples to storage units successfully." @@ -1189,11 +1207,12 @@ def wrapper(*args, **kwargs): def get_meta( self, data_fields: list[str], - batch_size: int, - partition_id: str, + batch_size: Optional[int] = None, + partition_id: str = "", mode: str = "fetch", task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, + token_budget: Optional[int] = None, ) -> BatchMeta: """Synchronously fetch data metadata from the controller via ZMQ. @@ -1252,6 +1271,7 @@ def get_meta( mode=mode, task_name=task_name, sampling_config=sampling_config, + token_budget=token_budget, ) def set_custom_meta(self, metadata: BatchMeta) -> None: diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 90304f7c..dd6b62d7 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -58,7 +58,7 @@ logger.addHandler(handler) TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1)) -TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5)) +TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = float(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 0.5)) # Sample pre-allocation for StreamingDataLoader compatibility. # By pre-allocating sample indices (typically global_batch_size), consumers can accurately @@ -506,6 +506,7 @@ def update_production_status( field_names: list[str], field_schema: dict[str, dict[str, Any]], custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + user_custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields. @@ -521,6 +522,11 @@ def update_production_status( field_schema: Columnar field schema {field_name: {dtype, shape, is_nested, ...}} custom_backend_meta: Optional per-sample per-field custom metadata provided by storage backend + user_custom_meta: Optional user-defined per-sample custom_meta in + {global_index: {...}} format. When provided, it is written + BEFORE production_status is flipped to ready, so any sample a + sampler observes as ready is guaranteed to carry its custom_meta + (eliminates the put/set_custom_meta race for streaming consumers). Returns: True if update was successful, False on error @@ -548,6 +554,11 @@ def update_production_status( with self.data_status_lock: self.ensure_fields_capacity(required_fields) + # Write user custom_meta BEFORE marking samples ready, so a sampler + # that observes production_status==1 always sees the custom_meta too. + if user_custom_meta: + self.set_custom_meta(user_custom_meta) + with self.data_status_lock: # Update production status if self.production_status is not None and global_indices and field_names: @@ -1128,6 +1139,7 @@ def update_production_status( global_indexes: list[int], field_schema: dict[str, dict[str, Any]], custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + user_custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -1138,6 +1150,8 @@ def update_production_status( global_indexes: List of sample indices to update field_schema: Columnar field schema {field_name: {dtype, shape, is_nested, ...}} custom_backend_meta: Optional custom backend metadata + user_custom_meta: Optional user-defined per-sample custom_meta in + {global_index: {...}} format, written atomically before ready. Returns: True if update was successful, False otherwise @@ -1148,7 +1162,9 @@ def update_production_status( logger.error(f"Partition {partition_id} not found") return False - success = partition.update_production_status(global_indexes, field_names, field_schema, custom_backend_meta) + success = partition.update_production_status( + global_indexes, field_names, field_schema, custom_backend_meta, user_custom_meta + ) if success: logger.debug( f"[{self.controller_id}]: Updated production status for partition {partition_id}: " @@ -1246,6 +1262,7 @@ def get_metadata( task_name: str | None = None, batch_size: int | None = None, sampling_config: Optional[dict[str, Any]] = None, + token_budget: int | None = None, *args, **kwargs, ) -> BatchMeta: @@ -1261,7 +1278,15 @@ def get_metadata( - mode="force_fetch": Get metadata for unconsumed samples without sampling (excludes already consumed samples) task_name: Name of the consumer task (required for fetch modes) - batch_size: Number of samples to retrieve + batch_size: Number of samples to retrieve (sample-count fetch mode). + Mutually exclusive with ``token_budget``. + sampling_config: Sampler-specific kwargs (e.g. dp_rank, dp_size). + token_budget: If provided, switch to token-budget fetch mode. The + configured sampler (must be a streaming token-aware sampler such + as :class:`StreamingTokenBudgetSampler`) returns a slice whose + accumulated ``total_lengths`` reaches ``token_budget``. The + sampler decides readiness, so the controller skips the + "wait until N ready" loop and asks the sampler directly. *args: Additional positional arguments **kwargs: Additional keyword arguments @@ -1304,64 +1329,153 @@ def get_metadata( assert task_name is not None # Find ready samples within current data partition and package into BatchMeta when reading - if batch_size is None: - raise ValueError("must provide batch_size in fetch mode") + if token_budget is not None: + # Token-budget fetch path: the sampler is stateful (streaming) and + # decides readiness internally, so we skip the "wait for N ready" + # gate and instead poll the sampler with the current ready pool. + # If the sampler returns empty AND the partition has no unconsumed + # samples left, the rollout for this DP is done — return empty meta. + if batch_size is not None: + raise ValueError("token_budget and batch_size are mutually exclusive") + + sampling_config = sampling_config or {} + + # Polling semantics: return empty meta quickly if the sampler + # can't produce data right now. The request thread is single- + # threaded for all GET_META; holding it here would deadlock + # other consumers (e.g. reference / advantages) that need to + # push data into TQ before this actor can drain anything. + # The actor-side drain loop uses check_consumption_status to + # know when the partition is truly done. + start_time = time.time() + while True: + partition = self._get_partition(partition_id) + if partition is None: + # Producer hasn't created this partition yet + # (race: consumer called before any insert). Wait. + batch_global_indexes = [] + consumed_indexes: list[int] = [] + else: + ready_for_consume_indexes = self.scan_data_status(partition_id, data_fields, task_name) + + # Both task_name and partition_id arrive via the + # sampling_config splat — the relax client composes them + # into the sampling_config dict before calling get_meta + # (see relax/utils/data/stream_dataloader.py + # get_data_from_transfer_queue: config = {**sampling_config, + # "batch_index": batch_index, "partition_id": partition_id}). + # Passing them again here causes a "multiple values for + # keyword argument" TypeError. + # production_done: every pre-allocated sample of this + # partition has been produced for the requested fields, + # so no more data is coming. The streaming sampler uses + # this to trigger its end-of-stream tail flush (dump all + # remaining ready samples even if they can't form a full + # balance unit), preventing a tail livelock under DP>1. + production_done = False + try: + _, prod = partition.get_production_status_for_fields(data_fields, mask=True) + if prod is not None and prod.numel() > 0: + production_done = bool((prod == 1).all().item()) + except Exception: + production_done = False + + batch_global_indexes, consumed_indexes = self.sampler( + ready_for_consume_indexes, + 0, # batch_size unused in token-budget mode + partition=partition, + token_budget=token_budget, + production_done=production_done, + **sampling_config, + **kwargs, + ) - start_time = time.time() - while True: - # ready_for_consume_indexes: samples where all required fields are produced - # (production status is ready) and not yet consumed - ready_for_consume_indexes = self.scan_data_status(partition_id, data_fields, task_name) + if batch_global_indexes: + break - if len(ready_for_consume_indexes) < batch_size: - if self.polling_mode: - # Return cached result if available - if self.sampler.has_cached_result(partition_id, task_name, sampling_config): - break + # Sampler returned nothing. Two possibilities: + # 1) Partition still has unconsumed samples but not enough for a + # balance round → wait for more production. + # 2) Partition is fully consumed (production complete AND every + # allocated sample marked consumed) → signal end-of-stream. + if partition is not None: + _, cons = partition.get_consumption_status(task_name, mask=True) + if ( + partition.production_status is not None + and cons.numel() > 0 + and bool((cons == 1).all().item()) + ): + # True end-of-stream: everything produced has been + # consumed. Signal drain completion. + return BatchMeta.empty() + + # Bounded wait: short backoff inside the handler so other + # consumers' GET_META requests aren't starved. When the + # wait exceeds TQ_CONTROLLER_GET_METADATA_TIMEOUT, return + # empty meta and let the actor-side drain loop retry. + if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: + return BatchMeta.empty() + time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) + else: + if batch_size is None: + raise ValueError("must provide batch_size or token_budget in fetch mode") + + start_time = time.time() + while True: + # ready_for_consume_indexes: samples where all required fields are produced + # (production status is ready) and not yet consumed + ready_for_consume_indexes = self.scan_data_status(partition_id, data_fields, task_name) + + if len(ready_for_consume_indexes) < batch_size: + if self.polling_mode: + # Return cached result if available + if self.sampler.has_cached_result(partition_id, task_name, sampling_config): + break + else: + logger.debug( + f"[{self.controller_id}]: Not enough data for task {task_name} in " + f"partition {partition_id}. Required: {batch_size}, " + f"Available: {len(ready_for_consume_indexes)}." + f" Returning None due to polling mode." + ) + return BatchMeta.empty() else: - logger.debug( - f"[{self.controller_id}]: Not enough data for task {task_name} in " - f"partition {partition_id}. Required: {batch_size}, " - f"Available: {len(ready_for_consume_indexes)}." - f" Returning None due to polling mode." + logger.warning( + f"[{self.controller_id}]: Insufficient data for task {task_name}. " + f"Required: {batch_size} " + f"samples with fields {data_fields} in partition {partition_id}, but only have " + f"{len(ready_for_consume_indexes)} samples meeting the criteria. " + f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." + ) + time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) + if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: + raise TimeoutError( + f"Timeout while waiting for sufficient data for task {task_name}. " + f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}" ) - return BatchMeta.empty() else: - logger.warning( - f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} " - f"samples with fields {data_fields} in partition {partition_id}, but only have " - f"{len(ready_for_consume_indexes)} samples meeting the criteria. " - f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." - ) - time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) - if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: - raise TimeoutError( - f"Timeout while waiting for sufficient data for task {task_name}. " - f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}" - ) - else: - break - - batch_global_indexes, consumed_indexes = self.sampler( - ready_for_consume_indexes, - batch_size, - partition=self._get_partition(partition_id), - **(sampling_config or {}), - **kwargs, - ) - - # Check if we got valid results from the sampler. - # Some samplers (e.g. SeqlenBalancedSampler) may return variable-size - # batches per DP rank, so we only check for empty results. - if len(batch_global_indexes) == 0: - if self.polling_mode: - return BatchMeta.empty() - raise RuntimeError( - f"Sampler returned no samples. Please check the sampler logic. " - f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, " - f"after sampling: {len(batch_global_indexes)}" + break + + batch_global_indexes, consumed_indexes = self.sampler( + ready_for_consume_indexes, + batch_size, + partition=self._get_partition(partition_id), + **(sampling_config or {}), + **kwargs, ) + # Check if we got valid results from the sampler. + # Some samplers (e.g. SeqlenBalancedSampler) may return variable-size + # batches per DP rank, so we only check for empty results. + if len(batch_global_indexes) == 0: + if self.polling_mode: + return BatchMeta.empty() + raise RuntimeError( + f"Sampler returned no samples. Please check the sampler logic. " + f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, " + f"after sampling: {len(batch_global_indexes)}" + ) + # Mark samples as consumed if in fetch mode if consumed_indexes: partition = self.partitions[partition_id] @@ -1835,11 +1949,12 @@ def _process_request(self): metadata = self.get_metadata( data_fields=params["data_fields"], - batch_size=params["batch_size"], + batch_size=params.get("batch_size"), partition_id=params["partition_id"], mode=params.get("mode", "fetch"), task_name=params.get("task_name"), sampling_config=params.get("sampling_config", {}), + token_budget=params.get("token_budget"), ) response_msg = ZMQMessage.create( @@ -2083,6 +2198,7 @@ def _update_data_status(self): global_indexes=message_data.get("global_indexes", []), field_schema=message_data.get("field_schema", {}), custom_backend_meta=message_data.get("custom_backend_meta", {}), + user_custom_meta=message_data.get("user_custom_meta", None), ) if success: diff --git a/transfer_queue/sampler/__init__.py b/transfer_queue/sampler/__init__.py index 302ed3f9..ed6e13ab 100644 --- a/transfer_queue/sampler/__init__.py +++ b/transfer_queue/sampler/__init__.py @@ -18,5 +18,13 @@ from .rank_aware_sampler import RankAwareSampler from .seqlen_balanced_sampler import SeqlenBalancedSampler from .sequential_sampler import SequentialSampler +from .streaming_token_budget_sampler import StreamingTokenBudgetSampler -__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler", "SeqlenBalancedSampler"] +__all__ = [ + "BaseSampler", + "SequentialSampler", + "GRPOGroupNSampler", + "RankAwareSampler", + "SeqlenBalancedSampler", + "StreamingTokenBudgetSampler", +] diff --git a/transfer_queue/sampler/streaming_token_budget_sampler.py b/transfer_queue/sampler/streaming_token_budget_sampler.py new file mode 100644 index 00000000..9ea827e5 --- /dev/null +++ b/transfer_queue/sampler/streaming_token_budget_sampler.py @@ -0,0 +1,600 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any + +from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler +from transfer_queue.sampler.seqlen_balanced_sampler import get_seqlen_balanced_partitions + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + + +class StreamingTokenBudgetSampler(GRPOGroupNSampler): + """Streaming sampler that returns samples up to a token budget per call. + + Designed for fully-async + dynamic-batch consumers that pull data from + TransferQueue as a stream rather than waiting for the entire rollout. + + Semantics + --------- + Each call to :meth:`sample` is parameterised (via ``sampling_config`` / + kwargs) by ``dp_rank`` and ``token_budget``. The sampler returns a + contiguous slice of GRPO-complete prompt groups assigned to this + ``dp_rank`` whose accumulated ``total_lengths`` reaches ``token_budget``. + + Internally the sampler maintains per-DP buckets of "assigned but not yet + consumed" indexes. When a bucket cannot satisfy the budget, the sampler + pulls the next ``balance_unit`` (= ``dp_size * n_samples_per_prompt * k``) + GRPO-complete groups out of the currently-ready pool, runs a token-balanced + partition across ``dp_size`` DP buckets, and tries again. Inside one + balance unit, token totals are balanced across DPs; across units totals + may differ (acceptable carry-over). + + When the ready pool cannot supply another full balance unit AND + ``allow_underfill=True``, the sampler returns whatever remains in this + DP's bucket (possibly underfill, possibly empty). + + Required ``custom_meta``: each sample must have + ``{"total_lengths": }`` populated by the producer at insert time. + + Required ``sampling_config`` keys: + - ``dp_rank``: int, DP rank of the caller. + - ``dp_size``: int, total DP world size. + - ``token_budget``: int, target accumulated token count for this fetch. + - ``allow_underfill``: bool, default True. If False, return ``([], [])`` + when budget cannot be reached (controller will retry). + + The ``partition`` kwarg (``DataPartitionStatus``) must be passed by the + controller; the sampler uses ``partition.get_custom_meta`` to read + ``total_lengths``. + + The ``batch_size`` argument is ignored (the budget governs slice size). + """ + + def __init__( + self, + n_samples_per_prompt: int = 1, + balance_unit_multiplier: int = 1, + ): + """Create a streaming token-budget sampler. + + Args: + n_samples_per_prompt: GRPO group size. Must be > 0. + balance_unit_multiplier: A balance unit pulls + ``balance_unit_multiplier * dp_size * n_samples_per_prompt`` + samples from the ready pool at a time. Larger values give + better token balance but require waiting for more samples to + be ready before any DP can progress. Default 1 = minimum unit. + """ + super().__init__(n_samples_per_prompt=n_samples_per_prompt) + if balance_unit_multiplier <= 0: + raise ValueError(f"balance_unit_multiplier must be positive, got {balance_unit_multiplier}") + self.balance_unit_multiplier = balance_unit_multiplier + + # Per (partition_id, task_name) state. + # _buckets[(pid, tn)][dp_rank] -> list[int] of indexes assigned to this DP + # but not yet returned to the caller. + self._buckets: dict[tuple[str, str], dict[int, list[int]]] = {} + # _assigned_global[(pid, tn)] -> set[int] of indexes currently in any bucket + # (used to filter ready_indexes coming from the controller, since they + # are not yet marked consumed until the caller actually fetches them). + self._assigned_global: dict[tuple[str, str], set[int]] = {} + # _resolved_lengths[(pid, tn)] -> dict[idx, int] of total_lengths used + # for token-budget accounting. Populated in _run_balance_round from + # custom_meta (with fallback to round-average when missing). Reused + # by _select_up_to_budget so a single sample sees a stable length + # across calls even if custom_meta later changes. + self._resolved_lengths: dict[tuple[str, str], dict[int, int]] = {} + + def sample( + self, + ready_indexes: list[int], + batch_size: int, + task_name: str = "", + partition_id: str = "", + *args: Any, + **kwargs: Any, + ) -> tuple[list[int], list[int]]: + """Return up to ``token_budget`` worth of samples assigned to ``dp_rank``. + + See class docstring for the streaming semantics. + + Fallback behaviour: when called WITHOUT ``token_budget`` (e.g. by + :func:`compute_ref_log_prob` / :func:`compute_actor_log_prob` which + use the legacy sample-count fetch), delegate to the inherited + :class:`GRPOGroupNSampler.sample` so those consumers keep working + unchanged on the same controller. + """ + token_budget = kwargs.get("token_budget", None) + if token_budget is None: + # Strip kwargs that GRPO doesn't expect (extras supplied for the + # streaming path by the controller / relax client). + grpo_kwargs = { + k: v for k, v in kwargs.items() if k not in ("token_budget", "dp_size", "allow_underfill", "partition") + } + return super().sample( + ready_indexes, + batch_size, + task_name=task_name, + partition_id=partition_id, + **grpo_kwargs, + ) + + dp_rank = kwargs.get("dp_rank", None) + dp_size = kwargs.get("dp_size", None) + allow_underfill = kwargs.get("allow_underfill", True) + partition = kwargs.get("partition", None) + batch_index = kwargs.get("batch_index", None) + # production_done: the controller tells us when EVERY pre-allocated + # sample of this partition has been produced (no more data is coming). + # At that point the tail-flush may dump all remaining ready samples. + production_done = bool(kwargs.get("production_done", False)) + + if dp_rank is None or dp_size is None: + raise ValueError( + "StreamingTokenBudgetSampler requires dp_rank and dp_size in sampling_config " + "when token_budget is provided" + ) + + # PP-stage cache: when multiple PP stages request the same + # (partition_id, task_name, dp_rank, batch_index), return the + # cached result from the first call so all stages see identical data. + if batch_index is not None: + cached = self._states.get(partition_id, {}).get(task_name, {}).get(dp_rank, {}).get(batch_index, None) + if cached is not None: + logger.debug( + "[stream-sampler] cache HIT: task=%s pid=%s dp=%s batch_idx=%s", + task_name, + partition_id, + dp_rank, + batch_index, + ) + return cached + + if partition is None: + raise ValueError("StreamingTokenBudgetSampler requires partition kwarg from the controller") + + # batch_index is the alignment key. When it is present (always true on + # the streaming train path), the FIRST request for a given batch_index + # (from any dp_rank / PP stage) atomically prepares the micro-batch + # slices for ALL dp_ranks against a single ``available_ready`` snapshot + # and caches every dp's result. Every subsequent request — other PP + # stages of this dp, or other dp_ranks — hits the cache and gets the + # identical, pre-determined data. This keeps all dp_ranks in lockstep + # by batch_index and removes the order-dependent state mutation that + # broke PP>1 (orphaned-in-``assigned`` samples → under-consume/deadlock). + if batch_index is not None: + self._prepare_batch_index( + partition_id, + task_name, + batch_index, + dp_size, + token_budget, + allow_underfill, + ready_indexes, + partition, + production_done, + ) + return self._states.get(partition_id, {}).get(task_name, {}).get(dp_rank, {}).get(batch_index, ([], [])) + + # Fallback: no batch_index (should not happen on the streaming path) — + # serve this single dp_rank immediately from shared state. + return self._extract_one_dp( + partition_id, + task_name, + dp_rank, + dp_size, + token_budget, + allow_underfill, + ready_indexes, + partition, + batch_index, + ) + + def _prepare_batch_index( + self, + partition_id: str, + task_name: str, + batch_index: int, + dp_size: int, + token_budget: int, + allow_underfill: bool, + ready_indexes: list[int], + partition, + production_done: bool = False, + ) -> None: + """Atomically prepare and cache one micro-batch slice for every dp_rank + at ``batch_index``. + + Empty slices are NOT cached: if no data can be served for this + batch_index yet (rollout still producing), we leave the cache untouched + so the next poll re-evaluates against freshly produced samples. Once a + round yields real data, all participating dp_ranks' (non-empty) slices + are cached together against a single ``available_ready`` snapshot, so + every PP stage / dp request for this batch_index becomes a pure cache + read with identical, pre-determined data. + + Early-return when (dp_rank=0, batch_index) already cached: a real round + was prepared before; all dp entries were written together. + """ + already = self._states.get(partition_id, {}).get(task_name, {}).get(0, {}).get(batch_index, None) + if already is not None: + return + + key = (partition_id, task_name) + buckets = self._buckets.setdefault(key, {}) + assigned = self._assigned_global.setdefault(key, set()) + resolved_lengths = self._resolved_lengths.setdefault(key, {}) + + # Single shared snapshot for the whole batch_index across all dp_ranks. + available_ready = [i for i in ready_indexes if i not in assigned] + balance_unit = dp_size * self.n_samples_per_prompt * self.balance_unit_multiplier + self._token_budget_for_fallback = token_budget + + def _bucket_tokens(dp_i: int) -> int: + return sum(resolved_lengths.get(i, 0) for i in buckets.get(dp_i, [])) + + is_eos = production_done + + logger.debug( + "[stream-sampler] prepare batch_idx=%s task=%s pid=%s dp_size=%d budget=%d ready=%d avail=%d eos=%s", + batch_index, + task_name, + partition_id, + dp_size, + token_budget, + len(ready_indexes), + len(available_ready), + is_eos, + ) + + # ── Phase 1: balance rounds ────────────────────────────────────── + # Pull balance rounds (each token-balances a chunk of complete GRPO + # groups across ALL dp buckets) until every dp bucket holds a + # token-budget worth or the ready pool runs dry. + # + # Round size: ideally a full ``balance_unit`` for best token balance, + # but we must NOT require a full balance_unit to be ready before making + # progress. With long responses the producer trickles ~1 group at a + # time and the slow per-mb consumer keeps ``ready`` pinned below + # balance_unit (observed: stuck at ready=8 < balance_unit=16 forever). + # So the minimum round is ONE complete group (``n_samples_per_prompt``), + # split per-sample across dps. This drains the trickle; token balance + # is slightly worse on small rounds, which the dummy-pad tolerates. + group = self.n_samples_per_prompt + max_rounds = len(available_ready) // max(group, 1) + 2 + rounds = 0 + while len(available_ready) >= group and rounds < max_rounds: + if all(_bucket_tokens(dp_i) >= token_budget for dp_i in range(dp_size)): + break + # Round size = as many full groups as are ready, capped at balance_unit. + round_size = min(balance_unit, (len(available_ready) // group) * group) + prev_avail = len(available_ready) + if not self._run_balance_round( + available_ready, round_size, dp_size, partition, buckets, assigned, resolved_lengths + ): + break + available_ready = [i for i in available_ready if i not in assigned] + rounds += 1 + if len(available_ready) >= prev_avail: + break # no progress — avoid spinning + + # ── Phase 2: unconditional end-of-stream tail flush ────────────── + # Once the producer is DONE for this partition (is_eos) and a sub- + # balance_unit remainder is still sitting in the ready pool, distribute + # EVERY remaining ready sample across the dp buckets — bypassing the + # GRPO-group and balance_unit constraints entirely. This guarantees no + # produced sample is ever orphaned (which would keep all_consumed False + # forever → livelock). GRPO group integrity is not needed here: + # advantages are precomputed per-sample upstream. Cross-dp imbalance + # from this uneven flush is corrected by the iterator's dummy-pad. + if is_eos and available_ready: + self._flush_tail(available_ready, dp_size, partition, buckets, assigned, resolved_lengths) + available_ready = [i for i in available_ready if i not in assigned] + + # ── Phase 3: slice one mb per dp and commit independently ──────── + # Each dp pops up to a token-budget slice from its own bucket and + # caches it independently. We do NOT require all dps to be non-empty + # (no all-or-nothing rollback — that livelocked at the tail): cross-dp + # mb-count differences are handled by the iterator's dummy-pad barrier. + for dp_i in range(dp_size): + bucket = buckets.setdefault(dp_i, []) + if not bucket: + continue + sel_count = self._select_up_to_budget(bucket, resolved_lengths, token_budget) + sel_count = max(sel_count, 1) # always make progress + result = self._pop_and_return(bucket, sel_count, assigned) + if result[0]: + self._cache_result(partition_id, task_name, dp_i, batch_index, result) + + logger.debug( + "[stream-sampler] batch_idx=%s prepared: per-dp cached sizes=%s remaining_avail=%d eos=%s", + batch_index, + [ + len(self._states.get(partition_id, {}).get(task_name, {}).get(dp_i, {}).get(batch_index, ([],))[0]) + for dp_i in range(dp_size) + ], + len(available_ready), + is_eos, + ) + + def _extract_one_dp( + self, + partition_id: str, + task_name: str, + dp_rank: int, + dp_size: int, + token_budget: int, + allow_underfill: bool, + ready_indexes: list[int], + partition, + batch_index: int | None, + ) -> tuple[list[int], list[int]]: + """Legacy single-dp extraction (fallback when batch_index is None).""" + key = (partition_id, task_name) + buckets = self._buckets.setdefault(key, {}) + bucket = buckets.setdefault(dp_rank, []) + assigned = self._assigned_global.setdefault(key, set()) + resolved_lengths = self._resolved_lengths.setdefault(key, {}) + + available_ready = [i for i in ready_indexes if i not in assigned] + balance_unit = dp_size * self.n_samples_per_prompt * self.balance_unit_multiplier + + while True: + if bucket: + sel_count = self._select_up_to_budget(bucket, resolved_lengths, token_budget) + if sel_count > 0: + cur_tokens = sum(resolved_lengths.get(i, 0) for i in bucket[:sel_count]) + if cur_tokens >= token_budget: + result = self._pop_and_return(bucket, sel_count, assigned) + self._cache_result(partition_id, task_name, dp_rank, batch_index, result) + return result + if sel_count < len(bucket): + result = self._pop_and_return(bucket, sel_count, assigned) + self._cache_result(partition_id, task_name, dp_rank, batch_index, result) + return result + + if len(available_ready) >= balance_unit: + self._token_budget_for_fallback = token_budget + assigned_this_round = self._run_balance_round( + available_ready, + balance_unit, + dp_size, + partition, + buckets, + assigned, + resolved_lengths, + ) + if assigned_this_round: + available_ready = [i for i in available_ready if i not in assigned] + bucket = buckets[dp_rank] + continue + + if allow_underfill and bucket: + result = self._pop_and_return(bucket, len(bucket), assigned) + self._cache_result(partition_id, task_name, dp_rank, batch_index, result) + return result + return [], [] + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _cache_result( + self, + partition_id: str, + task_name: str, + dp_rank: int, + batch_index: int | None, + result: tuple[list[int], list[int]], + ) -> None: + """Store a sampling result so other PP stages can retrieve it.""" + if batch_index is None: + return + self._states.setdefault(partition_id, {}).setdefault(task_name, {}).setdefault(dp_rank, {})[batch_index] = ( + result + ) + + def _select_up_to_budget(self, bucket: list[int], resolved_lengths: dict[int, int], token_budget: int) -> int: + """Return the largest prefix count k that packs into the budget. + + k is grown while ``sum(lengths[bucket[:k]]) <= token_budget``; the first + sample that would push the running total OVER the budget stops the slice + (so the returned slice never overshoots, avoiding oversized micro-batches). + A single leading sample larger than the whole budget is still included + (k>=1) so the stream always makes progress. Returns ``len(bucket)`` when + the entire bucket fits. Reads per-sample lengths from ``resolved_lengths`` + (populated by :meth:`_run_balance_round` with custom_meta + fallback).""" + if not bucket: + return 0 + accum = 0 + for k, idx in enumerate(bucket): + tl = resolved_lengths.get(idx, 0) + # Always include at least one sample even if it alone exceeds budget + # (otherwise we'd never make progress on oversized samples). + if accum > 0 and accum + tl > token_budget: + return k + accum += tl + if accum >= token_budget: + return k + 1 + return len(bucket) + + def _pop_and_return(self, bucket: list[int], n: int, assigned: set) -> tuple[list[int], list[int]]: + """Pop n items from the front of bucket, mark them as consumed + (remove from assigned set), and return as (sampled, consumed).""" + sampled = bucket[:n] + del bucket[:n] + assigned.difference_update(sampled) + return sampled, sampled.copy() + + def _run_balance_round( + self, + available_ready: list[int], + balance_unit: int, + dp_size: int, + partition, + buckets: dict[int, list[int]], + assigned: set, + resolved_lengths: dict[int, int], + ) -> bool: + """Pull balance_unit GRPO-complete groups from available_ready, balance + token totals across dp_size DP buckets, and append to per-DP buckets. + + Returns True if at least one balance round was completed.""" + # Use parent GRPO logic to find balance_unit complete groups. + # We bypass the cache by using a unique task_name/partition_id for + # this internal call (so it does not interfere with the consumer-facing + # state cache of GRPOGroupNSampler). + grpo_sampled, _ = super().sample( + sorted(available_ready), + balance_unit, + task_name="__streaming_internal__", + partition_id="__streaming_internal__", + ) + if not grpo_sampled: + return False + + # Read per-sample total_lengths. If some samples in the round lack + # total_lengths in custom_meta (producer race where rollout pushed + # samples but set_custom_meta hasn't landed yet, or a producer that + # skipped set_custom_meta entirely), fall back to the average of the + # present samples so the round can still proceed. This sacrifices + # exact token-budget accuracy but avoids infinite-defer deadlock. + custom_meta = partition.get_custom_meta(grpo_sampled) + missing = [i for i in grpo_sampled if "total_lengths" not in custom_meta.get(i, {})] + if missing: + # Be conservative: assume missing samples are as large as the full + # token budget (passed via kwargs). Using avg or 0 would risk + # packing many real-but-unknown long samples into one mb → OOM + # (observed with avg fallback when an entire GRPO group lacks + # custom_meta). Over-estimating means each missing sample tends + # to occupy a whole mb by itself, which is safe but inefficient. + fallback = self._token_budget_for_fallback + logger.warning( + "[stream-sampler] %d/%d samples missing total_lengths in this " + "round; using fallback=%d (token_budget) for safety " + "(picked=%s missing=%s)", + len(missing), + len(grpo_sampled), + fallback, + grpo_sampled[:8], + missing[:4], + ) + sample_lengths = [custom_meta.get(i, {}).get("total_lengths", fallback) for i in grpo_sampled] + else: + sample_lengths = [custom_meta[i]["total_lengths"] for i in grpo_sampled] + + # Record the resolved lengths so _select_up_to_budget can reuse them + # later without re-querying custom_meta (which may still race). + for idx, tl in zip(grpo_sampled, sample_lengths, strict=False): + resolved_lengths[idx] = tl + + # Per-sample balance across DPs. GRPO group integrity is NOT required + # at the training DP split: group-relative advantages are computed + # upstream by the Advantages service and stored per-sample, so the + # actor's loss (grpo/gspo/sapo) only reads per-sample advantages — it + # never re-normalizes across a group. We therefore balance individual + # samples (not whole groups) across DPs, which keeps each DP's sample + # count and token total close and minimizes the dummy-mb padding the + # streaming schedule needs for cross-DP micro-batch alignment. + # + # ``equal_size=True`` forces equal sample COUNT per DP (balance_unit is + # a multiple of dp_size), so each DP gets the same number of samples + # with balanced token sums — making per-DP micro-batch counts equal in + # the common case (token packing may still differ by one mb at the + # margins, which the schedule's dummy-pad barrier handles). + balanced = get_seqlen_balanced_partitions(sample_lengths, dp_size, equal_size=True) + for dp_i, sample_idx_list in enumerate(balanced): + dp_samples = [grpo_sampled[j] for j in sample_idx_list] + buckets.setdefault(dp_i, []).extend(dp_samples) + assigned.update(dp_samples) + + return True + + def _flush_tail( + self, + available_ready: list[int], + dp_size: int, + partition, + buckets: dict[int, list[int]], + assigned: set, + resolved_lengths: dict[int, int], + ) -> None: + """End-of-stream flush: distribute ALL remaining ready samples across + DP buckets, bypassing GRPO-group and balance_unit constraints. + + Called only when the producer is done for this partition. Whatever is + left in the ready pool — a sub-balance_unit remainder, even a partial + GRPO group — is token-balanced across DPs and dumped into their buckets + so every produced sample is guaranteed to be consumed (otherwise + all_consumed never becomes True → livelock). Group integrity is not + needed (advantages are precomputed per-sample); cross-DP count + imbalance from this uneven flush is handled by the iterator's dummy-pad. + """ + leftover = sorted(i for i in available_ready if i not in assigned) + if not leftover: + return + + # Resolve lengths (custom_meta, with token-budget fallback for missing). + custom_meta = partition.get_custom_meta(leftover) + fallback = self._token_budget_for_fallback + lengths = [custom_meta.get(i, {}).get("total_lengths", fallback) for i in leftover] + for idx, tl in zip(leftover, lengths, strict=False): + resolved_lengths[idx] = tl + + if len(leftover) >= dp_size: + # Token-balance the leftover across all DPs (variable count per DP). + parts = get_seqlen_balanced_partitions(lengths, dp_size, equal_size=False) + for dp_i, idx_list in enumerate(parts): + dp_samples = [leftover[j] for j in idx_list] + buckets.setdefault(dp_i, []).extend(dp_samples) + assigned.update(dp_samples) + else: + # Fewer leftover than DPs — round-robin; some DPs get nothing (the + # dummy-pad will align them). + for n, idx in enumerate(leftover): + dp_i = n % dp_size + buckets.setdefault(dp_i, []).append(idx) + assigned.add(idx) + + logger.debug( + "[stream-sampler] tail-flush %d leftover samples across %d dps", + len(leftover), + dp_size, + ) + + # ------------------------------------------------------------------ + # Cache / lifecycle + # ------------------------------------------------------------------ + + def clear_cache(self, partition_id: str): + """Drop all per-DP buckets and assignment tracking for this partition.""" + super().clear_cache(partition_id) + keys_to_remove = [k for k in self._buckets if k[0] == partition_id] + for k in keys_to_remove: + del self._buckets[k] + for k in list(self._assigned_global): + if k[0] == partition_id: + del self._assigned_global[k] + for k in list(self._resolved_lengths): + if k[0] == partition_id: + del self._resolved_lengths[k] + # Also clear the parent GRPO sampler's internal cache used by + # _run_balance_round (keyed under "__streaming_internal__"). + if "__streaming_internal__" in self._states: + del self._states["__streaming_internal__"] diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 42f17db8..8adc9879 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -197,6 +197,7 @@ async def notify_data_update( global_indexes: list[int], field_schema: dict[str, dict[str, Any]], custom_backend_meta: Optional[dict[int, dict[str, Any]]] = None, + user_custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> None: """ Notify controller that new data is ready. @@ -206,6 +207,8 @@ async def notify_data_update( global_indexes: Data update related global_indexes. field_schema: Columnar field schema {field_name: {dtype, shape, is_nested, ...}}. custom_backend_meta: Per-field custom_meta for each sample, in {global_index: {field: custom_meta}} format. + user_custom_meta: User-defined per-sample custom_meta in {global_index: {...}} format. When provided, + the controller writes it before marking samples ready, so it lands atomically with readiness. """ if not self.controller_info: @@ -253,6 +256,7 @@ async def notify_data_update( "global_indexes": global_indexes, "field_schema": normalized_field_schema, "custom_backend_meta": custom_backend_meta, + "user_custom_meta": user_custom_meta, }, ).serialize() @@ -608,11 +612,24 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # Get current data partition id partition_id = metadata.partition_ids[0] + # Forward any user-defined custom_meta carried on the BatchMeta so it lands + # atomically with the readiness notification (avoids the put/set_custom_meta + # race for streaming consumers). Only sent when at least one sample has it. + user_custom_meta_list = metadata.get_all_custom_meta() + user_custom_meta: Optional[dict[int, dict[str, Any]]] = None + if any(user_custom_meta_list): + user_custom_meta = { + metadata.global_indexes[i]: user_custom_meta_list[i] + for i in range(len(user_custom_meta_list)) + if user_custom_meta_list[i] + } + await self.notify_data_update( partition_id, metadata.global_indexes, field_schema, per_field_custom_backend_meta, + user_custom_meta, ) async def get_data(self, metadata: BatchMeta) -> TensorDict: diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 00d87822..5f8aab2e 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -21,7 +21,7 @@ from collections.abc import Mapping from functools import wraps from operator import itemgetter -from typing import Any, Callable, NamedTuple +from typing import Any, Callable, NamedTuple, Optional from uuid import uuid4 import torch @@ -329,10 +329,24 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: raise partition_id = metadata.partition_ids[0] + + # Forward any user-defined custom_meta carried on the BatchMeta so it lands + # atomically with the readiness notification (avoids the put/set_custom_meta + # race for streaming consumers). Only sent when at least one sample has it. + user_custom_meta_list = metadata.get_all_custom_meta() + user_custom_meta: Optional[dict[int, dict[str, Any]]] = None + if any(user_custom_meta_list): + user_custom_meta = { + metadata.global_indexes[i]: user_custom_meta_list[i] + for i in range(len(user_custom_meta_list)) + if user_custom_meta_list[i] + } + await self.notify_data_update( partition_id, metadata.global_indexes, field_schema, + user_custom_meta=user_custom_meta, ) @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) From a447841a7f8022c11f7fdd9658a5fe5e1617cd13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=81=E6=9C=AC=E5=93=B2?= Date: Wed, 24 Jun 2026 10:17:11 +0000 Subject: [PATCH 2/2] fix(streaming): no-preset-global-batch end-of-stream for dynamic batch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Track per-partition completion via the producer's is_last signal instead of pre-allocating to global_batch_size: add production_completed, actual_sample_count, pending_last_indexes/fields on DataPartitionStatus - Thread is_last through put -> get_meta(insert) -> get_metadata; accumulate the true inserted sample count on the insert path only - Flip production_completed only once the is_last batch is produced for the PRODUCER's own fields (not downstream-backfilled columns), avoiding a completion deadlock when advantages/ref backfill extra columns - Add check_stream_drained (production_completed AND all inserted samples consumed) and check_production_completed (producer-side gate) ZMQ ops + client wrappers - Pad/grow the lazily-sized per-task consumption tensor in is_stream_drained (read-only) and mark_consumed (write) so a high global index never overruns it (was: IndexError that killed the controller request thread) - Re-run the completion check right after the insert sets has_pending_last, in case the is_last batch's production NOTIFY arrived before the insert RPC - At EOS slice each DP bucket by token budget over successive batch_index rounds (not one oversized pop -> OOM); cache an explicit empty result for an empty DP so batch_index alignment is not frozen and residue is never stranded --- - production_completed gating, backfilled-field independence, OOB consumption tensor, notify-before-is_last race, EOS multi-round drain within token budget Signed-off-by: 宁本哲 --- tests/test_controller_data_partitions.py | 298 ++++++++++++++++++ tests/test_samplers.py | 121 +++++++ transfer_queue/client.py | 153 ++++++++- transfer_queue/controller.py | 238 ++++++++++++-- .../sampler/streaming_token_budget_sampler.py | 19 +- transfer_queue/utils/zmq_utils.py | 11 + 6 files changed, 803 insertions(+), 37 deletions(-) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 8f14517a..ab6c160a 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -1392,3 +1392,301 @@ def test_fieldmeta_to_batch_schema_nested_missing_sample(self): schema = field_meta.to_batch_schema([0, 1]) assert schema["per_sample_shapes"] == [(3,), None] + + +def _ready_schema(): + return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}} + + +class TestStreamingDrain: + """No-preset-global-batch streaming end-of-stream: production_completed + + actual_sample_count, exercised via DataPartitionStatus directly (the controller + insert path just accumulates these fields and stashes pending_last_indexes).""" + + def _make_partition(self): + from transfer_queue.controller import DataPartitionStatus + + return DataPartitionStatus(partition_id="train@stream_0") + + def _produce(self, partition, indices): + partition.update_production_status( + global_indices=list(indices), + field_names=["x"], + field_schema=_ready_schema(), + ) + + def _simulate_insert(self, partition, indices, is_last=False): + # Mirror controller.get_metadata(mode="insert") accounting. + partition.global_indexes.update(indices) + partition.actual_sample_count += len(indices) + if is_last: + partition.pending_last_indexes.update(indices) + partition.has_pending_last = True + partition.pending_last_fields.update(["x"]) # producer field, see _produce + + def test_not_drained_before_completed(self): + p = self._make_partition() + self._simulate_insert(p, [0, 1, 2]) + self._produce(p, [0, 1, 2]) + p.mark_consumed("actor_train", [0, 1, 2]) + # No is_last yet -> not completed -> not drained even though all consumed. + assert p.production_completed is False + assert p.is_stream_drained("actor_train") is False + + def test_completed_flips_only_after_last_batch_ready(self): + p = self._make_partition() + # First batch. + self._simulate_insert(p, [0, 1]) + self._produce(p, [0, 1]) + # Final batch announced at insert, but data not yet produced. + self._simulate_insert(p, [2, 3], is_last=True) + assert p.has_pending_last is True + assert p.production_completed is False # data of final batch not ready yet + # Producing the final batch flips the flag. + self._produce(p, [2, 3]) + assert p.production_completed is True + + def test_drained_requires_completed_and_all_consumed(self): + p = self._make_partition() + self._simulate_insert(p, [0, 1]) + self._produce(p, [0, 1]) + self._simulate_insert(p, [2, 3], is_last=True) + self._produce(p, [2, 3]) + assert p.production_completed is True + # Partial consumption -> not drained. + p.mark_consumed("actor_train", [0, 1, 2]) + assert p.is_stream_drained("actor_train") is False + # Full consumption -> drained. + p.mark_consumed("actor_train", [3]) + assert p.is_stream_drained("actor_train") is True + + def test_unactivated_prealloc_rows_do_not_block_drain(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@stream_1") + # Pre-allocate extra rows that are never activated/inserted. + p.register_pre_allocated_indexes([0, 1, 2, 3, 4, 5, 6, 7]) + # Only 4 samples are actually inserted+produced+consumed. + self._simulate_insert(p, [0, 1]) + self._produce(p, [0, 1]) + self._simulate_insert(p, [2, 3], is_last=True) + self._produce(p, [2, 3]) + p.mark_consumed("actor_train", [0, 1, 2, 3]) + # consumption tensor has 8 rows (4 of them never consumed) but drain only + # counts the actually-inserted samples. + assert p.actual_sample_count == 4 + assert p.is_stream_drained("actor_train") is True + + def test_non_contiguous_indices(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@stream_2") + # Non-contiguous activated global indexes (drain must not assume [:N]). + self._simulate_insert(p, [5, 9]) + self._produce(p, [5, 9]) + self._simulate_insert(p, [11, 20], is_last=True) + self._produce(p, [11, 20]) + assert p.production_completed is True + p.mark_consumed("actor_train", [5, 9, 11]) + assert p.is_stream_drained("actor_train") is False + p.mark_consumed("actor_train", [20]) + assert p.is_stream_drained("actor_train") is True + + def test_drained_false_for_unknown_task(self): + p = self._make_partition() + self._simulate_insert(p, [0, 1], is_last=True) + self._produce(p, [0, 1]) + assert p.production_completed is True + # A task that never consumed anything is not drained. + assert p.is_stream_drained("never_seen_task") is False + + +class TestStreamingDrainOutOfBounds: + """Regression: an is_last batch announced at insert registers high + pending_last_indexes; an EARLIER batch's production notify must not index + production_status out of bounds in _maybe_mark_production_completed.""" + + def _schema(self): + return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}} + + def test_pending_last_beyond_tensor_does_not_raise(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@oob_0") + # Final batch announced at insert with HIGH indexes (e.g. 6,7), but the + # tensor has only been grown for the earlier batch (indexes 0,1). + p.global_indexes.update([6, 7]) + p.actual_sample_count = 8 + p.pending_last_indexes.update([6, 7]) + p.has_pending_last = True + p.pending_last_fields.update(["x"]) + + # Earlier batch's notify: produces indexes 0,1 only. This calls + # _maybe_mark_production_completed internally; indexes 6,7 are beyond the + # tensor at this point — must NOT raise, must NOT mark completed. + ok = p.update_production_status( + global_indices=[0, 1], + field_names=["x"], + field_schema=self._schema(), + ) + assert ok is True + assert p.production_completed is False + + # Now the final batch's own notify lands (indexes 6,7) -> tensor grows -> + # completion flips True. + ok = p.update_production_status( + global_indices=[6, 7], + field_names=["x"], + field_schema=self._schema(), + ) + assert ok is True + assert p.production_completed is True + + +class TestCheckProductionCompleted: + """Producer-side admission gate: check_production_completed reflects + production_completed only (no consumption), unlike is_stream_drained.""" + + def _schema(self): + return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}} + + def test_completed_independent_of_consumption(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@gate_0") + # Not completed before is_last data lands. + p.global_indexes.update([0, 1]) + p.actual_sample_count = 2 + p.pending_last_indexes.update([0, 1]) + p.has_pending_last = True + p.pending_last_fields.update(["x"]) + assert p.production_completed is False + + # Produce the final batch -> completed True, with ZERO consumption. + p.update_production_status([0, 1], ["x"], field_schema=self._schema()) + assert p.production_completed is True + # is_stream_drained still False (nothing consumed) — the two gates differ. + assert p.is_stream_drained("actor_train") is False + + def test_notify_before_is_last_flag_recheck_flips_completed(self): + """Regression: the is_last batch's production notify can arrive BEFORE the + insert that sets has_pending_last (get_meta and put_data are separate RPCs). + + Order: data is produced (update_production_status) while has_pending_last is + still False → that notify's completion check is a no-op. Then the is_last + insert sets the flag but, without a re-check, no further notify fires → + production_completed would never flip → drain deadlock. The fix re-runs the + completion check right after the insert sets the flag. + """ + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@race_0") + + # 1) Producer's NOTIFY arrives first: data for the final batch is marked + # ready while has_pending_last is still False (no-op completion check). + p.update_production_status([0, 1], ["x"], field_schema=self._schema()) + p.actual_sample_count = 2 + assert p.production_completed is False # flag not set yet → not completed + + # 2) The is_last insert now sets the flag. Simulate the controller's + # insert-path: set flag, then re-run the completion check (the fix). + p.global_indexes.update([0, 1]) + p.pending_last_indexes.update([0, 1]) + p.has_pending_last = True + p.pending_last_fields.update(["x"]) + p._maybe_mark_production_completed() # re-check after flag set + + # The data was already ready, so the re-check must flip completion True. + assert p.production_completed is True + + +class TestStreamingDrainBackfillFields: + """Regression for the step-8 hang: downstream consumers (advantages, ref/ + actor_fwd) backfill EXTRA fields into the same partition. production_completed + must only require the PRODUCER's fields on the is_last samples — not the + backfilled columns the producer never writes.""" + + def _producer_schema(self): + return { + "tokens": {"dtype": "torch.int32", "shape": (8,), "is_nested": False, "is_non_tensor": False}, + "rewards": {"dtype": "torch.float32", "shape": (1,), "is_nested": False, "is_non_tensor": False}, + } + + def _adv_schema(self): + return { + "advantages": {"dtype": "torch.float32", "shape": (8,), "is_nested": False, "is_non_tensor": False}, + "returns": {"dtype": "torch.float32", "shape": (8,), "is_nested": False, "is_non_tensor": False}, + } + + def test_backfilled_fields_do_not_block_completion(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@bf_0") + # Producer inserts the (only, final) batch declaring its own fields. + p.global_indexes.update([0, 1]) + p.actual_sample_count = 2 + p.pending_last_indexes.update([0, 1]) + p.has_pending_last = True + p.pending_last_fields.update(["tokens", "rewards"]) + + # Producer writes its fields -> completion should flip True even though + # downstream fields (advantages/returns) have NOT been written yet. + p.update_production_status([0, 1], ["tokens", "rewards"], field_schema=self._producer_schema()) + assert p.production_completed is True + + def test_completion_not_blocked_when_adv_backfill_grows_columns_first(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@bf_1") + p.global_indexes.update([0, 1]) + p.actual_sample_count = 2 + p.pending_last_indexes.update([0, 1]) + p.has_pending_last = True + p.pending_last_fields.update(["tokens", "rewards"]) + + # Producer writes its fields first -> completed True. + p.update_production_status([0, 1], ["tokens", "rewards"], field_schema=self._producer_schema()) + assert p.production_completed is True + + # A later advantages backfill adds advantages/returns columns. The producer + # samples are 0 on those new columns, but completion already (correctly) + # latched True and must stay True. + p.update_production_status([0, 1], ["advantages", "returns"], field_schema=self._adv_schema()) + assert p.production_completed is True + + +class TestStreamDrainedConsumptionUndersized: + """Regression: is_stream_drained must not index past a lazily-sized per-task + consumption tensor (crash: 'index 56 out of bounds for size 56'). The tensor + grows lazily; once production_completed there may be no further production + notify to expand it, so is_stream_drained must ensure capacity itself.""" + + def _schema(self): + return {"x": {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}} + + def test_drained_with_undersized_consumption_tensor(self): + from transfer_queue.controller import DataPartitionStatus + + p = DataPartitionStatus(partition_id="train@undersize_0") + # Produce + complete a partition spanning indexes 0..7 (8 samples). + idxs = list(range(8)) + p.global_indexes.update(idxs) + p.actual_sample_count = 8 + p.pending_last_indexes.update(idxs) + p.has_pending_last = True + p.pending_last_fields.update(["x"]) + p.update_production_status(idxs, ["x"], field_schema=self._schema()) + assert p.production_completed is True + + # Force the task's consumption tensor to be SMALLER than max(active)+1, + # mimicking a tensor that was sized before later samples were activated. + import torch + + p.consumption_status["actor_train"] = torch.zeros(3, dtype=torch.int8) + + # Must not raise (previously IndexError), and not be drained (nothing consumed). + assert p.is_stream_drained("actor_train") is False + + # After consuming all, drained becomes True (capacity ensured internally). + p.mark_consumed("actor_train", idxs) + assert p.is_stream_drained("actor_train") is True diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 21782837..0d42e3d1 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -1420,6 +1420,68 @@ def test_tail_flush_on_production_done_drains_all(self): assert sorted(drained) == [0, 1, 2, 3] assert len(drained) == len(set(drained)) + def test_eos_drains_all_across_rounds_within_budget(self): + """Regression for the step-50 drain hang AND the EOS OOM. + + At end-of-stream every produced sample must eventually be delivered (and + marked consumed) — but NO single micro-batch may exceed the token budget + (popping a whole bucket at once built a 17-sample mb → OOM). So a dp whose + bucket holds several budgets of residue is drained across SUCCESSIVE + batch_index rounds, each mb ≤ budget. This models the real consumer: each + dp advances its OWN batch_index on every non-empty fetch and stops when + the global partition is fully consumed. + """ + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) + all_indexes = list(range(8)) # 4 GRPO groups + # Long per-sample length so each budget slice fits exactly ONE sample, + # forcing multi-round drain of any multi-sample bucket. + per_sample_len = 900 + partition = self._partition_with_uniform_lengths(all_indexes, per_sample_len) + dp_size = 2 + token_budget = 1000 # one 900-token sample fits; two (1800) do not + + # All dps step through batch_index in lockstep (the sampler atomically + # prepares every dp's slice for a batch_index on first touch); we advance + # to the next batch_index once a round has been served, and stop when a + # whole round yields nothing. + drained: list[int] = [] + for batch_index in range(20): # bounded; must finish well within + # Model the controller: consumed samples are filtered OUT of the ready + # pool (scan_data_status excludes consumption_status==1), so already + # drained indexes are never re-offered. + ready_now = [i for i in all_indexes if i not in drained] + round_total = 0 + for dp_rank in range(dp_size): + sampled, consumed = sampler.sample( + ready_now, + 0, + task_name="actor", + partition_id="p0", + dp_rank=dp_rank, + dp_size=dp_size, + batch_index=batch_index, + partition=partition, + token_budget=token_budget, + production_done=True, + ) + assert sampled == consumed + if sampled: + # OOM guard: each delivered mb must respect the token budget. + mb_tokens = len(sampled) * per_sample_len + assert mb_tokens <= token_budget, f"oversized mb: {len(sampled)} samples = {mb_tokens} tok" + drained.extend(sampled) + round_total += len(sampled) + if round_total == 0: + break # whole round empty → fully drained + + # Every produced sample consumed exactly once across the multi-round drain. + assert sorted(drained) == all_indexes, f"orphaned: {set(all_indexes) - set(drained)}" + assert len(drained) == len(set(drained)) + # Nothing left bucketed / assigned-but-unconsumed. + for bucket in sampler._buckets[("p0", "actor")].values(): + assert bucket == [], f"residue left in bucket: {bucket}" + assert sampler._assigned_global[("p0", "actor")] == set() + def test_tail_flush_assigns_all_to_buckets(self): """First production_done call must assign all ready samples to buckets.""" sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) @@ -1449,6 +1511,65 @@ def test_tail_flush_assigns_all_to_buckets(self): popped = set(range(4)) - assigned assert (assigned | popped) == {0, 1, 2, 3} + def test_eos_waits_for_downstream_fields_before_empty_cache(self): + """EOS waits for downstream fields before caching an empty result.""" + from transfer_queue.controller import DataPartitionStatus + + def schema(field_name: str) -> dict[str, dict[str, Any]]: + return {field_name: {"dtype": "torch.float32", "shape": (4,), "is_nested": False, "is_non_tensor": False}} + + sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=1) + partition = DataPartitionStatus(partition_id="train_0") + partition.global_indexes.update([0, 1]) + partition.actual_sample_count = 2 + partition.pending_last_indexes.update([0, 1]) + partition.has_pending_last = True + partition.pending_last_fields.update(["tokens"]) + partition.set_custom_meta({0: {"total_lengths": 5}, 1: {"total_lengths": 5}}) + partition.update_production_status([0, 1], ["tokens"], schema("tokens")) + assert partition.production_completed is True + + data_fields = ["tokens", "advantages"] + production_done = partition.are_unconsumed_fields_ready("actor_train", data_fields) + assert production_done is False + + sampled, consumed = sampler.sample( + [], + 0, + task_name="actor_train", + partition_id="train_0", + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + token_budget=10, + production_done=production_done, + ) + assert sampled == [] + assert consumed == [] + assert sampler._states == {} + + partition.update_production_status([0, 1], ["advantages"], schema("advantages")) + ready = partition.scan_data_status(data_fields, "actor_train") + assert ready == [0, 1] + production_done = partition.are_unconsumed_fields_ready("actor_train", data_fields) + assert production_done is True + + sampled, consumed = sampler.sample( + ready, + 0, + task_name="actor_train", + partition_id="train_0", + dp_rank=0, + dp_size=1, + batch_index=0, + partition=partition, + token_budget=10, + production_done=production_done, + ) + assert sampled == [0, 1] + assert consumed == [0, 1] + def test_no_complete_group_without_production_done_returns_empty(self): """No complete GRPO group + not end-of-stream → return empty (wait for more).""" sampler = StreamingTokenBudgetSampler(n_samples_per_prompt=2, balance_unit_multiplier=4) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index b8e32f2b..171638ae 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -165,6 +165,7 @@ async def async_get_meta( task_name: Optional[str] = None, sampling_config: Optional[dict[str, Any]] = None, token_budget: Optional[int] = None, + is_last: bool = False, socket: Optional[zmq.asyncio.Socket] = None, ) -> BatchMeta: """Asynchronously fetch data metadata from the controller via ZMQ. @@ -179,6 +180,8 @@ async def async_get_meta( - 'insert': Internal usage - should not be used by users task_name: Optional task name associated with the request sampling_config: Optional sampling configuration for custom samplers. + is_last: In insert mode, mark this as the final batch of the partition so + the controller can detect end-of-stream without a preset total count. socket: ZMQ async socket for message transmission (injected by decorator) Returns: @@ -229,6 +232,10 @@ async def async_get_meta( "task_name": task_name, "sampling_config": sampling_config, "token_budget": token_budget, + # is_last: producer marks the final insert batch of a partition so + # the controller can flag end-of-stream without a preset total. Only + # meaningful in mode="insert". + "is_last": is_last, }, ) @@ -327,6 +334,7 @@ async def async_put( metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None, custom_meta: Optional[list[dict[str, Any]]] = None, + is_last: bool = False, ) -> BatchMeta: """Asynchronously write data to storage units based on metadata. @@ -352,6 +360,9 @@ async def async_put( metadata: Records the metadata of a batch of data samples, containing index and storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) + is_last: Mark this as the final insert batch of the partition (streaming + end-of-stream). Only honored on the insert path (metadata is None); + ignored with a warning when explicit metadata is provided. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -413,7 +424,13 @@ async def async_put( batch_size=data.batch_size[0], partition_id=partition_id, mode="insert", + is_last=is_last, ) + elif is_last: + # is_last only applies to the insert path (auto-generated metadata). + # A backfill put with explicit metadata writes into existing samples + # and must not announce end-of-stream. + logger.warning(f"[{self.client_id}]: is_last=True ignored on put with explicit metadata.") if not metadata or metadata.size == 0: raise ValueError("metadata cannot be none or empty") @@ -698,6 +715,99 @@ async def async_get_consumption_status( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e + @dynamic_socket(socket_name="request_handle_socket") + async def async_check_stream_drained( + self, + task_name: str, + partition_id: str, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> bool: + """Streaming end-of-stream test (no preset global batch). + + Returns True iff the producer announced completion for the partition AND + every actually-inserted sample has been consumed by ``task_name``. Use this + instead of :meth:`async_check_consumption_status` on the dynamic-batch + streaming path, where a tensor-wide ``.all()`` check is unreliable + (unactivated pre-allocated rows never reach 1). + + Args: + task_name: Name of the task to check consumption for + partition_id: Partition id to check + socket: ZMQ async socket for message transmission (injected by decorator) + + Returns: + bool: True if the partition is fully produced and fully consumed. + """ + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECK_STREAM_DRAINED, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "partition_id": partition_id, + "task_name": task_name, + }, + ) + + try: + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + + if response_msg.request_type == ZMQRequestType.CHECK_STREAM_DRAINED_RESPONSE: + return bool(response_msg.body.get("drained", False)) + raise RuntimeError( + f"[{self.client_id}]: Failed to check stream drained from controller {self._controller.id}: " + f"{response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in check_stream_drained: {str(e)}") from e + + @dynamic_socket(socket_name="request_handle_socket") + async def async_check_production_completed( + self, + partition_id: str, + socket: Optional[zmq.asyncio.Socket] = None, + ) -> bool: + """Producer-side completion test (no preset global batch, no consumption). + + Returns True iff the partition's producer has declared the final batch via + ``is_last`` AND its data is ready. Use this instead of + :meth:`async_check_production_status` (a tensor-wide ``.all()`` that relied + on pre-allocating the partition to ``global_batch_size``) as the + weight-update / training-admission gate on the dynamic-batch path. + + Args: + partition_id: Partition id to check + socket: ZMQ async socket for message transmission (injected by decorator) + + Returns: + bool: True if the partition's production is complete. + """ + assert socket is not None + request_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECK_PRODUCTION_COMPLETED, # type: ignore[arg-type] + sender_id=self.client_id, + receiver_id=self._controller.id, + body={ + "partition_id": partition_id, + }, + ) + + try: + await socket.send_multipart(request_msg.serialize()) + response_serialized = await socket.recv_multipart(copy=False) + response_msg = ZMQMessage.deserialize(response_serialized) + + if response_msg.request_type == ZMQRequestType.CHECK_PRODUCTION_COMPLETED_RESPONSE: + return bool(response_msg.body.get("completed", False)) + raise RuntimeError( + f"[{self.client_id}]: Failed to check production completed from controller " + f"{self._controller.id}: {response_msg.body.get('message', 'Unknown error')}" + ) + except Exception as e: + raise RuntimeError(f"[{self.client_id}]: Error in check_production_completed: {str(e)}") from e + @dynamic_socket(socket_name="request_handle_socket") async def async_get_production_status( self, @@ -1195,6 +1305,8 @@ def wrapper(*args, **kwargs): self._get_consumption_status = _make_sync(self.async_get_consumption_status) self._get_production_status = _make_sync(self.async_get_production_status) self._check_consumption_status = _make_sync(self.async_check_consumption_status) + self._check_stream_drained = _make_sync(self.async_check_stream_drained) + self._check_production_completed = _make_sync(self.async_check_production_completed) self._check_production_status = _make_sync(self.async_check_production_status) self._get_partition_list = _make_sync(self.async_get_partition_list) self._set_custom_meta = _make_sync(self.async_set_custom_meta) @@ -1299,7 +1411,11 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: return self._set_custom_meta(metadata=metadata) def put( - self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None + self, + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, + is_last: bool = False, ) -> BatchMeta: """Synchronously write data to storage units based on metadata. @@ -1318,6 +1434,8 @@ def put( metadata: Records the metadata of a batch of data samples, containing index and storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) + is_last: Mark this as the final insert batch of the partition (streaming + end-of-stream). Only honored on the insert path (metadata is None). Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -1356,7 +1474,7 @@ def put( >>> # This will create metadata in "insert" mode internally. >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ - return self._put(data=data, metadata=metadata, partition_id=partition_id) + return self._put(data=data, metadata=metadata, partition_id=partition_id, is_last=is_last) def get_data(self, metadata: BatchMeta) -> TensorDict: """Synchronously fetch data from storage units and organize into TensorDict. @@ -1484,6 +1602,37 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool: """ return self._check_consumption_status(task_name=task_name, partition_id=partition_id) + def check_stream_drained(self, task_name: str, partition_id: str) -> bool: + """Synchronously check streaming end-of-stream for a (partition, task). + + Returns True iff the producer announced completion AND every actually- + inserted sample has been consumed by ``task_name``. Preferred over + :meth:`check_consumption_status` on the dynamic-batch streaming path. + + Args: + task_name: Name of the task to check consumption for + partition_id: Partition id to check + + Returns: + bool: True if the partition is fully produced and fully consumed. + """ + return self._check_stream_drained(task_name=task_name, partition_id=partition_id) + + def check_production_completed(self, partition_id: str) -> bool: + """Synchronously check producer-side completion for a partition. + + Returns True iff the producer declared the final batch via ``is_last`` AND + its data is ready. Preferred over :meth:`check_production_status` as the + weight-update / training-admission gate on the dynamic-batch path. + + Args: + partition_id: Partition id to check + + Returns: + bool: True if the partition's production is complete. + """ + return self._check_production_completed(partition_id=partition_id) + def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index dd6b62d7..fc732724 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -365,6 +365,16 @@ class DataPartitionStatus: keys_mapping: dict[str, int] = field(default_factory=dict) # key -> global_idx revert_keys_mapping: dict[int, str] = field(default_factory=dict) # global_idx -> key + # ── Streaming end-of-stream tracking (no preset global batch) ────────────── + # The producer marks the final batch via is_last; drain ends when every + # actually-inserted sample is consumed, instead of relying on the partition + # being pre-allocated to global_batch_size and a tensor-wide .all(). + production_completed: bool = False # set once the is_last batch's data is ready + actual_sample_count: int = 0 # samples actually inserted (sum of insert batch_sizes) + pending_last_indexes: set[int] = field(default_factory=set) # is_last batch indexes + has_pending_last: bool = False # is_last announced, completion not yet flipped + pending_last_fields: set[str] = field(default_factory=set) # producer fields from is_last batch + # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. data_status_lock: Lock = field(default_factory=Lock) @@ -571,12 +581,51 @@ def update_production_status( # Save these global_indexes self.global_indexes.update(global_indices) + # Streaming end-of-stream: the producer announced the final batch at + # insert time (pending_last_indexes); flip production_completed only now + # that those samples are actually marked ready, so consumers never see + # "completed" before the last batch's data has landed. + self._maybe_mark_production_completed() + return True except Exception as e: logger.error(f"Error updating production status for partition {self.partition_id}: {e}") return False + def _maybe_mark_production_completed(self) -> None: + """Flip ``production_completed`` once every sample of the is_last batch + (``pending_last_indexes``) is produced for the producer's own fields. + + No-op until the producer marked a final batch (``has_pending_last``). + Only the producer's declared fields (``pending_last_fields``) are required + — not downstream-backfilled columns (advantages, ref_log_probs) that the + producer never writes. Locked: reached from the NOTIFY_DATA_UPDATE thread + and the GET_META insert path concurrently. + """ + with self.data_status_lock: + if self.production_completed or not self.has_pending_last: + return + if not self.pending_last_indexes: + return + if self.production_status is None or self.total_fields_num == 0: + return + idx = sorted(self.pending_last_indexes) + # is_last indexes may exceed the tensor if an earlier batch's notify + # arrives before the final batch's data grows it → not ready yet. + if idx[-1] >= self.allocated_samples_num: + return + col_indices = [self.field_name_mapping[f] for f in self.pending_last_fields if f in self.field_name_mapping] + if not col_indices: + return + rows = self.production_status[torch.tensor(idx)][:, torch.tensor(sorted(col_indices))] + if bool((rows == 1).all().item()): + self.production_completed = True + logger.debug( + f"Partition {self.partition_id}: production_completed " + f"(actual_sample_count={self.actual_sample_count})" + ) + def _update_field_metadata( self, global_indexes: list[int], @@ -619,8 +668,18 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): try: _, consumption_status = self.get_consumption_status(task_name, mask=False) - if consumption_status.numel() > 0 and global_indices: - consumption_status[global_indices] = 1 + if global_indices: + # mask=False does not expand the tensor; grow it here so a write to + # a high index never silently drops a consumed mark (→ drain deadlock). + required_len = max(global_indices) + 1 + if consumption_status.shape[0] < required_len: + with self.data_status_lock: + grown = torch.zeros(required_len, dtype=consumption_status.dtype) + grown[: consumption_status.shape[0]] = consumption_status + self.consumption_status[task_name] = grown + consumption_status = grown + if consumption_status.numel() > 0: + consumption_status[global_indices] = 1 except Exception as e: logger.error( f"Error marking samples consumed for partition {self.partition_id}, task {task_name}: {e}. " @@ -630,6 +689,71 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): # ==================== Consumption Status Interface ==================== + def is_stream_drained(self, task_name: str) -> bool: + """Streaming end-of-stream test: producer done (``production_completed``) + AND every actually-inserted sample consumed by ``task_name``. + + Counts consumed samples over the activated ``global_indexes`` (not a + tensor-wide ``.all()``), so unactivated pre-allocated rows never block + drain and index allocation need not be contiguous. + """ + if not self.production_completed: + return False + if self.actual_sample_count <= 0: + return False + if task_name not in self.consumption_status: + return False + active = sorted(self.global_indexes) + if len(active) < self.actual_sample_count: + return False + cons = self.consumption_status[task_name] + if cons.numel() == 0: + return False + # The per-task consumption tensor grows lazily and may be shorter than + # max(active)+1; pad a read-only copy (zeros = unconsumed) so indexing is + # safe without mutating stored state (mark_consumed grows it on writes). + required_len = active[-1] + 1 + if cons.shape[0] < required_len: + padded = torch.zeros(required_len, dtype=cons.dtype) + padded[: cons.shape[0]] = cons + cons = padded + idx = torch.tensor(active, dtype=torch.long) + consumed = int((cons[idx] == 1).sum().item()) + return consumed >= self.actual_sample_count + + def are_unconsumed_fields_ready(self, task_name: str, field_names: list[str]) -> bool: + """Return True once producer EOS and requested fields are ready.""" + if not self.production_completed: + return False + active = sorted(self.global_indexes) + if len(active) < self.actual_sample_count: + return False + if not field_names: + return True + if any(field_name not in self.field_name_mapping for field_name in field_names): + return False + + cons = self.consumption_status.get(task_name) + if cons is None or cons.numel() == 0: + unconsumed = active + else: + required_len = active[-1] + 1 if active else 0 + if cons.shape[0] < required_len: + padded = torch.zeros(required_len, dtype=cons.dtype) + padded[: cons.shape[0]] = cons + cons = padded + active_consumed = (cons[torch.tensor(active, dtype=torch.long)] == 1).tolist() + unconsumed = [idx for idx, consumed in zip(active, active_consumed, strict=False) if not consumed] + + if not unconsumed: + return True + if unconsumed[-1] >= self.allocated_samples_num: + return False + + col_indices = [self.field_name_mapping[field_name] for field_name in field_names] + rows = self.production_status[torch.tensor(unconsumed, dtype=torch.long)][:, torch.tensor(col_indices)] + return bool((rows == 1).all().item()) + def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. @@ -1194,6 +1318,34 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Opt return partition.get_consumption_status(task_name, mask=True) + def check_stream_drained(self, partition_id: str, task_name: str) -> bool: + """Streaming end-of-stream test for a (partition, task). + + Returns True iff the producer announced completion (``production_completed``) + AND every actually-inserted sample of the partition has been consumed by + ``task_name``. Used in place of ``check_consumption_status`` on the + no-preset-global-batch path, where a tensor-wide ``.all()`` would either + fire early (dynamic growth) or never (unactivated pre-allocated rows). + """ + partition = self._get_partition(partition_id) + if not partition: + return False + return partition.is_stream_drained(task_name) + + def check_production_completed(self, partition_id: str) -> bool: + """Producer-side completion test for a partition (no consumption involved). + + Returns True iff the partition exists and its producer has declared the + final batch via ``is_last`` AND that batch's data is ready + (``production_completed``). This is the no-preset-global-batch replacement + for the ``production_status.all()`` admission gate, which relied on the + partition being pre-allocated to exactly ``global_batch_size``. + """ + partition = self._get_partition(partition_id) + if not partition: + return False + return bool(partition.production_completed) + def get_production_status( self, partition_id: str, data_fields: list[str] ) -> tuple[Optional[Tensor], Optional[Tensor]]: @@ -1263,6 +1415,7 @@ def get_metadata( batch_size: int | None = None, sampling_config: Optional[dict[str, Any]] = None, token_budget: int | None = None, + is_last: bool = False, *args, **kwargs, ) -> BatchMeta: @@ -1323,6 +1476,23 @@ def get_metadata( # register global_indexes in partition partition.global_indexes.update(batch_global_indexes) + # Streaming end-of-stream bookkeeping. Only insert puts reach here, so + # actual_sample_count counts each sample once (backfill puts carry + # metadata and skip this branch). + with partition.data_status_lock: + partition.actual_sample_count += batch_size + if is_last: + partition.pending_last_indexes.update(batch_global_indexes) + partition.has_pending_last = True + if data_fields: + partition.pending_last_fields.update(data_fields) + + if is_last: + # get_meta(insert) and the data's NOTIFY are separate RPCs; if the + # notify arrived first its completion check was a no-op, so re-check + # now that has_pending_last is set (no-op if data not ready yet). + partition._maybe_mark_production_completed() + return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode) if mode == "fetch": @@ -1366,19 +1536,9 @@ def get_metadata( # "batch_index": batch_index, "partition_id": partition_id}). # Passing them again here causes a "multiple values for # keyword argument" TypeError. - # production_done: every pre-allocated sample of this - # partition has been produced for the requested fields, - # so no more data is coming. The streaming sampler uses - # this to trigger its end-of-stream tail flush (dump all - # remaining ready samples even if they can't form a full - # balance unit), preventing a tail livelock under DP>1. - production_done = False - try: - _, prod = partition.get_production_status_for_fields(data_fields, mask=True) - if prod is not None and prod.numel() > 0: - production_done = bool((prod == 1).all().item()) - except Exception: - production_done = False + # Downstream fields may be backfilled after producer EOS; + # only cache EOS once the requested fields are ready. + production_done = partition.are_unconsumed_fields_ready(task_name, data_fields) batch_global_indexes, consumed_indexes = self.sampler( ready_for_consume_indexes, @@ -1393,21 +1553,10 @@ def get_metadata( if batch_global_indexes: break - # Sampler returned nothing. Two possibilities: - # 1) Partition still has unconsumed samples but not enough for a - # balance round → wait for more production. - # 2) Partition is fully consumed (production complete AND every - # allocated sample marked consumed) → signal end-of-stream. - if partition is not None: - _, cons = partition.get_consumption_status(task_name, mask=True) - if ( - partition.production_status is not None - and cons.numel() > 0 - and bool((cons == 1).all().item()) - ): - # True end-of-stream: everything produced has been - # consumed. Signal drain completion. - return BatchMeta.empty() + # Sampler empty: either still producing (wait), or fully drained + # (producer done AND all inserted samples consumed) → end-of-stream. + if partition is not None and partition.is_stream_drained(task_name): + return BatchMeta.empty() # Bounded wait: short backoff inside the handler so other # consumers' GET_META requests aren't starved. When the @@ -1955,6 +2104,7 @@ def _process_request(self): task_name=params.get("task_name"), sampling_config=params.get("sampling_config", {}), token_budget=params.get("token_budget"), + is_last=params.get("is_last", False), ) response_msg = ZMQMessage.create( @@ -2053,6 +2203,34 @@ def _process_request(self): }, ) + elif request_msg.request_type == ZMQRequestType.CHECK_STREAM_DRAINED: + with perf_monitor.measure(op_type="CHECK_STREAM_DRAINED"): + params = request_msg.body + drained = self.check_stream_drained(params["partition_id"], params["task_name"]) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECK_STREAM_DRAINED_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={ + "partition_id": params["partition_id"], + "drained": drained, + }, + ) + + elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION_COMPLETED: + with perf_monitor.measure(op_type="CHECK_PRODUCTION_COMPLETED"): + params = request_msg.body + completed = self.check_production_completed(params["partition_id"]) + response_msg = ZMQMessage.create( + request_type=ZMQRequestType.CHECK_PRODUCTION_COMPLETED_RESPONSE, + sender_id=self.controller_id, + receiver_id=request_msg.sender_id, + body={ + "partition_id": params["partition_id"], + "completed": completed, + }, + ) + elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION: with perf_monitor.measure(op_type="RESET_CONSUMPTION"): # Handle reset consumption status request diff --git a/transfer_queue/sampler/streaming_token_budget_sampler.py b/transfer_queue/sampler/streaming_token_budget_sampler.py index 9ea827e5..7a67a3d9 100644 --- a/transfer_queue/sampler/streaming_token_budget_sampler.py +++ b/transfer_queue/sampler/streaming_token_budget_sampler.py @@ -154,6 +154,8 @@ def sample( # PP-stage cache: when multiple PP stages request the same # (partition_id, task_name, dp_rank, batch_index), return the # cached result from the first call so all stages see identical data. + # At EOS, a DP cache miss may still need one more prepare to drain + # residue left by another DP rank's earlier prepare. if batch_index is not None: cached = self._states.get(partition_id, {}).get(task_name, {}).get(dp_rank, {}).get(batch_index, None) if cached is not None: @@ -165,6 +167,9 @@ def sample( batch_index, ) return cached + if production_done: + # Drop the batch-wide guard so prepare can fill this DP slot. + self._states.get(partition_id, {}).get(task_name, {}).get(0, {}).pop(batch_index, None) if partition is None: raise ValueError("StreamingTokenBudgetSampler requires partition kwarg from the controller") @@ -307,14 +312,18 @@ def _bucket_tokens(dp_i: int) -> int: self._flush_tail(available_ready, dp_size, partition, buckets, assigned, resolved_lengths) available_ready = [i for i in available_ready if i not in assigned] - # ── Phase 3: slice one mb per dp and commit independently ──────── - # Each dp pops up to a token-budget slice from its own bucket and - # caches it independently. We do NOT require all dps to be non-empty - # (no all-or-nothing rollback — that livelocked at the tail): cross-dp - # mb-count differences are handled by the iterator's dummy-pad barrier. + # ── Phase 3: slice one budget-sized mb per dp ──────────────────────── + # Pop only a token-budget slice (never the whole bucket — that builds an + # oversized mb and OOMs on long sequences); residue drains over successive + # batch_index rounds as the consumer advances each dp on non-empty fetches. + # At EOS, cache an explicit empty result for an already-empty dp so the + # dp=0 ``already`` guard stays set and batch_index isn't recomputed/frozen + # (which would strand other dps' residue). for dp_i in range(dp_size): bucket = buckets.setdefault(dp_i, []) if not bucket: + if is_eos: + self._cache_result(partition_id, task_name, dp_i, batch_index, ([], [])) continue sel_count = self._select_up_to_budget(bucket, resolved_lengths, token_budget) sel_count = max(sel_count, 1) # always make progress diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8afbb480..072c2379 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -83,6 +83,17 @@ class ZMQRequestType(ExplicitEnum): RESET_CONSUMPTION = "RESET_CONSUMPTION" RESET_CONSUMPTION_RESPONSE = "RESET_CONSUMPTION_RESPONSE" + # CHECK_STREAM_DRAINED (streaming end-of-stream: production_completed AND + # all actually-inserted samples consumed by the task) + CHECK_STREAM_DRAINED = "CHECK_STREAM_DRAINED" + CHECK_STREAM_DRAINED_RESPONSE = "CHECK_STREAM_DRAINED_RESPONSE" + + # CHECK_PRODUCTION_COMPLETED (producer-side only: the partition's producer has + # declared the final batch via is_last AND its data is ready — independent of + # any consumption. Used as the weight-update / training-admission gate.) + CHECK_PRODUCTION_COMPLETED = "CHECK_PRODUCTION_COMPLETED" + CHECK_PRODUCTION_COMPLETED_RESPONSE = "CHECK_PRODUCTION_COMPLETED_RESPONSE" + # GET_PRODUCTION GET_PRODUCTION = "GET_PRODUCTION" PRODUCTION_RESPONSE = "PRODUCTION_RESPONSE"