[trainer, fully_async] feat: add streaming rollouter mode to the V1 PPO trainer#6868
[trainer, fully_async] feat: add streaming rollouter mode to the V1 PPO trainer#6868huaiyizhao wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new fully_async (streaming) mode for the V1 PPO trainer, allowing an autonomous background feeder to continuously stream prompts into the TransferQueue under an in-flight budget. This decouples rollout generation from training and overlaps the two processes. The changes include the core streaming feeder logic, the PPOTrainerFullyAsync implementation, thread-safety guards for the training dataloader, configuration updates, and comprehensive CPU unit/integration tests and E2E scripts. The review feedback highlights two key improvement opportunities: first, updating the parameter version only during actual weight synchronization steps to prevent underestimating off-policy staleness; second, refactoring the feeder's pause/resume mechanism using a threading.Condition to eliminate polling delays and avoid cumulative GPU idle time.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| with self._param_version_lock: | ||
| self._param_version = self.global_steps |
There was a problem hiding this comment.
Updating self._param_version at the end of every training step is incorrect because weight synchronization with the standalone rollout only occurs periodically (every parameter_sync_step steps). Tagging prompts with intermediate step numbers makes the trainer believe they were generated with fresher weights than they actually were, which underestimates off-policy staleness and can bypass the max_off_policy_threshold safety guard.
self._param_version should only be updated to self.global_steps when a weight synchronization actually occurs.
| with self._param_version_lock: | |
| self._param_version = self.global_steps | |
| if self.global_steps % self.parameter_sync_step == 0: | |
| with self._param_version_lock: | |
| self._param_version = self.global_steps |
| self._stop = threading.Event() | ||
| self._paused = threading.Event() # when set, the loop stops dispatching new prompts | ||
| self.error = False # set True if the feeder thread dies unexpectedly | ||
| self._thread: threading.Thread | None = None | ||
|
|
||
| def _loop(self): | ||
| while not self._stop.is_set(): | ||
| if self._paused.is_set(): | ||
| # paused (e.g. during a weight sync): do not dispatch; generation already in | ||
| # flight keeps running and is aborted+continued by the checkpoint engine. | ||
| self._stop.wait(self._poll_interval) | ||
| continue | ||
| try: | ||
| counts = self._count_inflight() | ||
| inflight = counts["pending"] + counts["running"] + counts["finished"] + counts["failure"] | ||
| if inflight < self._budget: | ||
| self._feed_one_batch(self._param_version()) | ||
| else: | ||
| # interruptible sleep: avoids a tight busy loop while the budget is full | ||
| self._stop.wait(self._poll_interval) | ||
| except StopIteration: | ||
| logger.info("Streaming feeder: dataset exhausted, stopping feeder") | ||
| break | ||
| except Exception: | ||
| logger.exception("Streaming feeder thread crashed") | ||
| self.error = True | ||
| break | ||
|
|
||
| def start(self): | ||
| """Start the background feeder thread.""" | ||
| self._stop.clear() | ||
| self.error = False | ||
| self._thread = threading.Thread(target=self._loop, name="streaming-rollout-feeder", daemon=True) | ||
| self._thread.start() | ||
|
|
||
| def stop(self, timeout: float = 30.0): | ||
| """Signal the feeder to stop and join the thread.""" | ||
| self._stop.set() | ||
| if self._thread is not None and self._thread.is_alive(): | ||
| self._thread.join(timeout=timeout) | ||
|
|
||
| def pause(self): | ||
| """Pause dispatching new prompts (e.g. during a weight sync). | ||
|
|
||
| Generation already in flight is unaffected (it is aborted+continued by the checkpoint | ||
| engine / FullyAsyncLLMServerClient). Idempotent. | ||
| """ | ||
| self._paused.set() | ||
|
|
||
| def resume(self): | ||
| """Resume dispatching after :meth:`pause`. Idempotent.""" | ||
| self._paused.clear() | ||
|
|
||
| @property | ||
| def paused(self) -> bool: | ||
| return self._paused.is_set() |
There was a problem hiding this comment.
Using self._paused.wait(self._poll_interval) or polling self._paused.is_set() with a sleep interval introduces an unnecessary delay (up to feeder_poll_interval, which defaults to 1.0s) when resuming the feeder after a weight synchronization. This can lead to significant cumulative GPU idle time during training.
Using a threading.Condition variable allows the feeder to be resumed instantly and efficiently without any polling or sleeping delays.
self._stop = threading.Event()
self._paused = False
self._cv = threading.Condition()
self.error = False # set True if the feeder thread dies unexpectedly
self._thread: threading.Thread | None = None
def _loop(self):
while not self._stop.is_set():
with self._cv:
while self._paused and not self._stop.is_set():
self._cv.wait()
if self._stop.is_set():
break
try:
counts = self._count_inflight()
inflight = counts["pending"] + counts["running"] + counts["finished"] + counts["failure"]
if inflight < self._budget:
self._feed_one_batch(self._param_version())
else:
# interruptible sleep: avoids a tight busy loop while the budget is full
self._stop.wait(self._poll_interval)
except StopIteration:
logger.info("Streaming feeder: dataset exhausted, stopping feeder")
break
except Exception:
logger.exception("Streaming feeder thread crashed")
self.error = True
break
def start(self):
"""Start the background feeder thread."""
self._stop.clear()
with self._cv:
self._paused = False
self.error = False
self._thread = threading.Thread(target=self._loop, name="streaming-rollout-feeder", daemon=True)
self._thread.start()
def stop(self, timeout: float = 30.0):
"""Signal the feeder to stop and join the thread."""
self._stop.set()
with self._cv:
self._cv.notify_all()
if self._thread is not None and self._thread.is_alive():
self._thread.join(timeout=timeout)
def pause(self):
"""Pause dispatching new prompts (e.g. during a weight sync).
Generation already in flight is unaffected (it is aborted+continued by the checkpoint
engine / FullyAsyncLLMServerClient). Idempotent.
"""
with self._cv:
self._paused = True
def resume(self):
"""Resume dispatching after :meth:`pause`. Idempotent."""
with self._cv:
self._paused = False
self._cv.notify_all()
@property
def paused(self) -> bool:
with self._cv:
return self._paused|
@huaiyizhao Thanks for you contribution. I have 2 questions:
|
30bb9e9 to
bf6b966
Compare
Add a fully_async trainer mode: an autonomous background feeder thread continuously streams prompts into TransferQueue (bounded by a staleness/in-flight budget) while step() only samples + trains, decoupling rollout production from training consumption. The feeder (thread loop, throttling, weight-sync pause/resume) is fully self-contained in trainer_fully_async.py and touches the base trainer only through its public state. trainer_base is left exactly as upstream: the subclass overrides _add_batch_to_generate so the base step()'s unconditional feed becomes a no-op once the feeder owns generation, owns its own dataloader lock, and overrides _save_checkpoint to serialize against the feeder thread. replay_buffer gains count_inflight + a 'none' staleness strategy (TIS-corrected streaming) and a no-op dead_prompt_keys hook. separate_async reads num_warmup_batches/parameter_sync_step from the active mode's config so fully_async uses its own cadence. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…mal-rollout discard
Add an opt-in alternative to the default prompt-level agent-loop worker, selected via the existing
agent_loop_manager_class / custom_sampler config injection (the prompt-level implementation is left
unchanged):
- RolloutAgentLoopManagerTQ/RolloutAgentLoopWorkerTQ (agent_loop_tq_rollout.py): dispatch one
(prompt, session) rollout at a time round-robin across the worker pool, so sibling sessions run
on different workers and a long-tail rollout never blocks the rest. Each session writes a
{uid}_sess{session_id} completion marker and offloads postprocess CPU work to a thread pool.
- SessionReplayBuffer (replay_buffer_session.py): readiness derived from per-session completion
markers; only prompts with >=1 successful session are sampleable; dead_prompt_keys() surfaces
all-failed prompts so the streaming feeder discards them and a fresh prompt takes the slot.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…real TQ Cover the default prompt-level ReplayBuffer (status readiness) and the opt-in rollout-level SessionReplayBuffer (session-counting readiness, abnormal-rollout discard via dead_prompt_keys, sample/clear cycle) against a real TransferQueue on a local Ray cluster, plus a buffer-level concurrent produce/consume steady-state check (bounded, no deadlock). The inlined feeder thread lives in PPOTrainerFullyAsync (pulls the GPU serving stack) and is covered by the GPU e2e scripts. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
62ab8fb to
6c27b8f
Compare
File-load agent_loop_tq_rollout with the GPU serving stack stubbed, covering what the GPU smoke run does not assert: build_trajectory_info / extract_sample / mm_token_feature_counts, the RolloutAgentLoopManagerTQ round-robin session fan-out (+ persistent dispatch cursor), and _execute_rollout writing a per-session success OR failure completion marker against a real TransferQueue. The failure-marker branch is the one a healthy GPU run never trips, yet the abnormal-rollout discard depends on it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Register "Copyright 2026 Tencent Inc. and/or its affiliates" in the license-header allowlist and apply it to the files newly added by the streaming work. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
@wuxibin89 Hi.
|
What does this PR do?
Adds a new fully_async trainer mode to the V1 PPO trainer that decouples rollout generation from the training step, so generation and training overlap instead of running lock-step (one batch fed per step). This brings the streaming-rollouter capability of verl/experimental/fully_async_policy to the V1 trainer, while reusing the existing TransferQueue as the streaming channel (no separate MessageQueue / Rollouter actor).
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.