Skip to content

[trainer, rollout] feat: opt-in rollout-level dispatch for the V1 agent-loop trainer#6874

Open
huaiyizhao wants to merge 1 commit into
verl-project:mainfrom
huaiyizhao:feat/rollout-level-dispatch
Open

[trainer, rollout] feat: opt-in rollout-level dispatch for the V1 agent-loop trainer#6874
huaiyizhao wants to merge 1 commit into
verl-project:mainfrom
huaiyizhao:feat/rollout-level-dispatch

Conversation

@huaiyizhao

Copy link
Copy Markdown
Contributor

What does this PR do?

Adds an opt-in rollout-level dispatch model to the V1 agent-loop trainer, selected by trainer.v1.rollout_level_dispatch (default false).

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

By default the V1 agent-loop manager chunks a batch of prompts across workers and
each worker fans out a whole prompt's n GRPO sessions internally, pinning a prompt
(and a long-tail rollout) to a single worker process. This adds an opt-in
rollout-level dispatch model, selected by trainer.v1.rollout_level_dispatch
(default false, legacy behavior byte-for-byte):

- Manager dispatches one (prompt, session) rollout at a time, round-robin across
  the worker pool, so a prompt's sessions spread over workers and a long-tail
  rollout occupies only a single slot.
- Optional per-worker concurrency cap (trainer.v1.max_concurrent_rollouts_per_worker).
- ReplayBuffer gains a session-counting readiness model: each rollout writes a
  {uid}_sess{session_id} completion marker, and a prompt is sampleable once all n
  markers are present (success or failure). The legacy prompt-status readiness
  (pending/running/finished/failure) is kept and used when the flag is off.

Both readiness models are populated from one metadata sync (routed by whether the
prompt tag carries n vs status), so either path works without code changes.

Tests: tests/trainer/ppo/v1/test_replay_buffer_session_counting_on_cpu.py exercises
both models against a real TransferQueue on CPU (7/7).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in, rollout-level dispatch mechanism for the V1 AgentLoop and ReplayBuffer, allowing individual GRPO sessions of a prompt to be fanned out round-robin across the worker pool instead of being pinned to a single worker. This change is supported by a new session-counting readiness model in the ReplayBuffer and a comprehensive CPU-only test suite. The review feedback focuses on critical performance and robustness enhancements, specifically: cloning sliced PyTorch tensors and converting tensor elements to plain Python scalars to prevent massive serialization overhead when sending tasks via Ray, and using a more robust string splitting method (rsplit) to correctly extract UIDs that may contain underscores.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +69 to +85
def extract_sample(batch: TensorDict, i: int) -> dict:
"""Extract sample ``i`` from a batched TensorDict into a plain per-prompt dict.

Moved out of the worker so the manager can build per-prompt dicts before dispatching
individual rollout units. Mirrors the original per-key type handling.
"""
sample: dict[str, Any] = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
sample[k] = v[i]
elif isinstance(v, NonTensorStack):
sample[k] = v[i].data
elif isinstance(v, NonTensorData):
sample[k] = v.data
else:
logger.exception(f"Unsupported type {type(v)} for key {k}")
return sample

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Slicing a PyTorch tensor (e.g., v[i]) creates a view that shares the underlying storage with the original batched tensor. When Ray serializes this sliced tensor to send it to a worker, it serializes the entire underlying storage of the batch (e.g., all 1024 prompts) for each individual sample. This leads to massive memory overhead, high serialization latency, and potential out-of-memory (OOM) errors.

Calling .clone() on the sliced tensor allocates new, independent storage containing only the single element, preventing the serialization of the entire batch's storage.

Suggested change
def extract_sample(batch: TensorDict, i: int) -> dict:
"""Extract sample ``i`` from a batched TensorDict into a plain per-prompt dict.
Moved out of the worker so the manager can build per-prompt dicts before dispatching
individual rollout units. Mirrors the original per-key type handling.
"""
sample: dict[str, Any] = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
sample[k] = v[i]
elif isinstance(v, NonTensorStack):
sample[k] = v[i].data
elif isinstance(v, NonTensorData):
sample[k] = v.data
else:
logger.exception(f"Unsupported type {type(v)} for key {k}")
return sample
def extract_sample(batch: TensorDict, i: int) -> dict:
"""Extract sample ``i`` from a batched TensorDict into a plain per-prompt dict.
Moved out of the worker so the manager can build per-prompt dicts before dispatching
individual rollout units. Mirrors the original per-key type handling.
"""
sample: dict[str, Any] = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
sample[k] = v[i].clone()
elif isinstance(v, NonTensorStack):
sample[k] = v[i].data
elif isinstance(v, NonTensorData):
sample[k] = v.data
else:
logger.exception(f"Unsupported type {type(v)} for key {k}")
return sample

Comment on lines +52 to +66
def build_trajectory_info(step, index, validate) -> list[dict]:
"""Synchronous port of ``agent_loop.get_trajectory_info`` (pure CPU, no I/O).

Lives here so the manager can build trajectory info inline while fanning out rollout
units, without spinning up an event loop just to await the async original.
"""
trajectory_info = []
rollout_n = 0
for i in range(len(index)):
if i > 0 and index[i - 1] == index[i]:
rollout_n += 1
else:
rollout_n = 0
trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate})
return trajectory_info

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

index[i] is a PyTorch tensor element that shares the underlying storage of the entire index tensor. Passing it directly inside the trajectory_info dictionary to Ray workers causes Ray to serialize the entire index tensor's storage for every single rollout task, leading to unnecessary serialization overhead.

Since index[i] is just a metadata scalar index, converting it to a plain Python int using .item() completely avoids PyTorch storage sharing and makes serialization extremely lightweight.

Suggested change
def build_trajectory_info(step, index, validate) -> list[dict]:
"""Synchronous port of ``agent_loop.get_trajectory_info`` (pure CPU, no I/O).
Lives here so the manager can build trajectory info inline while fanning out rollout
units, without spinning up an event loop just to await the async original.
"""
trajectory_info = []
rollout_n = 0
for i in range(len(index)):
if i > 0 and index[i - 1] == index[i]:
rollout_n += 1
else:
rollout_n = 0
trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate})
return trajectory_info
def build_trajectory_info(step, index, validate) -> list[dict]:
"""Synchronous port of ``agent_loop.get_trajectory_info`` (pure CPU, no I/O).
Lives here so the manager can build trajectory info inline while fanning out rollout
units, without spinning up an event loop just to await the async original.
"""
trajectory_info = []
rollout_n = 0
step_val = step.item() if hasattr(step, "item") else step
for i in range(len(index)):
if i > 0 and index[i - 1] == index[i]:
rollout_n += 1
else:
rollout_n = 0
idx_val = index[i].item() if hasattr(index[i], "item") else index[i]
trajectory_info.append({"step": step_val, "sample_index": idx_val, "rollout_n": rollout_n, "validate": validate})
return trajectory_info

Comment on lines +176 to +180
elif tag.get("is_session", False):
# rollout-level per-session completion marker `{uid}_sess{session_id}`.
uid = key.split("_")[0]
self.session_done[partition_id][uid].add(tag["session_id"])
self.session_marker_keys[partition_id][uid].add(key)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using key.split("_")[0] to extract the uid from the session marker key (f"{uid}_sess{session_id}") is fragile and will fail if the uid itself contains underscores (e.g., my_custom_uid). In such cases, it will only extract the first part of the uid (e.g., my), causing a mismatch with the prompt keys in prompt_n and ultimately causing the replay buffer to hang indefinitely waiting for complete sessions.

Using key.rsplit("_sess", 1)[0] is robust and correctly extracts the full uid regardless of whether it contains underscores.

Suggested change
elif tag.get("is_session", False):
# rollout-level per-session completion marker `{uid}_sess{session_id}`.
uid = key.split("_")[0]
self.session_done[partition_id][uid].add(tag["session_id"])
self.session_marker_keys[partition_id][uid].add(key)
elif tag.get("is_session", False):
# rollout-level per-session completion marker `{uid}_sess{session_id}`.
uid = key.rsplit("_sess", 1)[0]
self.session_done[partition_id][uid].add(tag["session_id"])
self.session_marker_keys[partition_id][uid].add(key)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant