[trainer, rollout] feat: opt-in rollout-level dispatch for the V1 agent-loop trainer#6874
[trainer, rollout] feat: opt-in rollout-level dispatch for the V1 agent-loop trainer#6874huaiyizhao wants to merge 1 commit into
Conversation
By default the V1 agent-loop manager chunks a batch of prompts across workers and
each worker fans out a whole prompt's n GRPO sessions internally, pinning a prompt
(and a long-tail rollout) to a single worker process. This adds an opt-in
rollout-level dispatch model, selected by trainer.v1.rollout_level_dispatch
(default false, legacy behavior byte-for-byte):
- Manager dispatches one (prompt, session) rollout at a time, round-robin across
the worker pool, so a prompt's sessions spread over workers and a long-tail
rollout occupies only a single slot.
- Optional per-worker concurrency cap (trainer.v1.max_concurrent_rollouts_per_worker).
- ReplayBuffer gains a session-counting readiness model: each rollout writes a
{uid}_sess{session_id} completion marker, and a prompt is sampleable once all n
markers are present (success or failure). The legacy prompt-status readiness
(pending/running/finished/failure) is kept and used when the flag is off.
Both readiness models are populated from one metadata sync (routed by whether the
prompt tag carries n vs status), so either path works without code changes.
Tests: tests/trainer/ppo/v1/test_replay_buffer_session_counting_on_cpu.py exercises
both models against a real TransferQueue on CPU (7/7).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in, rollout-level dispatch mechanism for the V1 AgentLoop and ReplayBuffer, allowing individual GRPO sessions of a prompt to be fanned out round-robin across the worker pool instead of being pinned to a single worker. This change is supported by a new session-counting readiness model in the ReplayBuffer and a comprehensive CPU-only test suite. The review feedback focuses on critical performance and robustness enhancements, specifically: cloning sliced PyTorch tensors and converting tensor elements to plain Python scalars to prevent massive serialization overhead when sending tasks via Ray, and using a more robust string splitting method (rsplit) to correctly extract UIDs that may contain underscores.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def extract_sample(batch: TensorDict, i: int) -> dict: | ||
| """Extract sample ``i`` from a batched TensorDict into a plain per-prompt dict. | ||
|
|
||
| Moved out of the worker so the manager can build per-prompt dicts before dispatching | ||
| individual rollout units. Mirrors the original per-key type handling. | ||
| """ | ||
| sample: dict[str, Any] = {} | ||
| for k, v in batch.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| sample[k] = v[i] | ||
| elif isinstance(v, NonTensorStack): | ||
| sample[k] = v[i].data | ||
| elif isinstance(v, NonTensorData): | ||
| sample[k] = v.data | ||
| else: | ||
| logger.exception(f"Unsupported type {type(v)} for key {k}") | ||
| return sample |
There was a problem hiding this comment.
Slicing a PyTorch tensor (e.g., v[i]) creates a view that shares the underlying storage with the original batched tensor. When Ray serializes this sliced tensor to send it to a worker, it serializes the entire underlying storage of the batch (e.g., all 1024 prompts) for each individual sample. This leads to massive memory overhead, high serialization latency, and potential out-of-memory (OOM) errors.
Calling .clone() on the sliced tensor allocates new, independent storage containing only the single element, preventing the serialization of the entire batch's storage.
| def extract_sample(batch: TensorDict, i: int) -> dict: | |
| """Extract sample ``i`` from a batched TensorDict into a plain per-prompt dict. | |
| Moved out of the worker so the manager can build per-prompt dicts before dispatching | |
| individual rollout units. Mirrors the original per-key type handling. | |
| """ | |
| sample: dict[str, Any] = {} | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| sample[k] = v[i] | |
| elif isinstance(v, NonTensorStack): | |
| sample[k] = v[i].data | |
| elif isinstance(v, NonTensorData): | |
| sample[k] = v.data | |
| else: | |
| logger.exception(f"Unsupported type {type(v)} for key {k}") | |
| return sample | |
| def extract_sample(batch: TensorDict, i: int) -> dict: | |
| """Extract sample ``i`` from a batched TensorDict into a plain per-prompt dict. | |
| Moved out of the worker so the manager can build per-prompt dicts before dispatching | |
| individual rollout units. Mirrors the original per-key type handling. | |
| """ | |
| sample: dict[str, Any] = {} | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| sample[k] = v[i].clone() | |
| elif isinstance(v, NonTensorStack): | |
| sample[k] = v[i].data | |
| elif isinstance(v, NonTensorData): | |
| sample[k] = v.data | |
| else: | |
| logger.exception(f"Unsupported type {type(v)} for key {k}") | |
| return sample |
| def build_trajectory_info(step, index, validate) -> list[dict]: | ||
| """Synchronous port of ``agent_loop.get_trajectory_info`` (pure CPU, no I/O). | ||
|
|
||
| Lives here so the manager can build trajectory info inline while fanning out rollout | ||
| units, without spinning up an event loop just to await the async original. | ||
| """ | ||
| trajectory_info = [] | ||
| rollout_n = 0 | ||
| for i in range(len(index)): | ||
| if i > 0 and index[i - 1] == index[i]: | ||
| rollout_n += 1 | ||
| else: | ||
| rollout_n = 0 | ||
| trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate}) | ||
| return trajectory_info |
There was a problem hiding this comment.
index[i] is a PyTorch tensor element that shares the underlying storage of the entire index tensor. Passing it directly inside the trajectory_info dictionary to Ray workers causes Ray to serialize the entire index tensor's storage for every single rollout task, leading to unnecessary serialization overhead.
Since index[i] is just a metadata scalar index, converting it to a plain Python int using .item() completely avoids PyTorch storage sharing and makes serialization extremely lightweight.
| def build_trajectory_info(step, index, validate) -> list[dict]: | |
| """Synchronous port of ``agent_loop.get_trajectory_info`` (pure CPU, no I/O). | |
| Lives here so the manager can build trajectory info inline while fanning out rollout | |
| units, without spinning up an event loop just to await the async original. | |
| """ | |
| trajectory_info = [] | |
| rollout_n = 0 | |
| for i in range(len(index)): | |
| if i > 0 and index[i - 1] == index[i]: | |
| rollout_n += 1 | |
| else: | |
| rollout_n = 0 | |
| trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate}) | |
| return trajectory_info | |
| def build_trajectory_info(step, index, validate) -> list[dict]: | |
| """Synchronous port of ``agent_loop.get_trajectory_info`` (pure CPU, no I/O). | |
| Lives here so the manager can build trajectory info inline while fanning out rollout | |
| units, without spinning up an event loop just to await the async original. | |
| """ | |
| trajectory_info = [] | |
| rollout_n = 0 | |
| step_val = step.item() if hasattr(step, "item") else step | |
| for i in range(len(index)): | |
| if i > 0 and index[i - 1] == index[i]: | |
| rollout_n += 1 | |
| else: | |
| rollout_n = 0 | |
| idx_val = index[i].item() if hasattr(index[i], "item") else index[i] | |
| trajectory_info.append({"step": step_val, "sample_index": idx_val, "rollout_n": rollout_n, "validate": validate}) | |
| return trajectory_info |
| elif tag.get("is_session", False): | ||
| # rollout-level per-session completion marker `{uid}_sess{session_id}`. | ||
| uid = key.split("_")[0] | ||
| self.session_done[partition_id][uid].add(tag["session_id"]) | ||
| self.session_marker_keys[partition_id][uid].add(key) |
There was a problem hiding this comment.
Using key.split("_")[0] to extract the uid from the session marker key (f"{uid}_sess{session_id}") is fragile and will fail if the uid itself contains underscores (e.g., my_custom_uid). In such cases, it will only extract the first part of the uid (e.g., my), causing a mismatch with the prompt keys in prompt_n and ultimately causing the replay buffer to hang indefinitely waiting for complete sessions.
Using key.rsplit("_sess", 1)[0] is robust and correctly extracts the full uid regardless of whether it contains underscores.
| elif tag.get("is_session", False): | |
| # rollout-level per-session completion marker `{uid}_sess{session_id}`. | |
| uid = key.split("_")[0] | |
| self.session_done[partition_id][uid].add(tag["session_id"]) | |
| self.session_marker_keys[partition_id][uid].add(key) | |
| elif tag.get("is_session", False): | |
| # rollout-level per-session completion marker `{uid}_sess{session_id}`. | |
| uid = key.rsplit("_sess", 1)[0] | |
| self.session_done[partition_id][uid].add(tag["session_id"]) | |
| self.session_marker_keys[partition_id][uid].add(key) |
What does this PR do?
Adds an opt-in rollout-level dispatch model to the V1 agent-loop trainer, selected by trainer.v1.rollout_level_dispatch (default false).
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.