Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions verl/single_controller/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def collect_dp_compute_data_proto(worker_group, output):
return _concat_data_proto_or_future(output)


def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, collect_mask, *args, **kwargs):
def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs):
import os

from verl.single_controller.base.worker_group import WorkerGroup
Expand Down Expand Up @@ -248,10 +248,6 @@ def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, colle
local_dp_rank = dp_rank_mapping[i]
transformed_v.append(v[local_dp_rank])
all_kwargs[k] = transformed_v

# add kwargs determing whether to collect from this rank
all_kwargs["collect_from_rank"] = [collect_mask[i] for i in range(worker_group.world_size)]

return all_args, all_kwargs


Expand All @@ -269,9 +265,9 @@ def collect_nd_compute(collect_mask: list[bool], worker_group, output):
return output_in_dp


def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, collect_mask, *args, **kwargs):
def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs):
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(dp_size, *args, **kwargs)
return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, collect_mask, *splitted_args, **splitted_kwargs)
return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs)


def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output):
Expand All @@ -297,20 +293,10 @@ def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
worker_group._dispatch_info[mesh_name] = worker_group._query_dispatch_info(mesh_name)
assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size

# the dispatch info is stored in the worker group
assert mesh_name in worker_group._dispatch_info
if mesh_name not in worker_group._collect_info:
worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name)
assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size

dp_rank_mapping = worker_group._dispatch_info[mesh_name]

# a boolean of whether the dp_rank is used for collect
collect_mask = worker_group._collect_info[mesh_name]

# perform dispatch
dp_size = max(dp_rank_mapping) + 1
return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, collect_mask, *args, **kwargs)
return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs)


def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs):
Expand Down Expand Up @@ -456,7 +442,7 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki
_check_execute_mode(execute_mode=execute_mode)

def decorator(func):
func = tqbridge()(func)
func = tqbridge(dispatch_mode=dispatch_mode)(func)

@wraps(func)
def inner(*args, **kwargs):
Expand Down
137 changes: 103 additions & 34 deletions verl/utils/transferqueue_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import os
import threading
from functools import wraps
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
from verl.single_controller.base.decorator import Dispatch

from tensordict import TensorDict

Expand Down Expand Up @@ -144,7 +147,84 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun
return updated_batch_meta


def tqbridge(put_data: bool = True):
def _compute_need_collect(dispatch_mode: dict | "Dispatch", args: list) -> bool:
"""Compute whether data collection is needed for the current worker.

This function determines whether the current worker should collect data based on
the dispatch mode configuration and worker parameters. It's used to optimize
distributed data collection by ensuring only the appropriate rank collects data.

Args:
dispatch_mode: Controls data collection logic for the current worker. Can be None,
a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch,
always returns True (current worker should collect). If dict, checks
collect_fn for lazy compute optimization.
args: List of arguments passed to the function. Should contain a Worker instance
as the first argument when using lazy compute mode.

Returns:
bool: True if data collection is needed, False otherwise.

Note:
Only checks worker attributes when dispatch_mode is a dict with 'collect_fn',
the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker.
Otherwise, returns True. For the lazy compute case, checks the worker's
data parallel rank for the mesh specified in collect_fn.args[0] to determine
if this worker should collect data.
"""
from verl.single_controller.base.decorator import Dispatch
from verl.single_controller.base.worker import Worker

if dispatch_mode is None or isinstance(dispatch_mode, Dispatch):
return True

assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode."
Comment thread
0oshowero0 marked this conversation as resolved.
collect_fn_name = dispatch_mode["collect_fn"].func.__name__
if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker):
return True

collect_mesh_name = dispatch_mode["collect_fn"].args[0]
Comment on lines +182 to +186

Copilot AI Dec 19, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition checks if collect_fn_name != "collect_lazy_compute_data_proto" to return True early. However, this hardcoded function name check is brittle. If the function is renamed or if there are other similar collect functions added in the future, this logic will break. Consider using a more robust approach, such as checking a property/attribute on the collect function or using a registry pattern.

Suggested change
collect_fn_name = dispatch_mode["collect_fn"].func.__name__
if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker):
return True
collect_mesh_name = dispatch_mode["collect_fn"].args[0]
collect_fn = dispatch_mode["collect_fn"]
base_fn = getattr(collect_fn, "func", collect_fn)
# Prefer an explicit attribute on the collect function, fall back to name-based check.
is_lazy_collect = getattr(
base_fn,
"is_lazy_compute_data_proto",
base_fn.__name__ == "collect_lazy_compute_data_proto",
)
if not is_lazy_collect or len(args) < 1 or not isinstance(args[0], Worker):
return True
collect_mesh_name = collect_fn.args[0]

Copilot uses AI. Check for mistakes.
return args[0]._Worker__collect_dp_rank[collect_mesh_name]
Comment thread
0oshowero0 marked this conversation as resolved.


def _postprocess_common(output, put_data, need_collect):
"""Common post-processing logic for function outputs in TransferQueue bridge.

This function handles the final return value based on whether data should be
put into storage (put_data) and whether collection is needed (need_collect).
It ensures proper return types based on the execution context.

Args:
output: The original output from the decorated function. Can be any type,
typically DataProto when working with transfer queues.
put_data: bool, indicating whether the output should be stored in TransferQueue.
If True, output will be converted to BatchMeta; if False, returned as-is
or converted to DataProto.
need_collect: bool, indicating whether this process needs to collect data.
If False and put_data is True, returns empty BatchMeta to avoid
redundant storage.

Returns:
- BatchMeta.empty(): When put_data=True but need_collect=False, indicating
no data should be stored but BatchMeta structure is expected.
- DataProto(): When put_data=False, need_collect=False, and output is DataProto,
returning an empty DataProto.
- output: In all other cases, returns the original output unchanged.

Note:
This function is used in the tqbridge decorator to normalize return values
across different execution paths and avoid redundant data operations in
distributed scenarios.
"""
if put_data and not need_collect:
Comment thread
jianjunzhong marked this conversation as resolved.
return BatchMeta.empty()
elif not put_data and not need_collect and isinstance(output, DataProto):
return DataProto()
else:
return output


def tqbridge(dispatch_mode: dict | "Dispatch" = None, put_data: bool = True):
"""Creates a decorator for bridging BatchMeta and DataProto.

This decorator automatically handles conversions between `BatchMeta` and
Expand All @@ -155,6 +235,9 @@ def tqbridge(put_data: bool = True):
simply calls the original function as-is).

Args:
dispatch_mode: Controls data collection behavior for the current worker. Passed to
_compute_need_collect to determine if current worker should collect data.
If None, _compute_need_collect returns True (current worker collects).
put_data: Whether put the DataProto into Storage after func return.
If True, after function execution, the output result will be
updated to `BatchMeta` and `BatchMeta` will be returned;
Expand All @@ -178,24 +261,15 @@ def inner(*args, **kwargs):
f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, "
f"global_idx={batchmeta.global_indexes}"
)

if "collect_from_rank" in kwargs:
collect_from_rank = kwargs["collect_from_rank"]
kwargs.pop("collect_from_rank")
else:
collect_from_rank = None

args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
output = func(*args, **kwargs)

if put_data and collect_from_rank:
need_collect = _compute_need_collect(dispatch_mode, args)
updated_batch_meta = None
if put_data and need_collect:
updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__)
return updated_batch_meta
elif collect_from_rank == False:
return BatchMeta()
else:
return output
return _postprocess_common(output, put_data, need_collect, updated_batch_meta)

@wraps(func)
async def async_inner(*args, **kwargs):
Expand All @@ -207,39 +281,34 @@ async def async_inner(*args, **kwargs):
f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, "
f"global_idx={batchmeta.global_indexes}"
)

if "collect_from_rank" in kwargs:
collect_from_rank = kwargs["collect_from_rank"]
print(f"{func.__name__} with TQ put={kwargs['collect_from_rank']}")
kwargs.pop("collect_from_rank")
else:
collect_from_rank = None

args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args]
kwargs = {
k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v
for k, v in kwargs.items()
}
output = await func(*args, **kwargs)

if put_data and collect_from_rank:
need_collect = _compute_need_collect(dispatch_mode, args)
updated_batchmeta = None
if put_data and need_collect:
updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__)
return updated_batchmeta
elif collect_from_rank == False:
return BatchMeta()
return output
return _postprocess_common(output, put_data, need_collect)

@wraps(func)
def dummy_inner(*args, **kwargs):
if "collect_from_rank" in kwargs:
kwargs.pop("collect_from_rank")
return func(*args, **kwargs)
output = func(*args, **kwargs)
need_collect = _compute_need_collect(dispatch_mode, args)
if not need_collect:
return DataProto()
return output

@wraps(func)
async def dummy_async_inner(*args, **kwargs):
if "collect_from_rank" in kwargs:
kwargs.pop("collect_from_rank")
return await func(*args, **kwargs)
output = await func(*args, **kwargs)
need_collect = _compute_need_collect(dispatch_mode, args)
if not need_collect:
return DataProto()
return output

wrapper_inner = inner if is_transferqueue_enabled else dummy_inner
wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner
Expand Down
Loading