From 5ff95b801c294428fab5ae7280b5286e773ece28 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Thu, 18 Dec 2025 21:29:58 +0800 Subject: [PATCH 1/7] elegantly prevent data re-put within DP Signed-off-by: jianjunzhong --- verl/utils/transferqueue_utils.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 206d51899b4..05bf37ceb8b 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -144,7 +144,22 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun return updated_batch_meta -def tqbridge(put_data: bool = True): +def _compute_need_collect(dispatch_mode: dict, args: list) -> bool: + from verl.single_controller.base.worker import Worker + + if dispatch_mode is None: + return True + + assert "collect_fn" in dispatch_mode.keys() + collect_fn_name = dispatch_mode["collect_fn"].func.__name__ + if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): + return True + + collect_mesh_name = dispatch_mode["collect_fn"].args[0] + return args[0]._Worker__collect_dp_rank[collect_mesh_name] + + +def tqbridge(dispatch_mode=None, put_data: bool = True): """Creates a decorator for bridging BatchMeta and DataProto. This decorator automatically handles conversions between `BatchMeta` and @@ -181,7 +196,8 @@ def inner(*args, **kwargs): args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} output = func(*args, **kwargs) - if put_data: + need_collect = _compute_need_collect(dispatch_mode, args) + if put_data and need_collect: updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batch_meta else: @@ -203,7 +219,8 @@ async def async_inner(*args, **kwargs): for k, v in kwargs.items() } output = await func(*args, **kwargs) - if put_data: + need_collect = _compute_need_collect(dispatch_mode, args) + if put_data and need_collect: updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batchmeta return output From 6da262cce61f79415195ba8d95d758c1dc81ce6d Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Fri, 19 Dec 2025 13:21:02 +0800 Subject: [PATCH 2/7] fix Signed-off-by: jianjunzhong --- verl/single_controller/base/decorator.py | 2 +- verl/utils/transferqueue_utils.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 1fa0496eaaa..f764718ce04 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -442,7 +442,7 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki _check_execute_mode(execute_mode=execute_mode) def decorator(func): - func = tqbridge()(func) + func = tqbridge(dispatch_mode=dispatch_mode)(func) @wraps(func) def inner(*args, **kwargs): diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 05bf37ceb8b..632a9f32f35 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -150,7 +150,7 @@ def _compute_need_collect(dispatch_mode: dict, args: list) -> bool: if dispatch_mode is None: return True - assert "collect_fn" in dispatch_mode.keys() + assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode." collect_fn_name = dispatch_mode["collect_fn"].func.__name__ if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): return True @@ -170,6 +170,8 @@ def tqbridge(dispatch_mode=None, put_data: bool = True): simply calls the original function as-is). Args: + dispatch_mode: For controlling data collection logic. If None, + _compute_need_collect will always return True. put_data: Whether put the DataProto into Storage after func return. If True, after function execution, the output result will be updated to `BatchMeta` and `BatchMeta` will be returned; @@ -200,6 +202,8 @@ def inner(*args, **kwargs): if put_data and need_collect: updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batch_meta + elif not need_collect: + return BatchMeta.empty() else: return output @@ -223,7 +227,10 @@ async def async_inner(*args, **kwargs): if put_data and need_collect: updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batchmeta - return output + elif not need_collect: + return BatchMeta.empty() + else: + return output @wraps(func) def dummy_inner(*args, **kwargs): From a513570c8ada34a03024925bdf7f47d506427ed1 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Fri, 19 Dec 2025 14:33:41 +0800 Subject: [PATCH 3/7] remove ugly codes Signed-off-by: jianjunzhong --- verl/single_controller/base/decorator.py | 25 ++++++------------------ verl/utils/transferqueue_utils.py | 19 ------------------ 2 files changed, 6 insertions(+), 38 deletions(-) diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 4ce1811d1bb..3022e3cda6e 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -217,7 +217,7 @@ def collect_dp_compute_data_proto(worker_group, output): return _concat_data_proto_or_future(output) -def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, collect_mask, *args, **kwargs): +def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs): import os from verl.single_controller.base.worker_group import WorkerGroup @@ -248,10 +248,6 @@ def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, colle local_dp_rank = dp_rank_mapping[i] transformed_v.append(v[local_dp_rank]) all_kwargs[k] = transformed_v - - # add kwargs determing whether to collect from this rank - all_kwargs["collect_from_rank"] = [collect_mask[i] for i in range(worker_group.world_size)] - return all_args, all_kwargs @@ -269,13 +265,13 @@ def collect_nd_compute(collect_mask: list[bool], worker_group, output): return output_in_dp -def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, collect_mask, *args, **kwargs): +def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs): splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(dp_size, *args, **kwargs) - return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, collect_mask, *splitted_args, **splitted_kwargs) + return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs) -def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output): - output = collect_nd_compute(collect_mask, worker_group, output) +def collect_nd_compute_dataproto(worker_group, output): + output = collect_nd_compute(worker_group, output) import ray from verl.protocol import DataProto @@ -297,12 +293,6 @@ def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): worker_group._dispatch_info[mesh_name] = worker_group._query_dispatch_info(mesh_name) assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size - # the dispatch info is stored in the worker group - assert mesh_name in worker_group._dispatch_info - if mesh_name not in worker_group._collect_info: - worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name) - assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size - dp_rank_mapping = worker_group._dispatch_info[mesh_name] # a boolean of whether the dp_rank is used for collect @@ -324,11 +314,8 @@ def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): if mesh_name not in worker_group._collect_info: worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name) assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size - - # a boolean of whether the dp_rank is used for collect - collect_mask = worker_group._collect_info[mesh_name] # perform dispatch - return collect_nd_compute_dataproto(collect_mask, worker_group, *args, **kwargs) + return collect_nd_compute_dataproto(worker_group, *args, **kwargs) def make_nd_compute_dataproto_dispatch_fn(mesh_name): diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index d5853a18b07..632a9f32f35 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -195,13 +195,6 @@ def inner(*args, **kwargs): f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " f"global_idx={batchmeta.global_indexes}" ) - - if "collect_from_rank" in kwargs: - collect_from_rank = kwargs["collect_from_rank"] - kwargs.pop("collect_from_rank") - else: - collect_from_rank = None - args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} output = func(*args, **kwargs) @@ -224,14 +217,6 @@ async def async_inner(*args, **kwargs): f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " f"global_idx={batchmeta.global_indexes}" ) - - if "collect_from_rank" in kwargs: - collect_from_rank = kwargs["collect_from_rank"] - print(f"{func.__name__} with TQ put={kwargs['collect_from_rank']}") - kwargs.pop("collect_from_rank") - else: - collect_from_rank = None - args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = { k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v @@ -249,14 +234,10 @@ async def async_inner(*args, **kwargs): @wraps(func) def dummy_inner(*args, **kwargs): - if "collect_from_rank" in kwargs: - kwargs.pop("collect_from_rank") return func(*args, **kwargs) @wraps(func) async def dummy_async_inner(*args, **kwargs): - if "collect_from_rank" in kwargs: - kwargs.pop("collect_from_rank") return await func(*args, **kwargs) wrapper_inner = inner if is_transferqueue_enabled else dummy_inner From 9989e8371925797e2deb8c893ea7b5c779741473 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Sat, 20 Dec 2025 10:13:31 +0800 Subject: [PATCH 4/7] remove codes Signed-off-by: jianjunzhong --- verl/single_controller/base/decorator.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 3022e3cda6e..f764718ce04 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -270,8 +270,8 @@ def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_gr return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs) -def collect_nd_compute_dataproto(worker_group, output): - output = collect_nd_compute(worker_group, output) +def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output): + output = collect_nd_compute(collect_mask, worker_group, output) import ray from verl.protocol import DataProto @@ -294,13 +294,9 @@ def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size dp_rank_mapping = worker_group._dispatch_info[mesh_name] - - # a boolean of whether the dp_rank is used for collect - collect_mask = worker_group._collect_info[mesh_name] - # perform dispatch dp_size = max(dp_rank_mapping) + 1 - return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, collect_mask, *args, **kwargs) + return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs) def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): @@ -314,8 +310,11 @@ def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): if mesh_name not in worker_group._collect_info: worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name) assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size + + # a boolean of whether the dp_rank is used for collect + collect_mask = worker_group._collect_info[mesh_name] # perform dispatch - return collect_nd_compute_dataproto(worker_group, *args, **kwargs) + return collect_nd_compute_dataproto(collect_mask, worker_group, *args, **kwargs) def make_nd_compute_dataproto_dispatch_fn(mesh_name): From ce78ce4c1cc5f954173e9e4d0ef9083023699797 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Sat, 20 Dec 2025 11:27:53 +0800 Subject: [PATCH 5/7] optimize logic Signed-off-by: jianjunzhong --- verl/utils/transferqueue_utils.py | 33 +++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 632a9f32f35..4b5e56bbf99 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -159,6 +159,15 @@ def _compute_need_collect(dispatch_mode: dict, args: list) -> bool: return args[0]._Worker__collect_dp_rank[collect_mesh_name] +def _postprocess_common(output, put_data, need_collect): + if put_data and not need_collect: + return BatchMeta.empty() + elif not put_data and not need_collect and isinstance(output, DataProto): + return DataProto() + else: + return output + + def tqbridge(dispatch_mode=None, put_data: bool = True): """Creates a decorator for bridging BatchMeta and DataProto. @@ -199,13 +208,11 @@ def inner(*args, **kwargs): kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} output = func(*args, **kwargs) need_collect = _compute_need_collect(dispatch_mode, args) + updated_batch_meta = None if put_data and need_collect: updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batch_meta - elif not need_collect: - return BatchMeta.empty() - else: - return output + return _postprocess_common(output, put_data, need_collect, updated_batch_meta) @wraps(func) async def async_inner(*args, **kwargs): @@ -224,21 +231,27 @@ async def async_inner(*args, **kwargs): } output = await func(*args, **kwargs) need_collect = _compute_need_collect(dispatch_mode, args) + updated_batchmeta = None if put_data and need_collect: updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batchmeta - elif not need_collect: - return BatchMeta.empty() - else: - return output + return _postprocess_common(output, put_data, need_collect) @wraps(func) def dummy_inner(*args, **kwargs): - return func(*args, **kwargs) + output = func(*args, **kwargs) + need_collect = _compute_need_collect(dispatch_mode, args) + if not need_collect: + return DataProto() + return output @wraps(func) async def dummy_async_inner(*args, **kwargs): - return await func(*args, **kwargs) + output = await func(*args, **kwargs) + need_collect = _compute_need_collect(dispatch_mode, args) + if not need_collect: + return DataProto() + return output wrapper_inner = inner if is_transferqueue_enabled else dummy_inner wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner From 83810331a19b4fd8f9c83f1bee4ec284ec092bb0 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Sat, 20 Dec 2025 11:44:08 +0800 Subject: [PATCH 6/7] fix Signed-off-by: jianjunzhong --- verl/utils/transferqueue_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 4b5e56bbf99..cc496c15b28 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -145,9 +145,10 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun def _compute_need_collect(dispatch_mode: dict, args: list) -> bool: + from verl.single_controller.base.decorator import Dispatch from verl.single_controller.base.worker import Worker - if dispatch_mode is None: + if dispatch_mode is None or isinstance(dispatch_mode, Dispatch): return True assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode." From e58fdd6fff8e37e7eb972906a3fc3c0219b6e8a4 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Sat, 20 Dec 2025 16:22:07 +0800 Subject: [PATCH 7/7] add docstring Signed-off-by: jianjunzhong --- verl/utils/transferqueue_utils.py | 66 ++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index cc496c15b28..4a35cbfb56c 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -18,7 +18,10 @@ import os import threading from functools import wraps -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from verl.single_controller.base.decorator import Dispatch from tensordict import TensorDict @@ -144,7 +147,31 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun return updated_batch_meta -def _compute_need_collect(dispatch_mode: dict, args: list) -> bool: +def _compute_need_collect(dispatch_mode: dict | "Dispatch", args: list) -> bool: + """Compute whether data collection is needed for the current worker. + + This function determines whether the current worker should collect data based on + the dispatch mode configuration and worker parameters. It's used to optimize + distributed data collection by ensuring only the appropriate rank collects data. + + Args: + dispatch_mode: Controls data collection logic for the current worker. Can be None, + a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch, + always returns True (current worker should collect). If dict, checks + collect_fn for lazy compute optimization. + args: List of arguments passed to the function. Should contain a Worker instance + as the first argument when using lazy compute mode. + + Returns: + bool: True if data collection is needed, False otherwise. + + Note: + Only checks worker attributes when dispatch_mode is a dict with 'collect_fn', + the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker. + Otherwise, returns True. For the lazy compute case, checks the worker's + data parallel rank for the mesh specified in collect_fn.args[0] to determine + if this worker should collect data. + """ from verl.single_controller.base.decorator import Dispatch from verl.single_controller.base.worker import Worker @@ -161,6 +188,34 @@ def _compute_need_collect(dispatch_mode: dict, args: list) -> bool: def _postprocess_common(output, put_data, need_collect): + """Common post-processing logic for function outputs in TransferQueue bridge. + + This function handles the final return value based on whether data should be + put into storage (put_data) and whether collection is needed (need_collect). + It ensures proper return types based on the execution context. + + Args: + output: The original output from the decorated function. Can be any type, + typically DataProto when working with transfer queues. + put_data: bool, indicating whether the output should be stored in TransferQueue. + If True, output will be converted to BatchMeta; if False, returned as-is + or converted to DataProto. + need_collect: bool, indicating whether this process needs to collect data. + If False and put_data is True, returns empty BatchMeta to avoid + redundant storage. + + Returns: + - BatchMeta.empty(): When put_data=True but need_collect=False, indicating + no data should be stored but BatchMeta structure is expected. + - DataProto(): When put_data=False, need_collect=False, and output is DataProto, + returning an empty DataProto. + - output: In all other cases, returns the original output unchanged. + + Note: + This function is used in the tqbridge decorator to normalize return values + across different execution paths and avoid redundant data operations in + distributed scenarios. + """ if put_data and not need_collect: return BatchMeta.empty() elif not put_data and not need_collect and isinstance(output, DataProto): @@ -169,7 +224,7 @@ def _postprocess_common(output, put_data, need_collect): return output -def tqbridge(dispatch_mode=None, put_data: bool = True): +def tqbridge(dispatch_mode: dict | "Dispatch" = None, put_data: bool = True): """Creates a decorator for bridging BatchMeta and DataProto. This decorator automatically handles conversions between `BatchMeta` and @@ -180,8 +235,9 @@ def tqbridge(dispatch_mode=None, put_data: bool = True): simply calls the original function as-is). Args: - dispatch_mode: For controlling data collection logic. If None, - _compute_need_collect will always return True. + dispatch_mode: Controls data collection behavior for the current worker. Passed to + _compute_need_collect to determine if current worker should collect data. + If None, _compute_need_collect returns True (current worker collects). put_data: Whether put the DataProto into Storage after func return. If True, after function execution, the output result will be updated to `BatchMeta` and `BatchMeta` will be returned;