From c8eba5ff69dcccf3671cf75cd6dbaaf83044b06e Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Mon, 16 Feb 2026 03:30:58 +0000 Subject: [PATCH 1/8] feat(vision): add Vision DP for parallel ViT computation across Ulysses SP ranks Distribute whole images across Ulysses SP ranks for parallelized ViT computation, reducing ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x ViT memory reduction). Key changes: - Add roll/utils/context_parallel/vision_dp.py with image distribution utilities, GatherVisionEmbeddings autograd function, and model-agnostic VisionTransformer wrapper - Add apply_vision_dp_patch() in monkey_patch.py for Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3-VL-MoE VisionTransformer classes - Integrate into DeepSpeed strategy (both inference and training workers) - Add 17 unit tests covering all utility functions, edge cases, and integration workflows Ported from verl (https://github.com/verl-project/verl/pull/5230). Co-Authored-By: Claude Opus 4.6 --- .../strategy/deepspeed_strategy.py | 4 +- roll/utils/context_parallel/__init__.py | 17 +- roll/utils/context_parallel/monkey_patch.py | 98 +++++ roll/utils/context_parallel/vision_dp.py | 352 ++++++++++++++++++ tests/utils/test_vision_dp_on_cpu.py | 235 ++++++++++++ 5 files changed, 703 insertions(+), 3 deletions(-) create mode 100644 roll/utils/context_parallel/vision_dp.py create mode 100644 tests/utils/test_vision_dp_on_cpu.py diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index 58b7e1b4..f240ecfe 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -23,7 +23,7 @@ from roll.third_party.deepspeed.model_update import DeepSpeedWeightUpdater from roll.third_party.deepspeed.offload_states_patch import bind_deepspeed_offload_states_func from roll.utils.collective import collective -from roll.utils.context_parallel import get_ulysses_group, set_upg_manager +from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits from roll.utils.constants import IGNORE_INDEX @@ -69,6 +69,7 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: if current_platform.apply_ulysses_patch() is not None: set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) + apply_vision_dp_patch() else: cp_size = 1 @@ -332,6 +333,7 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: current_platform.apply_ulysses_patch() set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) + apply_vision_dp_patch() self.worker.rank_info.dp_rank = global_rank // cp_size self.worker.rank_info.dp_size = world_size // cp_size diff --git a/roll/utils/context_parallel/__init__.py b/roll/utils/context_parallel/__init__.py index 8112b8d2..cd3f0101 100644 --- a/roll/utils/context_parallel/__init__.py +++ b/roll/utils/context_parallel/__init__.py @@ -1,4 +1,17 @@ from roll.utils.context_parallel.globals import get_ulysses_group, set_upg_manager -from roll.utils.context_parallel.monkey_patch import apply_ulysses_patch, unapply_ulysses_patch +from roll.utils.context_parallel.monkey_patch import ( + apply_ulysses_patch, + apply_vision_dp_patch, + unapply_ulysses_patch, + unapply_vision_dp_patch, +) -__all__ = ["set_upg_manager", "get_ulysses_group", "apply_ulysses_patch", "unapply_ulysses_patch"] + +__all__ = [ + "set_upg_manager", + "get_ulysses_group", + "apply_ulysses_patch", + "apply_vision_dp_patch", + "unapply_ulysses_patch", + "unapply_vision_dp_patch", +] diff --git a/roll/utils/context_parallel/monkey_patch.py b/roll/utils/context_parallel/monkey_patch.py index a98ec66d..bf668139 100644 --- a/roll/utils/context_parallel/monkey_patch.py +++ b/roll/utils/context_parallel/monkey_patch.py @@ -13,6 +13,9 @@ else: old_update_causal_mask = None +# Store original vision forwards for unapply +_original_vision_forwards = {} + def apply_ulysses_patch(): from .ulysses_attention import _flash_attention_forward, _update_causal_mask @@ -35,6 +38,100 @@ def apply_ulysses_patch(): return patch_info +def apply_vision_dp_patch(): + """Patch VisionTransformer.forward for Vision Data Parallel. + + Distributes whole images across Ulysses SP ranks for parallelized ViT computation. + Each rank processes 1/sp_size of images, then all-gathers embeddings. + + This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction). + """ + from .vision_dp import create_dp_vision_forward + + # Patch Qwen2-VL VisionTransformer + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + + original = Qwen2VisionTransformerPretrainedModel.forward + _original_vision_forwards["qwen2_vl"] = original + Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}") + + # Patch Qwen2.5-VL VisionTransformer + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) + + original = Qwen2_5_VisionTransformerPretrainedModel.forward + _original_vision_forwards["qwen2_5_vl"] = original + Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}") + + # Patch Qwen3-VL VisionModel + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + original = Qwen3VLVisionModel.forward + _original_vision_forwards["qwen3_vl"] = original + Qwen3VLVisionModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen3VLVisionModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}") + + # Patch Qwen3-VL-MoE VisionModel + try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel + + original = Qwen3VLMoeVisionModel.forward + _original_vision_forwards["qwen3_vl_moe"] = original + Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original) + logger.info("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP") + except ImportError as e: + logger.debug(f"Qwen3-VL-MoE not available for Vision DP patch: {e}") + + +def unapply_vision_dp_patch(): + """Restore original VisionTransformer.forward methods.""" + if "qwen2_vl" in _original_vision_forwards: + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + + Qwen2VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_vl") + except ImportError: + pass + + if "qwen2_5_vl" in _original_vision_forwards: + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) + + Qwen2_5_VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_5_vl") + except ImportError: + pass + + if "qwen3_vl" in _original_vision_forwards: + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + Qwen3VLVisionModel.forward = _original_vision_forwards.pop("qwen3_vl") + except ImportError: + pass + + if "qwen3_vl_moe" in _original_vision_forwards: + try: + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel + + Qwen3VLMoeVisionModel.forward = _original_vision_forwards.pop("qwen3_vl_moe") + except ImportError: + pass + + def unapply_ulysses_patch(): global old_flash_attention_forward, old_update_causal_mask ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = old_flash_attention_forward @@ -47,3 +144,4 @@ def unapply_ulysses_patch(): unapply_hf_flash_attention_ulysses_patch() except Exception: pass + unapply_vision_dp_patch() diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py new file mode 100644 index 00000000..0f0c0116 --- /dev/null +++ b/roll/utils/context_parallel/vision_dp.py @@ -0,0 +1,352 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Alibaba Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Vision Data Parallel utilities for distributing ViT computation across Ulysses SP ranks. + +Ported from verl (https://github.com/verl-project/verl/pull/5230). + +Strategy: Distribute whole images across DP ranks, not patches within images. +This avoids breaking cu_seqlens semantics while parallelizing ViT computation. + +Key difference from text SP: +- Text SP: Split sequence within attention layers, all-to-all per layer +- Vision DP: Split images across ranks, all_gather once at the end +""" + +import torch +import torch.distributed as dist +from torch.autograd import Function + +from roll.utils.context_parallel.globals import get_ulysses_group, get_ulysses_size + + +def get_image_patch_counts(grid_thw: torch.Tensor) -> list[int]: + """Compute number of patches per image from grid_thw. + + Args: + grid_thw: Tensor of shape (num_images, 3) where each row is [t, h, w]. + + Returns: + List of patch counts per image. + """ + if grid_thw.numel() == 0: + return [] + return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + + +def get_image_embedding_counts(grid_thw: torch.Tensor, spatial_merge_size: int = 1) -> list[int]: + """Compute number of embeddings per image after spatial merging. + + Args: + grid_thw: Tensor of shape (num_images, 3) where each row is [t, h, w]. + spatial_merge_size: Spatial merge factor (typically 2 for Qwen-VL). + + Returns: + List of embedding counts per image. + """ + if grid_thw.numel() == 0: + return [] + if spatial_merge_size == 1: + return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + t = grid_thw[:, 0] + h = grid_thw[:, 1] // spatial_merge_size + w = grid_thw[:, 2] // spatial_merge_size + return (t * h * w).tolist() + + +def assign_images_to_dp_ranks( + patch_counts: list[int], + dp_size: int, +) -> tuple[list[list[int]], list[int]]: + """Assign whole images to DP ranks using contiguous distribution. + + Rank 0 gets images [0, 1, ...], rank 1 gets next chunk, etc. + This ensures no reordering is needed after all-gather. + + Args: + patch_counts: Number of patches per image. + dp_size: Number of DP ranks. + + Returns: + Tuple of (image_assignments, rank_loads) where: + - image_assignments[rank] = list of image indices assigned to that rank + - rank_loads[rank] = total patches assigned to that rank + """ + num_images = len(patch_counts) + if num_images == 0: + return [[] for _ in range(dp_size)], [0] * dp_size + + image_assignments: list[list[int]] = [[] for _ in range(dp_size)] + rank_loads = [0] * dp_size + + base_size = num_images // dp_size + remainder = num_images % dp_size + + start = 0 + for rank in range(dp_size): + chunk_size = base_size + (1 if rank < remainder else 0) + end = start + chunk_size + for img_idx in range(start, end): + image_assignments[rank].append(img_idx) + rank_loads[rank] += patch_counts[img_idx] + start = end + + return image_assignments, rank_loads + + +def prepare_local_vision_inputs( + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + image_assignments: list[list[int]], + dp_rank: int, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Extract pixel values and grid_thw for this DP rank's assigned images. + + Args: + pixel_values: All pixel values concatenated, shape (total_patches, dim). + grid_thw: Grid dimensions per image, shape (num_images, 3). + image_assignments: Per-rank image index assignments. + dp_rank: This rank's index in the DP group. + + Returns: + Tuple of (local_pixel_values, local_grid_thw, local_indices). + """ + local_indices = image_assignments[dp_rank] + + if len(local_indices) == 0: + return ( + torch.empty( + (0, pixel_values.shape[1]) if pixel_values.dim() > 1 else (0,), + dtype=pixel_values.dtype, + device=pixel_values.device, + ), + torch.empty((0, 3), dtype=grid_thw.dtype, device=grid_thw.device), + [], + ) + + patch_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + cumsum = [0] + for c in patch_counts: + cumsum.append(cumsum[-1] + c) + + local_patches = [] + local_grids = [] + for idx in local_indices: + start, end = cumsum[idx], cumsum[idx + 1] + local_patches.append(pixel_values[start:end]) + local_grids.append(grid_thw[idx : idx + 1]) + + local_pixel_values = torch.cat(local_patches, dim=0) + local_grid_thw = torch.cat(local_grids, dim=0) + + expected_patches = sum(patch_counts[idx] for idx in local_indices) + assert local_pixel_values.shape[0] == expected_patches + + return local_pixel_values, local_grid_thw, local_indices + + +class GatherVisionEmbeddings(Function): + """All-gather vision embeddings with gradient support. + + Contiguous assignment means simple concat without reordering. + Backward: scales gradients by dp_size to compensate for partial processing. + """ + + @staticmethod + def forward(ctx, local_embeddings, dp_group, grad_scaler=True): + ctx.grad_scaler = grad_scaler + dp_size = dist.get_world_size(dp_group) + dp_rank = dist.get_rank(dp_group) + ctx.dp_size = dp_size + + if dp_size == 1: + return local_embeddings + + local_count = torch.tensor( + [local_embeddings.shape[0]], dtype=torch.long, device=local_embeddings.device + ) + all_counts = [torch.zeros_like(local_count) for _ in range(dp_size)] + dist.all_gather(all_counts, local_count, group=dp_group) + all_counts = [c.item() for c in all_counts] + ctx.all_counts = all_counts + ctx.dp_rank = dp_rank + + max_count = max(all_counts) if all_counts else 0 + if max_count == 0: + return local_embeddings + + hidden_size = local_embeddings.shape[1] if local_embeddings.dim() > 1 else 1 + ctx.hidden_size = hidden_size + + if local_embeddings.shape[0] < max_count: + pad_size = max_count - local_embeddings.shape[0] + padding = torch.zeros( + (pad_size, hidden_size), + dtype=local_embeddings.dtype, + device=local_embeddings.device, + ) + local_padded = torch.cat([local_embeddings, padding], dim=0) + else: + local_padded = local_embeddings + + gathered = [torch.empty_like(local_padded) for _ in range(dp_size)] + dist.all_gather(gathered, local_padded, group=dp_group) + + result_chunks = [gathered[r][: all_counts[r]] for r in range(dp_size)] + result = torch.cat(result_chunks, dim=0) + return result + + @staticmethod + def backward(ctx, grad_output): + dp_size = ctx.dp_size + grad_scaler = ctx.grad_scaler + + if dp_size == 1: + return grad_output, None, None + + all_counts = ctx.all_counts + dp_rank = ctx.dp_rank + + if grad_scaler: + grad_output = grad_output * dp_size + + start = sum(all_counts[:dp_rank]) + end = start + all_counts[dp_rank] + local_grad = grad_output[start:end] + return local_grad, None, None + + +def gather_vision_embeddings(local_embeddings, dp_group=None, grad_scaler=True): + """All-gather vision embeddings from all DP ranks. + + Args: + local_embeddings: This rank's vision embeddings. + dp_group: Process group for all-gather. Defaults to Ulysses group. + grad_scaler: Whether to scale gradients in backward pass. + + Returns: + All-gathered embeddings concatenated across ranks. + """ + dp_group = get_ulysses_group() if dp_group is None else dp_group + if dp_group is None or dist.get_world_size(dp_group) == 1: + return local_embeddings + return GatherVisionEmbeddings.apply(local_embeddings, dp_group, grad_scaler) + + +def create_dp_vision_forward(original_forward): + """Wrap VisionTransformer.forward for Vision DP. + + Model-agnostic wrapper for any VisionTransformer with + ``forward(self, hidden_states, grid_thw, **kwargs) -> Tensor`` signature. + + When Ulysses SP size > 1, distributes images across SP ranks and + all-gathers the embeddings after ViT computation. + + Args: + original_forward: The original VisionTransformer.forward method. + + Returns: + Wrapped forward method with Vision DP support. + """ + + def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): + dp_size = get_ulysses_size() + if dp_size is None or dp_size <= 1: + return original_forward(self, hidden_states, grid_thw, **kwargs) + + dp_group = get_ulysses_group() + dp_rank = dist.get_rank(dp_group) + + # Step 1: Get image assignment + patch_counts = get_image_patch_counts(grid_thw) + total_patches = sum(patch_counts) + assert hidden_states.shape[0] == total_patches + + spatial_merge_size = 1 + if hasattr(self, "merger") and hasattr(self.merger, "spatial_merge_size"): + spatial_merge_size = self.merger.spatial_merge_size + elif hasattr(self, "spatial_merge_size"): + spatial_merge_size = self.spatial_merge_size + + embedding_counts = get_image_embedding_counts(grid_thw, spatial_merge_size) + total_embeddings = sum(embedding_counts) + + image_assignments, rank_loads = assign_images_to_dp_ranks(patch_counts, dp_size) + + # Step 2: Extract local inputs + local_pixels, local_grid_thw, local_indices = prepare_local_vision_inputs( + hidden_states, grid_thw, image_assignments, dp_rank + ) + + # Step 3: Process local images + if local_pixels.shape[0] > 0: + local_embeddings = original_forward(self, local_pixels, local_grid_thw, **kwargs) + else: + # Determine hidden_size for empty tensor + if hasattr(self, "merger") and hasattr(self.merger, "ln_q"): + ln_q = self.merger.ln_q + if hasattr(ln_q, "normalized_shape"): + hidden_size = ln_q.normalized_shape[0] + elif hasattr(ln_q, "weight"): + hidden_size = ln_q.weight.shape[0] + else: + raise RuntimeError( + "Cannot determine hidden_size from merger.ln_q: " + "no 'normalized_shape' or 'weight' attribute found" + ) + elif hasattr(self, "out_hidden_size"): + hidden_size = self.out_hidden_size + elif hasattr(self, "config") and hasattr(self.config, "hidden_size"): + hidden_size = self.config.hidden_size + else: + raise RuntimeError( + "Cannot determine hidden_size for empty Vision DP output. " + "Expected one of: self.merger.ln_q, self.out_hidden_size, self.config.hidden_size" + ) + + local_embeddings = torch.empty( + (0, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # Handle Qwen3-VL which returns (embeddings, deepstack_embeddings) + deepstack_outputs = None + if isinstance(local_embeddings, tuple): + local_embeddings, deepstack_outputs = local_embeddings[0], local_embeddings[1:] + + # Step 4: All-gather + all_embeddings = gather_vision_embeddings(local_embeddings, dp_group) + assert all_embeddings.shape[0] == total_embeddings + + if deepstack_outputs is not None: + # All-gather deepstack embeddings too + gathered_deepstack = [] + for ds_emb in deepstack_outputs: + if isinstance(ds_emb, list): + # List of tensors (one per deepstack layer) + gathered_list = [] + for single_emb in ds_emb: + gathered_list.append(gather_vision_embeddings(single_emb, dp_group)) + gathered_deepstack.append(gathered_list) + elif isinstance(ds_emb, torch.Tensor): + gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group)) + else: + gathered_deepstack.append(ds_emb) + return (all_embeddings, *gathered_deepstack) + + return all_embeddings + + return dp_vision_forward diff --git a/tests/utils/test_vision_dp_on_cpu.py b/tests/utils/test_vision_dp_on_cpu.py new file mode 100644 index 00000000..86b6410a --- /dev/null +++ b/tests/utils/test_vision_dp_on_cpu.py @@ -0,0 +1,235 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Alibaba Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for Vision Data Parallel utilities. +Ported from verl (https://github.com/verl-project/verl/pull/5230). +""" + +import pytest +import torch + +from roll.utils.context_parallel.vision_dp import ( + assign_images_to_dp_ranks, + get_image_patch_counts, + prepare_local_vision_inputs, +) + + +class TestGetImagePatchCounts: + """Tests for get_image_patch_counts function.""" + + def test_basic_patch_counts(self): + grid_thw = torch.tensor([ + [2, 4, 4], # 2*4*4 = 32 + [1, 2, 2], # 1*2*2 = 4 + [1, 8, 8], # 1*8*8 = 64 + ]) + counts = get_image_patch_counts(grid_thw) + assert counts == [32, 4, 64] + + def test_single_image(self): + grid_thw = torch.tensor([[1, 4, 4]]) # 16 patches + counts = get_image_patch_counts(grid_thw) + assert counts == [16] + + def test_empty_input(self): + grid_thw = torch.empty((0, 3), dtype=torch.long) + counts = get_image_patch_counts(grid_thw) + assert counts == [] + + def test_video_frames(self): + grid_thw = torch.tensor([[4, 4, 4]]) # 4 frames, 4*4 patches each = 64 + counts = get_image_patch_counts(grid_thw) + assert counts == [64] + + +class TestAssignImagesToDpRanks: + """Tests for assign_images_to_dp_ranks function.""" + + def test_balanced_assignment(self): + patch_counts = [100, 100, 100, 100] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + assert len(assignments[0]) == 2 + assert len(assignments[1]) == 2 + assert loads[0] == 200 + assert loads[1] == 200 + + def test_imbalanced_images(self): + patch_counts = [500, 100, 100, 100] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + total_assigned = sum(len(a) for a in assignments) + assert total_assigned == 4 + assert 0 in assignments[0] or 0 in assignments[1] + + def test_fewer_images_than_ranks(self): + patch_counts = [100, 200] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + non_empty_ranks = sum(1 for a in assignments if len(a) > 0) + assert non_empty_ranks == 2 + all_assigned = set() + for a in assignments: + all_assigned.update(a) + assert all_assigned == {0, 1} + + def test_empty_input(self): + patch_counts = [] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + assert all(len(a) == 0 for a in assignments) + assert all(load == 0 for load in loads) + + def test_single_rank(self): + patch_counts = [100, 200, 300] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=1) + assert assignments == [[0, 1, 2]] + assert loads == [600] + + def test_equal_images_equal_size(self): + patch_counts = [100, 100, 100, 100, 100, 100] # 6 images + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=3) + assert all(len(a) == 2 for a in assignments) + assert all(load == 200 for load in loads) + + def test_image_order_preserved(self): + patch_counts = [10, 20, 30, 40, 50] + assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size=2) + for rank_assignment in assignments: + assert rank_assignment == sorted(rank_assignment) + + +class TestPrepareLocalVisionInputs: + """Tests for prepare_local_vision_inputs function.""" + + def test_basic_extraction(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([ + [1, 6, 6], # 36 patches (indices 0-35) + [1, 8, 8], # 64 patches (indices 36-99) + ]) + image_assignments = [[0], [1]] + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=0 + ) + assert local_pix.shape[0] == 36 + assert local_grid.shape[0] == 1 + assert local_indices == [0] + assert torch.allclose(local_pix, pixel_values[:36]) + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=1 + ) + assert local_pix.shape[0] == 64 + assert local_grid.shape[0] == 1 + assert local_indices == [1] + assert torch.allclose(local_pix, pixel_values[36:100]) + + def test_multiple_images_per_rank(self): + pixel_values = torch.randn(200, 768) + grid_thw = torch.tensor([ + [1, 5, 10], # 50 patches + [1, 5, 10], # 50 patches + [1, 5, 10], # 50 patches + [1, 5, 10], # 50 patches + ]) + image_assignments = [[0, 2], [1, 3]] + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=0 + ) + assert local_pix.shape[0] == 100 + assert local_grid.shape[0] == 2 + assert local_indices == [0, 2] + expected = torch.cat([pixel_values[0:50], pixel_values[100:150]], dim=0) + assert torch.allclose(local_pix, expected) + + def test_empty_rank(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0], []] + + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=1 + ) + assert local_pix.shape[0] == 0 + assert local_grid.shape[0] == 0 + assert local_indices == [] + + def test_grid_thw_preserved(self): + pixel_values = torch.randn(150, 768) + grid_thw = torch.tensor([ + [1, 5, 5], # 25 patches + [2, 5, 5], # 50 patches + [3, 5, 5], # 75 patches + ]) + image_assignments = [[0, 2], [1]] + + _, local_grid, _ = prepare_local_vision_inputs( + pixel_values, grid_thw, image_assignments, dp_rank=0 + ) + assert local_grid.shape == (2, 3) + assert torch.equal(local_grid[0], grid_thw[0]) + assert torch.equal(local_grid[1], grid_thw[2]) + + +class TestIntegration: + """Integration tests combining multiple functions.""" + + def test_full_workflow(self): + grid_thw = torch.tensor([ + [1, 4, 4], # 16 patches + [1, 8, 8], # 64 patches + [1, 4, 4], # 16 patches + [1, 6, 6], # 36 patches + [1, 4, 4], # 16 patches + ]) + total_patches = 16 + 64 + 16 + 36 + 16 # 148 + pixel_values = torch.randn(total_patches, 768) + + patch_counts = get_image_patch_counts(grid_thw) + assert patch_counts == [16, 64, 16, 36, 16] + + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3, 4] + + total_local_patches = 0 + for rank in range(2): + local_pix, local_grid, local_indices = prepare_local_vision_inputs( + pixel_values, grid_thw, assignments, dp_rank=rank + ) + expected_patches = sum(patch_counts[i] for i in local_indices) + assert local_pix.shape[0] == expected_patches + assert local_grid.shape[0] == len(local_indices) + total_local_patches += local_pix.shape[0] + + assert total_local_patches == total_patches + + def test_same_size_images(self): + num_images = 50 + patch_per_image = 64 + grid_thw = torch.tensor([[1, 8, 8]] * num_images) + total_patches = num_images * patch_per_image + _ = torch.randn(total_patches, 768) + + patch_counts = get_image_patch_counts(grid_thw) + assert all(c == 64 for c in patch_counts) + + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + for rank in range(4): + assert 12 <= len(assignments[rank]) <= 13 + for load in loads: + assert load in [768, 832] From 1b13eafe2c7dc7297b3c81260f9c1220eef1bfa8 Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Tue, 24 Feb 2026 13:54:20 +0000 Subject: [PATCH 2/8] fix(vision_dp): fix gradient routing, load balancing, and efficiency issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address reviewer comments (same fixes as verl PR #5230 and AReaL PR #929): 1. **Gradient routing fix (critical)**: Replace `grad_scaler * dp_size` with `all_reduce(SUM)` in GatherVisionEmbeddings.backward() to aggregate partial sequence gradients before slicing. Fixes silent gradient loss when vision tokens span multiple sequence shard boundaries. 2. **Load-balanced assignment**: Replace count-based chunking with greedy contiguous bin-packing that balances total patch load across ranks. 3. **Remove unnecessary all_gather**: Pass pre-computed `all_counts` from caller instead of doing all_gather in forward. 4. **Idempotency guard**: Extract `_patch_vision_class()` helper with `_vision_dp_patched` attribute check. Add `_unapply_vision_class()` to properly clear the flag on unapply. 5. **Remove Qwen3-VL-MoE dead code**: Remove unreachable qwen3_vl_moe blocks from apply/unapply (not yet in transformers vl_model_mappings). 6. **GPU→CPU sync optimization**: Move `grid_thw.cpu()` to dp_vision_forward entry point to avoid repeated `.tolist()` GPU→CPU syncs. 7. **Tensor slicing**: Replace Python loop + list append in prepare_local_vision_inputs with contiguous tensor slice using cumsum. 8. **Test improvements**: Rename tests, add load balancing test, add gather_none_group test, use parametrize. Co-Authored-By: Claude Opus 4.6 --- roll/utils/context_parallel/monkey_patch.py | 92 +++---- roll/utils/context_parallel/vision_dp.py | 143 +++++++---- tests/utils/test_vision_dp_on_cpu.py | 267 +++++++++----------- 3 files changed, 252 insertions(+), 250 deletions(-) diff --git a/roll/utils/context_parallel/monkey_patch.py b/roll/utils/context_parallel/monkey_patch.py index bf668139..499005d5 100644 --- a/roll/utils/context_parallel/monkey_patch.py +++ b/roll/utils/context_parallel/monkey_patch.py @@ -38,6 +38,19 @@ def apply_ulysses_patch(): return patch_info +def _patch_vision_class(cls, key, class_name): + """Patch a single VisionTransformer class with Vision DP, with idempotency guard.""" + from .vision_dp import create_dp_vision_forward + + if getattr(cls, "_vision_dp_patched", False): + return + original = cls.forward + _original_vision_forwards[key] = original + cls.forward = create_dp_vision_forward(original) + cls._vision_dp_patched = True + logger.info(f"Monkey patch {class_name}.forward for Vision DP") + + def apply_vision_dp_patch(): """Patch VisionTransformer.forward for Vision Data Parallel. @@ -45,17 +58,13 @@ def apply_vision_dp_patch(): Each rank processes 1/sp_size of images, then all-gathers embeddings. This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction). + Safe to call multiple times -- each class is only patched once. """ - from .vision_dp import create_dp_vision_forward - # Patch Qwen2-VL VisionTransformer try: from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel - original = Qwen2VisionTransformerPretrainedModel.forward - _original_vision_forwards["qwen2_vl"] = original - Qwen2VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original) - logger.info("Monkey patch Qwen2VisionTransformerPretrainedModel.forward for Vision DP") + _patch_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl", "Qwen2VisionTransformerPretrainedModel") except ImportError as e: logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}") @@ -65,10 +74,9 @@ def apply_vision_dp_patch(): Qwen2_5_VisionTransformerPretrainedModel, ) - original = Qwen2_5_VisionTransformerPretrainedModel.forward - _original_vision_forwards["qwen2_5_vl"] = original - Qwen2_5_VisionTransformerPretrainedModel.forward = create_dp_vision_forward(original) - logger.info("Monkey patch Qwen2_5_VisionTransformerPretrainedModel.forward for Vision DP") + _patch_vision_class( + Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl", "Qwen2_5_VisionTransformerPretrainedModel" + ) except ImportError as e: logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}") @@ -76,60 +84,42 @@ def apply_vision_dp_patch(): try: from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel - original = Qwen3VLVisionModel.forward - _original_vision_forwards["qwen3_vl"] = original - Qwen3VLVisionModel.forward = create_dp_vision_forward(original) - logger.info("Monkey patch Qwen3VLVisionModel.forward for Vision DP") + _patch_vision_class(Qwen3VLVisionModel, "qwen3_vl", "Qwen3VLVisionModel") except ImportError as e: logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}") - # Patch Qwen3-VL-MoE VisionModel - try: - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel - original = Qwen3VLMoeVisionModel.forward - _original_vision_forwards["qwen3_vl_moe"] = original - Qwen3VLMoeVisionModel.forward = create_dp_vision_forward(original) - logger.info("Monkey patch Qwen3VLMoeVisionModel.forward for Vision DP") - except ImportError as e: - logger.debug(f"Qwen3-VL-MoE not available for Vision DP patch: {e}") +def _unapply_vision_class(cls, key): + """Restore a single VisionTransformer class, clearing the idempotency flag.""" + if key in _original_vision_forwards: + cls.forward = _original_vision_forwards.pop(key) + cls._vision_dp_patched = False def unapply_vision_dp_patch(): """Restore original VisionTransformer.forward methods.""" - if "qwen2_vl" in _original_vision_forwards: - try: - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel - - Qwen2VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_vl") - except ImportError: - pass - - if "qwen2_5_vl" in _original_vision_forwards: - try: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VisionTransformerPretrainedModel, - ) + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel - Qwen2_5_VisionTransformerPretrainedModel.forward = _original_vision_forwards.pop("qwen2_5_vl") - except ImportError: - pass + _unapply_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl") + except ImportError: + pass - if "qwen3_vl" in _original_vision_forwards: - try: - from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) - Qwen3VLVisionModel.forward = _original_vision_forwards.pop("qwen3_vl") - except ImportError: - pass + _unapply_vision_class(Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl") + except ImportError: + pass - if "qwen3_vl_moe" in _original_vision_forwards: - try: - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeVisionModel + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel - Qwen3VLMoeVisionModel.forward = _original_vision_forwards.pop("qwen3_vl_moe") - except ImportError: - pass + _unapply_vision_class(Qwen3VLVisionModel, "qwen3_vl") + except ImportError: + pass def unapply_ulysses_patch(): diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py index 0f0c0116..7ebff21f 100644 --- a/roll/utils/context_parallel/vision_dp.py +++ b/roll/utils/context_parallel/vision_dp.py @@ -20,9 +20,17 @@ Strategy: Distribute whole images across DP ranks, not patches within images. This avoids breaking cu_seqlens semantics while parallelizing ViT computation. -Key difference from text SP: -- Text SP: Split sequence within attention layers, all-to-all per layer -- Vision DP: Split images across ranks, all_gather once at the end +Key design choices: +- Image-level distribution (not patch-level): avoids breaking ViT's internal + cu_seqlens tracking +- Contiguous assignment: rank 0 gets images [0,1,...], rank 1 gets next chunk, etc. + No reordering needed after all-gather. +- Gradient sync in backward: all_reduce(SUM) across SP ranks before slicing to + recover the complete gradient for each image. Without this, gradients from + vision tokens in other ranks' sequence shards would be lost. +- No additional gradient scaling needed: the all_reduce aggregates partial + sequence gradients, making each rank's ViT backward equivalent to the non-DP + baseline. FSDP's dp_sp reduce-scatter then handles DP averaging as usual. """ import torch @@ -70,10 +78,12 @@ def assign_images_to_dp_ranks( patch_counts: list[int], dp_size: int, ) -> tuple[list[list[int]], list[int]]: - """Assign whole images to DP ranks using contiguous distribution. + """Assign whole images to DP ranks using load-balanced contiguous distribution. - Rank 0 gets images [0, 1, ...], rank 1 gets next chunk, etc. - This ensures no reordering is needed after all-gather. + The algorithm uses greedy contiguous bin-packing: + - Images are assigned in order (contiguous) to preserve ordering after gather + - Split points are chosen to balance total patch load across ranks + - Each rank gets at least one image when num_images >= dp_size Args: patch_counts: Number of patches per image. @@ -91,17 +101,34 @@ def assign_images_to_dp_ranks( image_assignments: list[list[int]] = [[] for _ in range(dp_size)] rank_loads = [0] * dp_size - base_size = num_images // dp_size - remainder = num_images % dp_size - - start = 0 + remaining_patches = sum(patch_counts) + img_idx = 0 for rank in range(dp_size): - chunk_size = base_size + (1 if rank < remainder else 0) - end = start + chunk_size - for img_idx in range(start, end): + remaining_ranks = dp_size - rank + remaining_images = num_images - img_idx + + if remaining_images <= 0: + break + + # Dynamic target: distribute remaining patches evenly among remaining ranks + target = remaining_patches / remaining_ranks + + # Must leave at least 1 image for each remaining rank + max_images = remaining_images - (remaining_ranks - 1) + + # Greedily add images until we reach the target load or hit the max + count = 0 + while img_idx < num_images and count < max_images: image_assignments[rank].append(img_idx) rank_loads[rank] += patch_counts[img_idx] - start = end + img_idx += 1 + count += 1 + + # Stop early once we've reached the target (always take at least 1) + if rank_loads[rank] >= target: + break + + remaining_patches -= rank_loads[rank] return image_assignments, rank_loads @@ -136,23 +163,32 @@ def prepare_local_vision_inputs( [], ) - patch_counts = (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() - cumsum = [0] - for c in patch_counts: - cumsum.append(cumsum[-1] + c) + # local_indices are contiguous (e.g. [2, 3, 4]), so use tensor slicing + first_img_idx = local_indices[0] + last_img_idx = local_indices[-1] + + # Compute patch offsets using cumsum + patch_counts = get_image_patch_counts(grid_thw) + patch_counts_tensor = torch.tensor(patch_counts, device=grid_thw.device, dtype=torch.long) + offsets = torch.cat( + ( + torch.tensor([0], device=grid_thw.device, dtype=torch.long), + torch.cumsum(patch_counts_tensor, dim=0), + ) + ) - local_patches = [] - local_grids = [] - for idx in local_indices: - start, end = cumsum[idx], cumsum[idx + 1] - local_patches.append(pixel_values[start:end]) - local_grids.append(grid_thw[idx : idx + 1]) + start_patch = offsets[first_img_idx].item() + end_patch = offsets[last_img_idx + 1].item() - local_pixel_values = torch.cat(local_patches, dim=0) - local_grid_thw = torch.cat(local_grids, dim=0) + local_pixel_values = pixel_values[start_patch:end_patch] + local_grid_thw = grid_thw[first_img_idx : last_img_idx + 1] - expected_patches = sum(patch_counts[idx] for idx in local_indices) - assert local_pixel_values.shape[0] == expected_patches + expected_patches = end_patch - start_patch + assert local_pixel_values.shape[0] == expected_patches, ( + f"[Vision DP] Local patch count mismatch: " + f"extracted={local_pixel_values.shape[0]}, expected={expected_patches}, " + f"local_indices={local_indices}" + ) return local_pixel_values, local_grid_thw, local_indices @@ -161,28 +197,22 @@ class GatherVisionEmbeddings(Function): """All-gather vision embeddings with gradient support. Contiguous assignment means simple concat without reordering. - Backward: scales gradients by dp_size to compensate for partial processing. + Backward: all_reduce(SUM) to aggregate gradients from all sequence shards, + then slice to extract this rank's image gradients. """ @staticmethod - def forward(ctx, local_embeddings, dp_group, grad_scaler=True): - ctx.grad_scaler = grad_scaler + def forward(ctx, local_embeddings, dp_group, all_counts: list[int]): dp_size = dist.get_world_size(dp_group) dp_rank = dist.get_rank(dp_group) ctx.dp_size = dp_size + ctx.dp_group = dp_group + ctx.all_counts = all_counts + ctx.dp_rank = dp_rank if dp_size == 1: return local_embeddings - local_count = torch.tensor( - [local_embeddings.shape[0]], dtype=torch.long, device=local_embeddings.device - ) - all_counts = [torch.zeros_like(local_count) for _ in range(dp_size)] - dist.all_gather(all_counts, local_count, group=dp_group) - all_counts = [c.item() for c in all_counts] - ctx.all_counts = all_counts - ctx.dp_rank = dp_rank - max_count = max(all_counts) if all_counts else 0 if max_count == 0: return local_embeddings @@ -211,16 +241,19 @@ def forward(ctx, local_embeddings, dp_group, grad_scaler=True): @staticmethod def backward(ctx, grad_output): dp_size = ctx.dp_size - grad_scaler = ctx.grad_scaler if dp_size == 1: return grad_output, None, None all_counts = ctx.all_counts dp_rank = ctx.dp_rank + dp_group = ctx.dp_group - if grad_scaler: - grad_output = grad_output * dp_size + # Aggregate gradient contributions from all SP ranks. + # Each rank only has non-zero grad for vision tokens in its own + # sequence shard. Summing across ranks recovers the complete + # gradient for every image before we slice by image assignment. + dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=dp_group) start = sum(all_counts[:dp_rank]) end = start + all_counts[dp_rank] @@ -228,13 +261,13 @@ def backward(ctx, grad_output): return local_grad, None, None -def gather_vision_embeddings(local_embeddings, dp_group=None, grad_scaler=True): +def gather_vision_embeddings(local_embeddings, dp_group, all_counts: list[int]): """All-gather vision embeddings from all DP ranks. Args: local_embeddings: This rank's vision embeddings. dp_group: Process group for all-gather. Defaults to Ulysses group. - grad_scaler: Whether to scale gradients in backward pass. + all_counts: Pre-computed embedding counts per rank (avoids an all_gather). Returns: All-gathered embeddings concatenated across ranks. @@ -242,7 +275,7 @@ def gather_vision_embeddings(local_embeddings, dp_group=None, grad_scaler=True): dp_group = get_ulysses_group() if dp_group is None else dp_group if dp_group is None or dist.get_world_size(dp_group) == 1: return local_embeddings - return GatherVisionEmbeddings.apply(local_embeddings, dp_group, grad_scaler) + return GatherVisionEmbeddings.apply(local_embeddings, dp_group, all_counts) def create_dp_vision_forward(original_forward): @@ -269,8 +302,12 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): dp_group = get_ulysses_group() dp_rank = dist.get_rank(dp_group) + # Move grid_thw to CPU once to avoid repeated GPU->CPU syncs in + # metadata helpers (grid_thw is a tiny [num_images, 3] tensor). + grid_thw_cpu = grid_thw.cpu() + # Step 1: Get image assignment - patch_counts = get_image_patch_counts(grid_thw) + patch_counts = get_image_patch_counts(grid_thw_cpu) total_patches = sum(patch_counts) assert hidden_states.shape[0] == total_patches @@ -280,10 +317,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): elif hasattr(self, "spatial_merge_size"): spatial_merge_size = self.spatial_merge_size - embedding_counts = get_image_embedding_counts(grid_thw, spatial_merge_size) + embedding_counts = get_image_embedding_counts(grid_thw_cpu, spatial_merge_size) total_embeddings = sum(embedding_counts) - image_assignments, rank_loads = assign_images_to_dp_ranks(patch_counts, dp_size) + image_assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size) # Step 2: Extract local inputs local_pixels, local_grid_thw, local_indices = prepare_local_vision_inputs( @@ -328,7 +365,9 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): local_embeddings, deepstack_outputs = local_embeddings[0], local_embeddings[1:] # Step 4: All-gather - all_embeddings = gather_vision_embeddings(local_embeddings, dp_group) + # Compute per-rank embedding counts locally (grid_thw is replicated on all ranks) + all_counts = [sum(embedding_counts[i] for i in image_assignments[r]) for r in range(dp_size)] + all_embeddings = gather_vision_embeddings(local_embeddings, dp_group, all_counts) assert all_embeddings.shape[0] == total_embeddings if deepstack_outputs is not None: @@ -339,10 +378,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): # List of tensors (one per deepstack layer) gathered_list = [] for single_emb in ds_emb: - gathered_list.append(gather_vision_embeddings(single_emb, dp_group)) + gathered_list.append(gather_vision_embeddings(single_emb, dp_group, all_counts)) gathered_deepstack.append(gathered_list) elif isinstance(ds_emb, torch.Tensor): - gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group)) + gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group, all_counts)) else: gathered_deepstack.append(ds_emb) return (all_embeddings, *gathered_deepstack) diff --git a/tests/utils/test_vision_dp_on_cpu.py b/tests/utils/test_vision_dp_on_cpu.py index 86b6410a..5efd6e42 100644 --- a/tests/utils/test_vision_dp_on_cpu.py +++ b/tests/utils/test_vision_dp_on_cpu.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Unit tests for Vision Data Parallel utilities. +Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). + Ported from verl (https://github.com/verl-project/verl/pull/5230). """ @@ -22,178 +23,157 @@ from roll.utils.context_parallel.vision_dp import ( assign_images_to_dp_ranks, + gather_vision_embeddings, + get_image_embedding_counts, get_image_patch_counts, prepare_local_vision_inputs, ) class TestGetImagePatchCounts: - """Tests for get_image_patch_counts function.""" - - def test_basic_patch_counts(self): - grid_thw = torch.tensor([ - [2, 4, 4], # 2*4*4 = 32 - [1, 2, 2], # 1*2*2 = 4 - [1, 8, 8], # 1*8*8 = 64 - ]) - counts = get_image_patch_counts(grid_thw) - assert counts == [32, 4, 64] - - def test_single_image(self): - grid_thw = torch.tensor([[1, 4, 4]]) # 16 patches - counts = get_image_patch_counts(grid_thw) - assert counts == [16] - - def test_empty_input(self): - grid_thw = torch.empty((0, 3), dtype=torch.long) - counts = get_image_patch_counts(grid_thw) + @pytest.mark.parametrize( + "grid_thw,expected", + [ + ([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]), + ([[1, 4, 4]], [16]), + ([[4, 4, 4]], [64]), + ], + ids=["multi-image", "single-image", "video-frames"], + ) + def test_patch_counts_various_grids_correct_products(self, grid_thw, expected): + counts = get_image_patch_counts(torch.tensor(grid_thw)) + assert counts == expected + + def test_patch_counts_empty_input_returns_empty_list(self): + counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) assert counts == [] - def test_video_frames(self): - grid_thw = torch.tensor([[4, 4, 4]]) # 4 frames, 4*4 patches each = 64 - counts = get_image_patch_counts(grid_thw) - assert counts == [64] - -class TestAssignImagesToDpRanks: - """Tests for assign_images_to_dp_ranks function.""" +class TestGetImageEmbeddingCounts: + @pytest.mark.parametrize( + "grid_thw,merge_size,expected", + [ + ([[1, 8, 8]], 1, [64]), + ([[1, 8, 8]], 2, [16]), + ([[1, 6, 6], [1, 4, 4]], 2, [9, 4]), + ], + ids=["no-merge", "merge-2", "multi-image-merge"], + ) + def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, expected): + counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) + assert counts == expected - def test_balanced_assignment(self): - patch_counts = [100, 100, 100, 100] - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) - assert len(assignments[0]) == 2 - assert len(assignments[1]) == 2 - assert loads[0] == 200 - assert loads[1] == 200 - def test_imbalanced_images(self): - patch_counts = [500, 100, 100, 100] - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) - total_assigned = sum(len(a) for a in assignments) - assert total_assigned == 4 - assert 0 in assignments[0] or 0 in assignments[1] +class TestAssignImagesToDpRanks: + @pytest.mark.parametrize( + "patch_counts,dp_size,expected_all_assigned", + [ + ([100, 100, 100, 100], 2, True), + ([100, 200, 300], 1, True), + ([100, 100, 100, 100, 100, 100], 3, True), + ], + ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], + ) + def test_assign_all_images_distributed(self, patch_counts, dp_size, expected_all_assigned): + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == list(range(len(patch_counts))) + assert sum(loads) == sum(patch_counts) - def test_fewer_images_than_ranks(self): - patch_counts = [100, 200] - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) - non_empty_ranks = sum(1 for a in assignments if len(a) > 0) - assert non_empty_ranks == 2 + def test_assign_fewer_images_than_ranks_all_assigned(self): + assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4) + non_empty = sum(1 for a in assignments if len(a) > 0) + assert non_empty == 2 all_assigned = set() for a in assignments: all_assigned.update(a) assert all_assigned == {0, 1} - def test_empty_input(self): - patch_counts = [] - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + def test_assign_empty_input_returns_empty(self): + assignments, loads = assign_images_to_dp_ranks([], dp_size=4) assert all(len(a) == 0 for a in assignments) assert all(load == 0 for load in loads) - def test_single_rank(self): - patch_counts = [100, 200, 300] - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=1) - assert assignments == [[0, 1, 2]] - assert loads == [600] - - def test_equal_images_equal_size(self): - patch_counts = [100, 100, 100, 100, 100, 100] # 6 images - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=3) - assert all(len(a) == 2 for a in assignments) - assert all(load == 200 for load in loads) - - def test_image_order_preserved(self): - patch_counts = [10, 20, 30, 40, 50] - assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size=2) + def test_assign_image_order_preserved_contiguous(self): + assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) for rank_assignment in assignments: assert rank_assignment == sorted(rank_assignment) + def test_assign_load_balanced_unequal_patches(self): + """With unequal patch counts, greedy balancing should reduce imbalance.""" + patch_counts = [4096, 256, 256, 256] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3] + max_load = max(loads) + min_load = min(load for load in loads if load > 0) + assert max_load / min_load < 8.0 -class TestPrepareLocalVisionInputs: - """Tests for prepare_local_vision_inputs function.""" - def test_basic_extraction(self): +class TestPrepareLocalVisionInputs: + def test_prepare_two_images_splits_correctly(self): pixel_values = torch.randn(100, 768) - grid_thw = torch.tensor([ - [1, 6, 6], # 36 patches (indices 0-35) - [1, 8, 8], # 64 patches (indices 36-99) - ]) + grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 image_assignments = [[0], [1]] - local_pix, local_grid, local_indices = prepare_local_vision_inputs( - pixel_values, grid_thw, image_assignments, dp_rank=0 - ) - assert local_pix.shape[0] == 36 - assert local_grid.shape[0] == 1 - assert local_indices == [0] - assert torch.allclose(local_pix, pixel_values[:36]) - - local_pix, local_grid, local_indices = prepare_local_vision_inputs( - pixel_values, grid_thw, image_assignments, dp_rank=1 - ) - assert local_pix.shape[0] == 64 - assert local_grid.shape[0] == 1 - assert local_indices == [1] - assert torch.allclose(local_pix, pixel_values[36:100]) - - def test_multiple_images_per_rank(self): + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert pix.shape[0] == 36 + assert grid.shape[0] == 1 + assert indices == [0] + assert torch.allclose(pix, pixel_values[:36]) + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + assert pix.shape[0] == 64 + assert grid.shape[0] == 1 + assert indices == [1] + assert torch.allclose(pix, pixel_values[36:100]) + + def test_prepare_multiple_contiguous_images_per_rank(self): pixel_values = torch.randn(200, 768) - grid_thw = torch.tensor([ - [1, 5, 10], # 50 patches - [1, 5, 10], # 50 patches - [1, 5, 10], # 50 patches - [1, 5, 10], # 50 patches - ]) - image_assignments = [[0, 2], [1, 3]] - - local_pix, local_grid, local_indices = prepare_local_vision_inputs( - pixel_values, grid_thw, image_assignments, dp_rank=0 - ) - assert local_pix.shape[0] == 100 - assert local_grid.shape[0] == 2 - assert local_indices == [0, 2] - expected = torch.cat([pixel_values[0:50], pixel_values[100:150]], dim=0) - assert torch.allclose(local_pix, expected) - - def test_empty_rank(self): + grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches + image_assignments = [[0, 1], [2, 3]] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert pix.shape[0] == 100 + assert grid.shape[0] == 2 + assert indices == [0, 1] + assert torch.allclose(pix, pixel_values[:100]) + + def test_prepare_empty_rank_returns_empty(self): pixel_values = torch.randn(100, 768) grid_thw = torch.tensor([[1, 10, 10]]) image_assignments = [[0], []] - local_pix, local_grid, local_indices = prepare_local_vision_inputs( - pixel_values, grid_thw, image_assignments, dp_rank=1 - ) - assert local_pix.shape[0] == 0 - assert local_grid.shape[0] == 0 - assert local_indices == [] + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + assert pix.shape[0] == 0 + assert grid.shape[0] == 0 + assert indices == [] - def test_grid_thw_preserved(self): + def test_prepare_grid_thw_preserved(self): pixel_values = torch.randn(150, 768) - grid_thw = torch.tensor([ - [1, 5, 5], # 25 patches - [2, 5, 5], # 50 patches - [3, 5, 5], # 75 patches - ]) - image_assignments = [[0, 2], [1]] - - _, local_grid, _ = prepare_local_vision_inputs( - pixel_values, grid_thw, image_assignments, dp_rank=0 - ) + grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 + image_assignments = [[0, 1], [2]] + + _, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) assert local_grid.shape == (2, 3) assert torch.equal(local_grid[0], grid_thw[0]) - assert torch.equal(local_grid[1], grid_thw[2]) + assert torch.equal(local_grid[1], grid_thw[1]) + + +class TestGatherVisionEmbeddings: + def test_gather_none_group_returns_input(self): + embeddings = torch.randn(10, 64) + result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) + assert torch.equal(result, embeddings) class TestIntegration: - """Integration tests combining multiple functions.""" - - def test_full_workflow(self): - grid_thw = torch.tensor([ - [1, 4, 4], # 16 patches - [1, 8, 8], # 64 patches - [1, 4, 4], # 16 patches - [1, 6, 6], # 36 patches - [1, 4, 4], # 16 patches - ]) + def test_full_workflow_all_patches_covered(self): + grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]]) total_patches = 16 + 64 + 16 + 36 + 16 # 148 pixel_values = torch.randn(total_patches, 768) @@ -208,27 +188,20 @@ def test_full_workflow(self): total_local_patches = 0 for rank in range(2): - local_pix, local_grid, local_indices = prepare_local_vision_inputs( - pixel_values, grid_thw, assignments, dp_rank=rank - ) - expected_patches = sum(patch_counts[i] for i in local_indices) - assert local_pix.shape[0] == expected_patches - assert local_grid.shape[0] == len(local_indices) - total_local_patches += local_pix.shape[0] + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, assignments, dp_rank=rank) + expected = sum(patch_counts[i] for i in indices) + assert pix.shape[0] == expected + assert grid.shape[0] == len(indices) + total_local_patches += pix.shape[0] assert total_local_patches == total_patches - def test_same_size_images(self): + def test_same_size_images_4_ranks_balanced(self): num_images = 50 - patch_per_image = 64 grid_thw = torch.tensor([[1, 8, 8]] * num_images) - total_patches = num_images * patch_per_image - _ = torch.randn(total_patches, 768) - patch_counts = get_image_patch_counts(grid_thw) - assert all(c == 64 for c in patch_counts) - assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + for rank in range(4): assert 12 <= len(assignments[rank]) <= 13 for load in loads: From e941e679344a064efcee1f894203b2923f3f3f2a Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Tue, 3 Mar 2026 22:47:16 +0900 Subject: [PATCH 3/8] refactor(vision_dp): simplify docstrings, fix empty-rank backward, add contiguous guard - Trim verbose docstrings to concise one-liners - Delete dead store ctx.hidden_size (written in forward, never read in backward) - Simplify hidden_size detection: self.config.out_hidden_size - Add requires_grad_() for empty rank to participate in backward all_reduce - Add .contiguous() guard before all_reduce (NCCL requirement) - Reuse get_image_patch_counts in spatial_merge_size==1 path Co-Authored-By: Claude Opus 4.6 --- roll/utils/context_parallel/vision_dp.py | 141 ++++++----------------- 1 file changed, 37 insertions(+), 104 deletions(-) diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py index 7ebff21f..ec0dcfad 100644 --- a/roll/utils/context_parallel/vision_dp.py +++ b/roll/utils/context_parallel/vision_dp.py @@ -15,22 +15,11 @@ """ Vision Data Parallel utilities for distributing ViT computation across Ulysses SP ranks. -Ported from verl (https://github.com/verl-project/verl/pull/5230). +Distribute whole images across SP ranks, not patches within images. +Each rank runs ViT on its assigned images, then all-gather combines embeddings. +Backward all_reduce(SUM) recovers complete gradients before slicing by assignment. -Strategy: Distribute whole images across DP ranks, not patches within images. -This avoids breaking cu_seqlens semantics while parallelizing ViT computation. - -Key design choices: -- Image-level distribution (not patch-level): avoids breaking ViT's internal - cu_seqlens tracking -- Contiguous assignment: rank 0 gets images [0,1,...], rank 1 gets next chunk, etc. - No reordering needed after all-gather. -- Gradient sync in backward: all_reduce(SUM) across SP ranks before slicing to - recover the complete gradient for each image. Without this, gradients from - vision tokens in other ranks' sequence shards would be lost. -- No additional gradient scaling needed: the all_reduce aggregates partial - sequence gradients, making each rank's ViT backward equivalent to the non-DP - baseline. FSDP's dp_sp reduce-scatter then handles DP averaging as usual. +Ported from verl (https://github.com/verl-project/verl/pull/5230). """ import torch @@ -41,33 +30,18 @@ def get_image_patch_counts(grid_thw: torch.Tensor) -> list[int]: - """Compute number of patches per image from grid_thw. - - Args: - grid_thw: Tensor of shape (num_images, 3) where each row is [t, h, w]. - - Returns: - List of patch counts per image. - """ + """Return [t*h*w for each image] from a [num_images, 3] grid_thw tensor.""" if grid_thw.numel() == 0: return [] return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() def get_image_embedding_counts(grid_thw: torch.Tensor, spatial_merge_size: int = 1) -> list[int]: - """Compute number of embeddings per image after spatial merging. - - Args: - grid_thw: Tensor of shape (num_images, 3) where each row is [t, h, w]. - spatial_merge_size: Spatial merge factor (typically 2 for Qwen-VL). - - Returns: - List of embedding counts per image. - """ + """Return per-image embedding counts after spatial merging: t * (h/merge) * (w/merge).""" if grid_thw.numel() == 0: return [] if spatial_merge_size == 1: - return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + return get_image_patch_counts(grid_thw) t = grid_thw[:, 0] h = grid_thw[:, 1] // spatial_merge_size w = grid_thw[:, 2] // spatial_merge_size @@ -78,21 +52,10 @@ def assign_images_to_dp_ranks( patch_counts: list[int], dp_size: int, ) -> tuple[list[list[int]], list[int]]: - """Assign whole images to DP ranks using load-balanced contiguous distribution. - - The algorithm uses greedy contiguous bin-packing: - - Images are assigned in order (contiguous) to preserve ordering after gather - - Split points are chosen to balance total patch load across ranks - - Each rank gets at least one image when num_images >= dp_size - - Args: - patch_counts: Number of patches per image. - dp_size: Number of DP ranks. + """Assign whole images to DP ranks via greedy contiguous bin-packing. - Returns: - Tuple of (image_assignments, rank_loads) where: - - image_assignments[rank] = list of image indices assigned to that rank - - rank_loads[rank] = total patches assigned to that rank + Returns (image_assignments, rank_patch_counts). Images are kept contiguous + so the gather result needs no reordering. """ num_images = len(patch_counts) if num_images == 0: @@ -141,14 +104,7 @@ def prepare_local_vision_inputs( ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """Extract pixel values and grid_thw for this DP rank's assigned images. - Args: - pixel_values: All pixel values concatenated, shape (total_patches, dim). - grid_thw: Grid dimensions per image, shape (num_images, 3). - image_assignments: Per-rank image index assignments. - dp_rank: This rank's index in the DP group. - - Returns: - Tuple of (local_pixel_values, local_grid_thw, local_indices). + Exploits contiguous assignment: a single slice instead of per-image cat. """ local_indices = image_assignments[dp_rank] @@ -218,7 +174,6 @@ def forward(ctx, local_embeddings, dp_group, all_counts: list[int]): return local_embeddings hidden_size = local_embeddings.shape[1] if local_embeddings.dim() > 1 else 1 - ctx.hidden_size = hidden_size if local_embeddings.shape[0] < max_count: pad_size = max_count - local_embeddings.shape[0] @@ -249,29 +204,20 @@ def backward(ctx, grad_output): dp_rank = ctx.dp_rank dp_group = ctx.dp_group - # Aggregate gradient contributions from all SP ranks. - # Each rank only has non-zero grad for vision tokens in its own - # sequence shard. Summing across ranks recovers the complete - # gradient for every image before we slice by image assignment. - dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=dp_group) + # all_reduce(SUM) aggregates partial gradients from all SP ranks: + # each rank only has non-zero grad for vision tokens in its sequence shard. + # NCCL all_reduce requires contiguous tensors — defensive guard. + grad = grad_output.contiguous() + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=dp_group) start = sum(all_counts[:dp_rank]) end = start + all_counts[dp_rank] - local_grad = grad_output[start:end] + local_grad = grad[start:end] return local_grad, None, None def gather_vision_embeddings(local_embeddings, dp_group, all_counts: list[int]): - """All-gather vision embeddings from all DP ranks. - - Args: - local_embeddings: This rank's vision embeddings. - dp_group: Process group for all-gather. Defaults to Ulysses group. - all_counts: Pre-computed embedding counts per rank (avoids an all_gather). - - Returns: - All-gathered embeddings concatenated across ranks. - """ + """All-gather vision embeddings from all DP ranks with gradient support.""" dp_group = get_ulysses_group() if dp_group is None else dp_group if dp_group is None or dist.get_world_size(dp_group) == 1: return local_embeddings @@ -279,19 +225,18 @@ def gather_vision_embeddings(local_embeddings, dp_group, all_counts: list[int]): def create_dp_vision_forward(original_forward): - """Wrap VisionTransformer.forward for Vision DP. - - Model-agnostic wrapper for any VisionTransformer with - ``forward(self, hidden_states, grid_thw, **kwargs) -> Tensor`` signature. - - When Ulysses SP size > 1, distributes images across SP ranks and - all-gathers the embeddings after ViT computation. - - Args: - original_forward: The original VisionTransformer.forward method. - - Returns: - Wrapped forward method with Vision DP support. + """Wrap VisionTransformer.forward for Vision DP (Data Parallel across SP ranks). + + Strategy: + 1. Distribute whole images to SP ranks (not patches within images) + 2. Each rank processes its assigned images independently + 3. All-gather embeddings at the end (contiguous assignment, no reordering) + + Gradient correctness: after all-gather in forward, each SP rank's inputs_embeds + contains vision tokens from ALL images. But Ulysses gives each rank only its + sequence shard. In backward, each rank only has non-zero gradient for vision + tokens in its own shard. The all_reduce(SUM) in GatherVisionEmbeddings.backward + aggregates partial gradients from all ranks, recovering the complete gradient. """ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): @@ -331,26 +276,12 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): if local_pixels.shape[0] > 0: local_embeddings = original_forward(self, local_pixels, local_grid_thw, **kwargs) else: - # Determine hidden_size for empty tensor - if hasattr(self, "merger") and hasattr(self.merger, "ln_q"): - ln_q = self.merger.ln_q - if hasattr(ln_q, "normalized_shape"): - hidden_size = ln_q.normalized_shape[0] - elif hasattr(ln_q, "weight"): - hidden_size = ln_q.weight.shape[0] - else: - raise RuntimeError( - "Cannot determine hidden_size from merger.ln_q: " - "no 'normalized_shape' or 'weight' attribute found" - ) - elif hasattr(self, "out_hidden_size"): - hidden_size = self.out_hidden_size - elif hasattr(self, "config") and hasattr(self.config, "hidden_size"): - hidden_size = self.config.hidden_size - else: + # This rank has no images, create empty tensor with correct hidden size + hidden_size = getattr(getattr(self, "config", None), "out_hidden_size", None) + if hidden_size is None: raise RuntimeError( - "Cannot determine hidden_size for empty Vision DP output. " - "Expected one of: self.merger.ln_q, self.out_hidden_size, self.config.hidden_size" + f"Cannot determine hidden_size: self.config.out_hidden_size not found. " + f"Model type: {type(self).__name__}" ) local_embeddings = torch.empty( @@ -358,6 +289,8 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): dtype=hidden_states.dtype, device=hidden_states.device, ) + # Empty rank must participate in autograd for backward all_reduce + local_embeddings.requires_grad_() # Handle Qwen3-VL which returns (embeddings, deepstack_embeddings) deepstack_outputs = None From d46bb1acbcc1b7bf83a5f9fcf621116309e05b0b Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Wed, 4 Mar 2026 00:10:16 +0900 Subject: [PATCH 4/8] fix(vision_dp): fix Qwen3-VL deepstack NCCL deadlock on empty ranks Replace isinstance(tuple) check with model attribute detection (hasattr deepstack_merger_list). Empty ranks now create matching empty deepstack tensors and participate in all-gather, preventing NCCL deadlock when num_images < dp_size. Co-Authored-By: Claude Opus 4.6 --- roll/utils/context_parallel/vision_dp.py | 45 ++++++++++++++---------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py index ec0dcfad..3a3588bf 100644 --- a/roll/utils/context_parallel/vision_dp.py +++ b/roll/utils/context_parallel/vision_dp.py @@ -272,6 +272,10 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): hidden_states, grid_thw, image_assignments, dp_rank ) + # Detect Qwen3-VL deepstack: model attribute, not return type, + # because empty ranks don't call original_forward and can't inspect the return. + has_deepstack = hasattr(self, "deepstack_merger_list") + # Step 3: Process local images if local_pixels.shape[0] > 0: local_embeddings = original_forward(self, local_pixels, local_grid_thw, **kwargs) @@ -292,10 +296,21 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): # Empty rank must participate in autograd for backward all_reduce local_embeddings.requires_grad_() - # Handle Qwen3-VL which returns (embeddings, deepstack_embeddings) - deepstack_outputs = None - if isinstance(local_embeddings, tuple): - local_embeddings, deepstack_outputs = local_embeddings[0], local_embeddings[1:] + # Unpack Qwen3-VL deepstack: forward returns (embeddings, list[3 × Tensor]) + local_deepstack = None + if has_deepstack: + if isinstance(local_embeddings, tuple): + local_embeddings, local_deepstack = local_embeddings[0], local_embeddings[1] + else: + # Empty rank: create matching empty deepstack tensors + num_deepstack = len(self.deepstack_merger_list) + h = local_embeddings.shape[1] + local_deepstack = [ + torch.empty( + (0, h), dtype=hidden_states.dtype, device=hidden_states.device + ) + for _ in range(num_deepstack) + ] # Step 4: All-gather # Compute per-rank embedding counts locally (grid_thw is replicated on all ranks) @@ -303,21 +318,13 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): all_embeddings = gather_vision_embeddings(local_embeddings, dp_group, all_counts) assert all_embeddings.shape[0] == total_embeddings - if deepstack_outputs is not None: - # All-gather deepstack embeddings too - gathered_deepstack = [] - for ds_emb in deepstack_outputs: - if isinstance(ds_emb, list): - # List of tensors (one per deepstack layer) - gathered_list = [] - for single_emb in ds_emb: - gathered_list.append(gather_vision_embeddings(single_emb, dp_group, all_counts)) - gathered_deepstack.append(gathered_list) - elif isinstance(ds_emb, torch.Tensor): - gathered_deepstack.append(gather_vision_embeddings(ds_emb, dp_group, all_counts)) - else: - gathered_deepstack.append(ds_emb) - return (all_embeddings, *gathered_deepstack) + # Step 5: All-gather deepstack embeddings (all ranks must participate) + if local_deepstack is not None: + gathered_deepstack = [ + gather_vision_embeddings(ds, dp_group, all_counts) + for ds in local_deepstack + ] + return all_embeddings, gathered_deepstack return all_embeddings From 1d3b9ae2143882388f6b2dc1e60f021526ecf7ea Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Wed, 4 Mar 2026 00:57:42 +0900 Subject: [PATCH 5/8] feat(vision_dp): add vision_dp flag to gate Vision DP patching Add `vision_dp: bool = False` to ModelArguments and gate apply_vision_dp_patch() calls in both DeepSpeedInferStrategy and DeepSpeedTrainStrategy behind it. Vision DP is now opt-in. Co-Authored-By: Claude Opus 4.6 --- roll/configs/model_args.py | 4 ++++ roll/distributed/strategy/deepspeed_strategy.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/roll/configs/model_args.py b/roll/configs/model_args.py index c9b8b844..63664587 100644 --- a/roll/configs/model_args.py +++ b/roll/configs/model_args.py @@ -119,6 +119,10 @@ class ModelArguments(LoraArguments): default=1, metadata={"help": "The group size for Ulysses attention."}, ) + vision_dp: bool = field( + default=False, + metadata={"help": "Enable Vision DP: distribute ViT across Ulysses SP ranks."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index f240ecfe..498b3c6f 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -69,7 +69,8 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: if current_platform.apply_ulysses_patch() is not None: set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) - apply_vision_dp_patch() + if self.worker_config.model_args.vision_dp: + apply_vision_dp_patch() else: cp_size = 1 @@ -333,7 +334,8 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: current_platform.apply_ulysses_patch() set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) - apply_vision_dp_patch() + if self.worker_config.model_args.vision_dp: + apply_vision_dp_patch() self.worker.rank_info.dp_rank = global_rank // cp_size self.worker.rank_info.dp_size = world_size // cp_size From f466599455eb650655df4ba9ead53fc06c3a3da2 Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Wed, 4 Mar 2026 23:01:45 +0900 Subject: [PATCH 6/8] fix(vision_dp): fix tautological assertion + improve test coverage - Replace `expected_patches = end_patch - start_patch` (always-true by Python slicing) with independent cross-check via `get_image_patch_counts(local_grid_thw)` in prepare_local_vision_inputs() - Rename tests to `test___()` convention - Add missing tests: embedding_counts empty, contiguous coverage, gather same-storage Co-Authored-By: Claude Opus 4.6 --- roll/utils/context_parallel/vision_dp.py | 4 ++- tests/utils/test_vision_dp_on_cpu.py | 33 +++++++++++++++++++----- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py index 3a3588bf..889193d5 100644 --- a/roll/utils/context_parallel/vision_dp.py +++ b/roll/utils/context_parallel/vision_dp.py @@ -139,7 +139,9 @@ def prepare_local_vision_inputs( local_pixel_values = pixel_values[start_patch:end_patch] local_grid_thw = grid_thw[first_img_idx : last_img_idx + 1] - expected_patches = end_patch - start_patch + # Cross-check: verify extracted slice matches independently computed patch counts + independent_counts = get_image_patch_counts(local_grid_thw) + expected_patches = sum(independent_counts) assert local_pixel_values.shape[0] == expected_patches, ( f"[Vision DP] Local patch count mismatch: " f"extracted={local_pixel_values.shape[0]}, expected={expected_patches}, " diff --git a/tests/utils/test_vision_dp_on_cpu.py b/tests/utils/test_vision_dp_on_cpu.py index 5efd6e42..becc3703 100644 --- a/tests/utils/test_vision_dp_on_cpu.py +++ b/tests/utils/test_vision_dp_on_cpu.py @@ -16,6 +16,7 @@ Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). Ported from verl (https://github.com/verl-project/verl/pull/5230). +Test naming convention: test___() """ import pytest @@ -63,6 +64,10 @@ def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, ex counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) assert counts == expected + def test_embedding_counts_empty_input_returns_empty_list(self): + counts = get_image_embedding_counts(torch.empty((0, 3), dtype=torch.long)) + assert counts == [] + class TestAssignImagesToDpRanks: @pytest.mark.parametrize( @@ -74,7 +79,7 @@ class TestAssignImagesToDpRanks: ], ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], ) - def test_assign_all_images_distributed(self, patch_counts, dp_size, expected_all_assigned): + def test_assign_all_images_distributed_correctly(self, patch_counts, dp_size, expected_all_assigned): assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) all_assigned = [] for a in assignments: @@ -91,7 +96,7 @@ def test_assign_fewer_images_than_ranks_all_assigned(self): all_assigned.update(a) assert all_assigned == {0, 1} - def test_assign_empty_input_returns_empty(self): + def test_assign_empty_input_returns_empty_lists(self): assignments, loads = assign_images_to_dp_ranks([], dp_size=4) assert all(len(a) == 0 for a in assignments) assert all(load == 0 for load in loads) @@ -101,7 +106,7 @@ def test_assign_image_order_preserved_contiguous(self): for rank_assignment in assignments: assert rank_assignment == sorted(rank_assignment) - def test_assign_load_balanced_unequal_patches(self): + def test_assign_load_balanced_unequal_patches_reduces_imbalance(self): """With unequal patch counts, greedy balancing should reduce imbalance.""" patch_counts = [4096, 256, 256, 256] assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) @@ -113,6 +118,16 @@ def test_assign_load_balanced_unequal_patches(self): min_load = min(load for load in loads if load > 0) assert max_load / min_load < 8.0 + def test_assign_contiguous_coverage_all_dp_sizes(self): + """All images are covered exactly once across ranks for various dp_size.""" + patch_counts = [10, 20, 30, 40, 50, 60, 70] + for dp_size in [1, 2, 3, 4, 7]: + assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size) + all_indices = [] + for a in assignments: + all_indices.extend(a) + assert sorted(all_indices) == list(range(len(patch_counts))) + class TestPrepareLocalVisionInputs: def test_prepare_two_images_splits_correctly(self): @@ -143,7 +158,7 @@ def test_prepare_multiple_contiguous_images_per_rank(self): assert indices == [0, 1] assert torch.allclose(pix, pixel_values[:100]) - def test_prepare_empty_rank_returns_empty(self): + def test_prepare_empty_rank_returns_empty_tensors(self): pixel_values = torch.randn(100, 768) grid_thw = torch.tensor([[1, 10, 10]]) image_assignments = [[0], []] @@ -153,7 +168,7 @@ def test_prepare_empty_rank_returns_empty(self): assert grid.shape[0] == 0 assert indices == [] - def test_prepare_grid_thw_preserved(self): + def test_prepare_local_inputs_grid_thw_values_preserved(self): pixel_values = torch.randn(150, 768) grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 image_assignments = [[0, 1], [2]] @@ -165,11 +180,17 @@ def test_prepare_grid_thw_preserved(self): class TestGatherVisionEmbeddings: - def test_gather_none_group_returns_input(self): + def test_gather_embeddings_none_group_returns_input_unchanged(self): embeddings = torch.randn(10, 64) result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) assert torch.equal(result, embeddings) + def test_gather_embeddings_none_group_same_storage(self): + """Single-rank group should short-circuit and return same tensor (not a copy).""" + embeddings = torch.randn(10, 64) + result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) + assert result.data_ptr() == embeddings.data_ptr() + class TestIntegration: def test_full_workflow_all_patches_covered(self): From 5906a4f35ca6dd4b92fff62c31c2c2837514b46a Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Wed, 4 Mar 2026 23:16:02 +0900 Subject: [PATCH 7/8] refactor(vision_dp): align error handling with verl PR #5230 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sync shared utility functions with verl's stricter error handling: - get_image_patch_counts/get_image_embedding_counts: empty grid_thw raises ValueError instead of returning [] - assign_images_to_dp_ranks: validate dp_size > 0, empty patch_counts raises ValueError instead of returning empty lists - prepare_local_vision_inputs: add dp_rank bounds check, use tensor-ops for offset computation (avoid Python-list round-trip), add int() cast - GatherVisionEmbeddings.forward: dp_size<=1 raises RuntimeError, validate all_counts length, max_count==0 raises RuntimeError - GatherVisionEmbeddings.backward: assert dp_size>1, add CUDA check - dp_vision_forward: sp_size<=1 raises RuntimeError, use GatherVisionEmbeddings.apply() directly, add detailed assert messages - Update tests to match: empty→raises, add dp_size/dp_rank validation Co-Authored-By: Claude Opus 4.6 --- roll/utils/context_parallel/vision_dp.py | 126 +++++++++++++++++------ tests/utils/test_vision_dp_on_cpu.py | 41 ++++++-- 2 files changed, 127 insertions(+), 40 deletions(-) diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py index 889193d5..8a070789 100644 --- a/roll/utils/context_parallel/vision_dp.py +++ b/roll/utils/context_parallel/vision_dp.py @@ -32,16 +32,19 @@ def get_image_patch_counts(grid_thw: torch.Tensor) -> list[int]: """Return [t*h*w for each image] from a [num_images, 3] grid_thw tensor.""" if grid_thw.numel() == 0: - return [] + raise ValueError("grid_thw is empty — Vision DP should only be called when images are present") return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() def get_image_embedding_counts(grid_thw: torch.Tensor, spatial_merge_size: int = 1) -> list[int]: """Return per-image embedding counts after spatial merging: t * (h/merge) * (w/merge).""" if grid_thw.numel() == 0: - return [] + raise ValueError("grid_thw is empty — Vision DP should only be called when images are present") + if spatial_merge_size == 1: return get_image_patch_counts(grid_thw) + + # Apply spatial merging: h and w are divided by spatial_merge_size t = grid_thw[:, 0] h = grid_thw[:, 1] // spatial_merge_size w = grid_thw[:, 2] // spatial_merge_size @@ -57,9 +60,12 @@ def assign_images_to_dp_ranks( Returns (image_assignments, rank_patch_counts). Images are kept contiguous so the gather result needs no reordering. """ + if dp_size <= 0: + raise ValueError(f"dp_size must be positive, got {dp_size}") + num_images = len(patch_counts) if num_images == 0: - return [[] for _ in range(dp_size)], [0] * dp_size + raise ValueError("patch_counts is empty — Vision DP should only be called when images are present") image_assignments: list[list[int]] = [[] for _ in range(dp_size)] rank_loads = [0] * dp_size @@ -106,6 +112,12 @@ def prepare_local_vision_inputs( Exploits contiguous assignment: a single slice instead of per-image cat. """ + if dp_rank < 0 or dp_rank >= len(image_assignments): + raise ValueError( + f"dp_rank={dp_rank} out of range for image_assignments with " + f"{len(image_assignments)} ranks" + ) + local_indices = image_assignments[dp_rank] if len(local_indices) == 0: @@ -123,18 +135,17 @@ def prepare_local_vision_inputs( first_img_idx = local_indices[0] last_img_idx = local_indices[-1] - # Compute patch offsets using cumsum - patch_counts = get_image_patch_counts(grid_thw) - patch_counts_tensor = torch.tensor(patch_counts, device=grid_thw.device, dtype=torch.long) + # Compute patch offsets using cumsum (grid_thw may be on CPU or GPU) + patch_counts = grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2] offsets = torch.cat( ( - torch.tensor([0], device=grid_thw.device, dtype=torch.long), - torch.cumsum(patch_counts_tensor, dim=0), + torch.zeros(1, device=grid_thw.device, dtype=patch_counts.dtype), + torch.cumsum(patch_counts, dim=0), ) ) - start_patch = offsets[first_img_idx].item() - end_patch = offsets[last_img_idx + 1].item() + start_patch = int(offsets[first_img_idx].item()) + end_patch = int(offsets[last_img_idx + 1].item()) local_pixel_values = pixel_values[start_patch:end_patch] local_grid_thw = grid_thw[first_img_idx : last_img_idx + 1] @@ -152,31 +163,52 @@ def prepare_local_vision_inputs( class GatherVisionEmbeddings(Function): - """All-gather vision embeddings with gradient support. + """ + All-gather vision embeddings with gradient support. - Contiguous assignment means simple concat without reordering. + Since images are assigned contiguously (rank 0 gets [0,1], rank 1 gets [2,3], etc.), + we can simply concat gathered results without reordering. + + Forward: all_gather + remove padding + concat Backward: all_reduce(SUM) to aggregate gradients from all sequence shards, - then slice to extract this rank's image gradients. + then slice to extract this rank's image gradients """ @staticmethod - def forward(ctx, local_embeddings, dp_group, all_counts: list[int]): + def forward( + ctx, + local_embeddings: torch.Tensor, + dp_group, + all_counts: list[int], + ) -> torch.Tensor: dp_size = dist.get_world_size(dp_group) + if dp_size <= 1: + raise RuntimeError( + "GatherVisionEmbeddings.forward called with dp_size=1. " + "Caller should short-circuit before reaching here." + ) dp_rank = dist.get_rank(dp_group) ctx.dp_size = dp_size ctx.dp_group = dp_group ctx.all_counts = all_counts ctx.dp_rank = dp_rank - if dp_size == 1: - return local_embeddings + if not all_counts or len(all_counts) != dp_size: + raise ValueError( + f"all_counts length ({len(all_counts) if all_counts else 0}) " + f"must equal dp_size ({dp_size})" + ) - max_count = max(all_counts) if all_counts else 0 + max_count = max(all_counts) if max_count == 0: - return local_embeddings + raise RuntimeError( + "all_counts are all zero — Vision DP gather should not be called " + "when no images are present" + ) hidden_size = local_embeddings.shape[1] if local_embeddings.dim() > 1 else 1 + # Pad to same length for all_gather if local_embeddings.shape[0] < max_count: pad_size = max_count - local_embeddings.shape[0] padding = torch.zeros( @@ -188,19 +220,23 @@ def forward(ctx, local_embeddings, dp_group, all_counts: list[int]): else: local_padded = local_embeddings + # All-gather gathered = [torch.empty_like(local_padded) for _ in range(dp_size)] dist.all_gather(gathered, local_padded, group=dp_group) + # Remove padding and concat (no reordering needed - contiguous assignment) result_chunks = [gathered[r][: all_counts[r]] for r in range(dp_size)] result = torch.cat(result_chunks, dim=0) + return result @staticmethod def backward(ctx, grad_output): dp_size = ctx.dp_size - - if dp_size == 1: - return grad_output, None, None + assert dp_size > 1, ( + f"GatherVisionEmbeddings.backward reached with dp_size={dp_size}. " + "Forward should never be called with dp_size<=1." + ) all_counts = ctx.all_counts dp_rank = ctx.dp_rank @@ -208,13 +244,22 @@ def backward(ctx, grad_output): # all_reduce(SUM) aggregates partial gradients from all SP ranks: # each rank only has non-zero grad for vision tokens in its sequence shard. - # NCCL all_reduce requires contiguous tensors — defensive guard. + if not grad_output.is_cuda: + raise RuntimeError( + "GatherVisionEmbeddings.backward requires CUDA tensors (NCCL backend). " + f"Got device={grad_output.device}" + ) + # NCCL all_reduce requires contiguous tensors. In the real training path + # (masked_scatter_backward → view), grad is already contiguous (no-op). + # Kept as defensive guard against upstream autograd changes. grad = grad_output.contiguous() dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=dp_group) + # Extract gradients for this rank's images (contiguous slice) start = sum(all_counts[:dp_rank]) end = start + all_counts[dp_rank] local_grad = grad[start:end] + return local_grad, None, None @@ -244,7 +289,10 @@ def create_dp_vision_forward(original_forward): def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): dp_size = get_ulysses_size() if dp_size is None or dp_size <= 1: - return original_forward(self, hidden_states, grid_thw, **kwargs) + raise RuntimeError( + f"sp_size={dp_size}, Vision DP should not be active — " + "monkey-patch is only applied when sp_size > 1" + ) dp_group = get_ulysses_group() dp_rank = dist.get_rank(dp_group) @@ -253,17 +301,25 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): # metadata helpers (grid_thw is a tiny [num_images, 3] tensor). grid_thw_cpu = grid_thw.cpu() - # Step 1: Get image assignment + # Step 1: Get image assignment based on patch counts patch_counts = get_image_patch_counts(grid_thw_cpu) total_patches = sum(patch_counts) - assert hidden_states.shape[0] == total_patches + assert hidden_states.shape[0] == total_patches, ( + f"[Vision DP] Input patch count mismatch: " + f"hidden_states.shape[0]={hidden_states.shape[0]}, " + f"sum(grid_thw products)={total_patches}, " + f"grid_thw.shape={grid_thw.shape}" + ) + + # Get spatial_merge_size from merger (VLMs like Qwen use merger to reduce embeddings) spatial_merge_size = 1 if hasattr(self, "merger") and hasattr(self.merger, "spatial_merge_size"): spatial_merge_size = self.merger.spatial_merge_size elif hasattr(self, "spatial_merge_size"): spatial_merge_size = self.spatial_merge_size + # Calculate embedding counts (after merger) for gather verification embedding_counts = get_image_embedding_counts(grid_thw_cpu, spatial_merge_size) total_embeddings = sum(embedding_counts) @@ -314,16 +370,26 @@ def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): for _ in range(num_deepstack) ] - # Step 4: All-gather + # Step 4: All-gather (contiguous assignment, no reordering needed) # Compute per-rank embedding counts locally (grid_thw is replicated on all ranks) - all_counts = [sum(embedding_counts[i] for i in image_assignments[r]) for r in range(dp_size)] - all_embeddings = gather_vision_embeddings(local_embeddings, dp_group, all_counts) - assert all_embeddings.shape[0] == total_embeddings + all_counts = [ + sum(embedding_counts[i] for i in image_assignments[r]) + for r in range(dp_size) + ] + all_embeddings = GatherVisionEmbeddings.apply( + local_embeddings, dp_group, all_counts + ) + + assert all_embeddings.shape[0] == total_embeddings, ( + f"[Vision DP] Output embedding count mismatch: " + f"all_embeddings.shape[0]={all_embeddings.shape[0]}, " + f"expected={total_embeddings}" + ) # Step 5: All-gather deepstack embeddings (all ranks must participate) if local_deepstack is not None: gathered_deepstack = [ - gather_vision_embeddings(ds, dp_group, all_counts) + GatherVisionEmbeddings.apply(ds, dp_group, all_counts) for ds in local_deepstack ] return all_embeddings, gathered_deepstack diff --git a/tests/utils/test_vision_dp_on_cpu.py b/tests/utils/test_vision_dp_on_cpu.py index becc3703..0255013e 100644 --- a/tests/utils/test_vision_dp_on_cpu.py +++ b/tests/utils/test_vision_dp_on_cpu.py @@ -45,9 +45,9 @@ def test_patch_counts_various_grids_correct_products(self, grid_thw, expected): counts = get_image_patch_counts(torch.tensor(grid_thw)) assert counts == expected - def test_patch_counts_empty_input_returns_empty_list(self): - counts = get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) - assert counts == [] + def test_patch_counts_empty_input_raises_value_error(self): + with pytest.raises(ValueError, match="grid_thw is empty"): + get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) class TestGetImageEmbeddingCounts: @@ -64,9 +64,9 @@ def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, ex counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) assert counts == expected - def test_embedding_counts_empty_input_returns_empty_list(self): - counts = get_image_embedding_counts(torch.empty((0, 3), dtype=torch.long)) - assert counts == [] + def test_embedding_counts_empty_input_raises_value_error(self): + with pytest.raises(ValueError, match="grid_thw is empty"): + get_image_embedding_counts(torch.empty((0, 3), dtype=torch.long)) class TestAssignImagesToDpRanks: @@ -96,10 +96,17 @@ def test_assign_fewer_images_than_ranks_all_assigned(self): all_assigned.update(a) assert all_assigned == {0, 1} - def test_assign_empty_input_returns_empty_lists(self): - assignments, loads = assign_images_to_dp_ranks([], dp_size=4) - assert all(len(a) == 0 for a in assignments) - assert all(load == 0 for load in loads) + def test_assign_empty_input_raises_value_error(self): + with pytest.raises(ValueError, match="patch_counts is empty"): + assign_images_to_dp_ranks([], dp_size=4) + + def test_assign_zero_dp_size_raises_value_error(self): + with pytest.raises(ValueError, match="dp_size must be positive"): + assign_images_to_dp_ranks([100], dp_size=0) + + def test_assign_negative_dp_size_raises_value_error(self): + with pytest.raises(ValueError, match="dp_size must be positive"): + assign_images_to_dp_ranks([100], dp_size=-1) def test_assign_image_order_preserved_contiguous(self): assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) @@ -178,6 +185,20 @@ def test_prepare_local_inputs_grid_thw_values_preserved(self): assert torch.equal(local_grid[0], grid_thw[0]) assert torch.equal(local_grid[1], grid_thw[1]) + def test_prepare_out_of_range_dp_rank_raises_value_error(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0]] + with pytest.raises(ValueError, match="dp_rank=1 out of range"): + prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + + def test_prepare_negative_dp_rank_raises_value_error(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0]] + with pytest.raises(ValueError, match="dp_rank=-1 out of range"): + prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=-1) + class TestGatherVisionEmbeddings: def test_gather_embeddings_none_group_returns_input_unchanged(self): From e39b496bbc8129faf448aea83ff5e238393390ac Mon Sep 17 00:00:00 2001 From: aoshen524 Date: Thu, 5 Mar 2026 15:23:47 +0900 Subject: [PATCH 8/8] feat(vision_dp): add FSDP2 strategy support for Vision DP Call apply_vision_dp_patch() in fsdp2_strategy.py after set_upg_manager(), mirroring the existing pattern in deepspeed_strategy.py. This ensures Vision DP works correctly with FSDP2, not just DeepSpeed. Co-Authored-By: Claude Opus 4.6 --- roll/distributed/strategy/fsdp2_strategy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/roll/distributed/strategy/fsdp2_strategy.py b/roll/distributed/strategy/fsdp2_strategy.py index 389ff9cb..15940f1b 100644 --- a/roll/distributed/strategy/fsdp2_strategy.py +++ b/roll/distributed/strategy/fsdp2_strategy.py @@ -35,7 +35,7 @@ from roll.third_party.fsdp2.model_update import FSDP2WeightUpdater from roll.utils.checkpoint_manager import CheckpointManager, download_model from roll.utils.collective import collective -from roll.utils.context_parallel import get_ulysses_group, set_upg_manager +from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager from roll.utils.context_parallel.autograd_gather import ulysses_gather from roll.utils.context_parallel.rmpad_ulysses import ( gather_outputs_and_unpad, @@ -570,6 +570,8 @@ def _prepare_fsdp2_model( rank=global_rank, world_size=world_size, ) + if self.worker_config.model_args.vision_dp: + apply_vision_dp_patch() else: cp_size = 1