diff --git a/roll/configs/model_args.py b/roll/configs/model_args.py index c9b8b844..63664587 100644 --- a/roll/configs/model_args.py +++ b/roll/configs/model_args.py @@ -119,6 +119,10 @@ class ModelArguments(LoraArguments): default=1, metadata={"help": "The group size for Ulysses attention."}, ) + vision_dp: bool = field( + default=False, + metadata={"help": "Enable Vision DP: distribute ViT across Ulysses SP ranks."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/roll/distributed/strategy/deepspeed_strategy.py b/roll/distributed/strategy/deepspeed_strategy.py index 58b7e1b4..498b3c6f 100644 --- a/roll/distributed/strategy/deepspeed_strategy.py +++ b/roll/distributed/strategy/deepspeed_strategy.py @@ -23,7 +23,7 @@ from roll.third_party.deepspeed.model_update import DeepSpeedWeightUpdater from roll.third_party.deepspeed.offload_states_patch import bind_deepspeed_offload_states_func from roll.utils.collective import collective -from roll.utils.context_parallel import get_ulysses_group, set_upg_manager +from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits from roll.utils.constants import IGNORE_INDEX @@ -69,6 +69,8 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: if current_platform.apply_ulysses_patch() is not None: set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) + if self.worker_config.model_args.vision_dp: + apply_vision_dp_patch() else: cp_size = 1 @@ -332,6 +334,8 @@ def initialize(self, model_provider): if (cp_size := self.worker_config.model_args.ulysses_size) > 1: current_platform.apply_ulysses_patch() set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size) + if self.worker_config.model_args.vision_dp: + apply_vision_dp_patch() self.worker.rank_info.dp_rank = global_rank // cp_size self.worker.rank_info.dp_size = world_size // cp_size diff --git a/roll/distributed/strategy/fsdp2_strategy.py b/roll/distributed/strategy/fsdp2_strategy.py index 389ff9cb..15940f1b 100644 --- a/roll/distributed/strategy/fsdp2_strategy.py +++ b/roll/distributed/strategy/fsdp2_strategy.py @@ -35,7 +35,7 @@ from roll.third_party.fsdp2.model_update import FSDP2WeightUpdater from roll.utils.checkpoint_manager import CheckpointManager, download_model from roll.utils.collective import collective -from roll.utils.context_parallel import get_ulysses_group, set_upg_manager +from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager from roll.utils.context_parallel.autograd_gather import ulysses_gather from roll.utils.context_parallel.rmpad_ulysses import ( gather_outputs_and_unpad, @@ -570,6 +570,8 @@ def _prepare_fsdp2_model( rank=global_rank, world_size=world_size, ) + if self.worker_config.model_args.vision_dp: + apply_vision_dp_patch() else: cp_size = 1 diff --git a/roll/utils/context_parallel/__init__.py b/roll/utils/context_parallel/__init__.py index 8112b8d2..cd3f0101 100644 --- a/roll/utils/context_parallel/__init__.py +++ b/roll/utils/context_parallel/__init__.py @@ -1,4 +1,17 @@ from roll.utils.context_parallel.globals import get_ulysses_group, set_upg_manager -from roll.utils.context_parallel.monkey_patch import apply_ulysses_patch, unapply_ulysses_patch +from roll.utils.context_parallel.monkey_patch import ( + apply_ulysses_patch, + apply_vision_dp_patch, + unapply_ulysses_patch, + unapply_vision_dp_patch, +) -__all__ = ["set_upg_manager", "get_ulysses_group", "apply_ulysses_patch", "unapply_ulysses_patch"] + +__all__ = [ + "set_upg_manager", + "get_ulysses_group", + "apply_ulysses_patch", + "apply_vision_dp_patch", + "unapply_ulysses_patch", + "unapply_vision_dp_patch", +] diff --git a/roll/utils/context_parallel/monkey_patch.py b/roll/utils/context_parallel/monkey_patch.py index a98ec66d..499005d5 100644 --- a/roll/utils/context_parallel/monkey_patch.py +++ b/roll/utils/context_parallel/monkey_patch.py @@ -13,6 +13,9 @@ else: old_update_causal_mask = None +# Store original vision forwards for unapply +_original_vision_forwards = {} + def apply_ulysses_patch(): from .ulysses_attention import _flash_attention_forward, _update_causal_mask @@ -35,6 +38,90 @@ def apply_ulysses_patch(): return patch_info +def _patch_vision_class(cls, key, class_name): + """Patch a single VisionTransformer class with Vision DP, with idempotency guard.""" + from .vision_dp import create_dp_vision_forward + + if getattr(cls, "_vision_dp_patched", False): + return + original = cls.forward + _original_vision_forwards[key] = original + cls.forward = create_dp_vision_forward(original) + cls._vision_dp_patched = True + logger.info(f"Monkey patch {class_name}.forward for Vision DP") + + +def apply_vision_dp_patch(): + """Patch VisionTransformer.forward for Vision Data Parallel. + + Distributes whole images across Ulysses SP ranks for parallelized ViT computation. + Each rank processes 1/sp_size of images, then all-gathers embeddings. + + This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction). + Safe to call multiple times -- each class is only patched once. + """ + # Patch Qwen2-VL VisionTransformer + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + + _patch_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl", "Qwen2VisionTransformerPretrainedModel") + except ImportError as e: + logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}") + + # Patch Qwen2.5-VL VisionTransformer + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) + + _patch_vision_class( + Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl", "Qwen2_5_VisionTransformerPretrainedModel" + ) + except ImportError as e: + logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}") + + # Patch Qwen3-VL VisionModel + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + _patch_vision_class(Qwen3VLVisionModel, "qwen3_vl", "Qwen3VLVisionModel") + except ImportError as e: + logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}") + + +def _unapply_vision_class(cls, key): + """Restore a single VisionTransformer class, clearing the idempotency flag.""" + if key in _original_vision_forwards: + cls.forward = _original_vision_forwards.pop(key) + cls._vision_dp_patched = False + + +def unapply_vision_dp_patch(): + """Restore original VisionTransformer.forward methods.""" + try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel + + _unapply_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl") + except ImportError: + pass + + try: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + ) + + _unapply_vision_class(Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl") + except ImportError: + pass + + try: + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + _unapply_vision_class(Qwen3VLVisionModel, "qwen3_vl") + except ImportError: + pass + + def unapply_ulysses_patch(): global old_flash_attention_forward, old_update_causal_mask ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = old_flash_attention_forward @@ -47,3 +134,4 @@ def unapply_ulysses_patch(): unapply_hf_flash_attention_ulysses_patch() except Exception: pass + unapply_vision_dp_patch() diff --git a/roll/utils/context_parallel/vision_dp.py b/roll/utils/context_parallel/vision_dp.py new file mode 100644 index 00000000..8a070789 --- /dev/null +++ b/roll/utils/context_parallel/vision_dp.py @@ -0,0 +1,399 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Alibaba Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Vision Data Parallel utilities for distributing ViT computation across Ulysses SP ranks. + +Distribute whole images across SP ranks, not patches within images. +Each rank runs ViT on its assigned images, then all-gather combines embeddings. +Backward all_reduce(SUM) recovers complete gradients before slicing by assignment. + +Ported from verl (https://github.com/verl-project/verl/pull/5230). +""" + +import torch +import torch.distributed as dist +from torch.autograd import Function + +from roll.utils.context_parallel.globals import get_ulysses_group, get_ulysses_size + + +def get_image_patch_counts(grid_thw: torch.Tensor) -> list[int]: + """Return [t*h*w for each image] from a [num_images, 3] grid_thw tensor.""" + if grid_thw.numel() == 0: + raise ValueError("grid_thw is empty — Vision DP should only be called when images are present") + return (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).tolist() + + +def get_image_embedding_counts(grid_thw: torch.Tensor, spatial_merge_size: int = 1) -> list[int]: + """Return per-image embedding counts after spatial merging: t * (h/merge) * (w/merge).""" + if grid_thw.numel() == 0: + raise ValueError("grid_thw is empty — Vision DP should only be called when images are present") + + if spatial_merge_size == 1: + return get_image_patch_counts(grid_thw) + + # Apply spatial merging: h and w are divided by spatial_merge_size + t = grid_thw[:, 0] + h = grid_thw[:, 1] // spatial_merge_size + w = grid_thw[:, 2] // spatial_merge_size + return (t * h * w).tolist() + + +def assign_images_to_dp_ranks( + patch_counts: list[int], + dp_size: int, +) -> tuple[list[list[int]], list[int]]: + """Assign whole images to DP ranks via greedy contiguous bin-packing. + + Returns (image_assignments, rank_patch_counts). Images are kept contiguous + so the gather result needs no reordering. + """ + if dp_size <= 0: + raise ValueError(f"dp_size must be positive, got {dp_size}") + + num_images = len(patch_counts) + if num_images == 0: + raise ValueError("patch_counts is empty — Vision DP should only be called when images are present") + + image_assignments: list[list[int]] = [[] for _ in range(dp_size)] + rank_loads = [0] * dp_size + + remaining_patches = sum(patch_counts) + img_idx = 0 + for rank in range(dp_size): + remaining_ranks = dp_size - rank + remaining_images = num_images - img_idx + + if remaining_images <= 0: + break + + # Dynamic target: distribute remaining patches evenly among remaining ranks + target = remaining_patches / remaining_ranks + + # Must leave at least 1 image for each remaining rank + max_images = remaining_images - (remaining_ranks - 1) + + # Greedily add images until we reach the target load or hit the max + count = 0 + while img_idx < num_images and count < max_images: + image_assignments[rank].append(img_idx) + rank_loads[rank] += patch_counts[img_idx] + img_idx += 1 + count += 1 + + # Stop early once we've reached the target (always take at least 1) + if rank_loads[rank] >= target: + break + + remaining_patches -= rank_loads[rank] + + return image_assignments, rank_loads + + +def prepare_local_vision_inputs( + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + image_assignments: list[list[int]], + dp_rank: int, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + """Extract pixel values and grid_thw for this DP rank's assigned images. + + Exploits contiguous assignment: a single slice instead of per-image cat. + """ + if dp_rank < 0 or dp_rank >= len(image_assignments): + raise ValueError( + f"dp_rank={dp_rank} out of range for image_assignments with " + f"{len(image_assignments)} ranks" + ) + + local_indices = image_assignments[dp_rank] + + if len(local_indices) == 0: + return ( + torch.empty( + (0, pixel_values.shape[1]) if pixel_values.dim() > 1 else (0,), + dtype=pixel_values.dtype, + device=pixel_values.device, + ), + torch.empty((0, 3), dtype=grid_thw.dtype, device=grid_thw.device), + [], + ) + + # local_indices are contiguous (e.g. [2, 3, 4]), so use tensor slicing + first_img_idx = local_indices[0] + last_img_idx = local_indices[-1] + + # Compute patch offsets using cumsum (grid_thw may be on CPU or GPU) + patch_counts = grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2] + offsets = torch.cat( + ( + torch.zeros(1, device=grid_thw.device, dtype=patch_counts.dtype), + torch.cumsum(patch_counts, dim=0), + ) + ) + + start_patch = int(offsets[first_img_idx].item()) + end_patch = int(offsets[last_img_idx + 1].item()) + + local_pixel_values = pixel_values[start_patch:end_patch] + local_grid_thw = grid_thw[first_img_idx : last_img_idx + 1] + + # Cross-check: verify extracted slice matches independently computed patch counts + independent_counts = get_image_patch_counts(local_grid_thw) + expected_patches = sum(independent_counts) + assert local_pixel_values.shape[0] == expected_patches, ( + f"[Vision DP] Local patch count mismatch: " + f"extracted={local_pixel_values.shape[0]}, expected={expected_patches}, " + f"local_indices={local_indices}" + ) + + return local_pixel_values, local_grid_thw, local_indices + + +class GatherVisionEmbeddings(Function): + """ + All-gather vision embeddings with gradient support. + + Since images are assigned contiguously (rank 0 gets [0,1], rank 1 gets [2,3], etc.), + we can simply concat gathered results without reordering. + + Forward: all_gather + remove padding + concat + Backward: all_reduce(SUM) to aggregate gradients from all sequence shards, + then slice to extract this rank's image gradients + """ + + @staticmethod + def forward( + ctx, + local_embeddings: torch.Tensor, + dp_group, + all_counts: list[int], + ) -> torch.Tensor: + dp_size = dist.get_world_size(dp_group) + if dp_size <= 1: + raise RuntimeError( + "GatherVisionEmbeddings.forward called with dp_size=1. " + "Caller should short-circuit before reaching here." + ) + dp_rank = dist.get_rank(dp_group) + ctx.dp_size = dp_size + ctx.dp_group = dp_group + ctx.all_counts = all_counts + ctx.dp_rank = dp_rank + + if not all_counts or len(all_counts) != dp_size: + raise ValueError( + f"all_counts length ({len(all_counts) if all_counts else 0}) " + f"must equal dp_size ({dp_size})" + ) + + max_count = max(all_counts) + if max_count == 0: + raise RuntimeError( + "all_counts are all zero — Vision DP gather should not be called " + "when no images are present" + ) + + hidden_size = local_embeddings.shape[1] if local_embeddings.dim() > 1 else 1 + + # Pad to same length for all_gather + if local_embeddings.shape[0] < max_count: + pad_size = max_count - local_embeddings.shape[0] + padding = torch.zeros( + (pad_size, hidden_size), + dtype=local_embeddings.dtype, + device=local_embeddings.device, + ) + local_padded = torch.cat([local_embeddings, padding], dim=0) + else: + local_padded = local_embeddings + + # All-gather + gathered = [torch.empty_like(local_padded) for _ in range(dp_size)] + dist.all_gather(gathered, local_padded, group=dp_group) + + # Remove padding and concat (no reordering needed - contiguous assignment) + result_chunks = [gathered[r][: all_counts[r]] for r in range(dp_size)] + result = torch.cat(result_chunks, dim=0) + + return result + + @staticmethod + def backward(ctx, grad_output): + dp_size = ctx.dp_size + assert dp_size > 1, ( + f"GatherVisionEmbeddings.backward reached with dp_size={dp_size}. " + "Forward should never be called with dp_size<=1." + ) + + all_counts = ctx.all_counts + dp_rank = ctx.dp_rank + dp_group = ctx.dp_group + + # all_reduce(SUM) aggregates partial gradients from all SP ranks: + # each rank only has non-zero grad for vision tokens in its sequence shard. + if not grad_output.is_cuda: + raise RuntimeError( + "GatherVisionEmbeddings.backward requires CUDA tensors (NCCL backend). " + f"Got device={grad_output.device}" + ) + # NCCL all_reduce requires contiguous tensors. In the real training path + # (masked_scatter_backward → view), grad is already contiguous (no-op). + # Kept as defensive guard against upstream autograd changes. + grad = grad_output.contiguous() + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=dp_group) + + # Extract gradients for this rank's images (contiguous slice) + start = sum(all_counts[:dp_rank]) + end = start + all_counts[dp_rank] + local_grad = grad[start:end] + + return local_grad, None, None + + +def gather_vision_embeddings(local_embeddings, dp_group, all_counts: list[int]): + """All-gather vision embeddings from all DP ranks with gradient support.""" + dp_group = get_ulysses_group() if dp_group is None else dp_group + if dp_group is None or dist.get_world_size(dp_group) == 1: + return local_embeddings + return GatherVisionEmbeddings.apply(local_embeddings, dp_group, all_counts) + + +def create_dp_vision_forward(original_forward): + """Wrap VisionTransformer.forward for Vision DP (Data Parallel across SP ranks). + + Strategy: + 1. Distribute whole images to SP ranks (not patches within images) + 2. Each rank processes its assigned images independently + 3. All-gather embeddings at the end (contiguous assignment, no reordering) + + Gradient correctness: after all-gather in forward, each SP rank's inputs_embeds + contains vision tokens from ALL images. But Ulysses gives each rank only its + sequence shard. In backward, each rank only has non-zero gradient for vision + tokens in its own shard. The all_reduce(SUM) in GatherVisionEmbeddings.backward + aggregates partial gradients from all ranks, recovering the complete gradient. + """ + + def dp_vision_forward(self, hidden_states, grid_thw, **kwargs): + dp_size = get_ulysses_size() + if dp_size is None or dp_size <= 1: + raise RuntimeError( + f"sp_size={dp_size}, Vision DP should not be active — " + "monkey-patch is only applied when sp_size > 1" + ) + + dp_group = get_ulysses_group() + dp_rank = dist.get_rank(dp_group) + + # Move grid_thw to CPU once to avoid repeated GPU->CPU syncs in + # metadata helpers (grid_thw is a tiny [num_images, 3] tensor). + grid_thw_cpu = grid_thw.cpu() + + # Step 1: Get image assignment based on patch counts + patch_counts = get_image_patch_counts(grid_thw_cpu) + total_patches = sum(patch_counts) + + assert hidden_states.shape[0] == total_patches, ( + f"[Vision DP] Input patch count mismatch: " + f"hidden_states.shape[0]={hidden_states.shape[0]}, " + f"sum(grid_thw products)={total_patches}, " + f"grid_thw.shape={grid_thw.shape}" + ) + + # Get spatial_merge_size from merger (VLMs like Qwen use merger to reduce embeddings) + spatial_merge_size = 1 + if hasattr(self, "merger") and hasattr(self.merger, "spatial_merge_size"): + spatial_merge_size = self.merger.spatial_merge_size + elif hasattr(self, "spatial_merge_size"): + spatial_merge_size = self.spatial_merge_size + + # Calculate embedding counts (after merger) for gather verification + embedding_counts = get_image_embedding_counts(grid_thw_cpu, spatial_merge_size) + total_embeddings = sum(embedding_counts) + + image_assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size) + + # Step 2: Extract local inputs + local_pixels, local_grid_thw, local_indices = prepare_local_vision_inputs( + hidden_states, grid_thw, image_assignments, dp_rank + ) + + # Detect Qwen3-VL deepstack: model attribute, not return type, + # because empty ranks don't call original_forward and can't inspect the return. + has_deepstack = hasattr(self, "deepstack_merger_list") + + # Step 3: Process local images + if local_pixels.shape[0] > 0: + local_embeddings = original_forward(self, local_pixels, local_grid_thw, **kwargs) + else: + # This rank has no images, create empty tensor with correct hidden size + hidden_size = getattr(getattr(self, "config", None), "out_hidden_size", None) + if hidden_size is None: + raise RuntimeError( + f"Cannot determine hidden_size: self.config.out_hidden_size not found. " + f"Model type: {type(self).__name__}" + ) + + local_embeddings = torch.empty( + (0, hidden_size), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + # Empty rank must participate in autograd for backward all_reduce + local_embeddings.requires_grad_() + + # Unpack Qwen3-VL deepstack: forward returns (embeddings, list[3 × Tensor]) + local_deepstack = None + if has_deepstack: + if isinstance(local_embeddings, tuple): + local_embeddings, local_deepstack = local_embeddings[0], local_embeddings[1] + else: + # Empty rank: create matching empty deepstack tensors + num_deepstack = len(self.deepstack_merger_list) + h = local_embeddings.shape[1] + local_deepstack = [ + torch.empty( + (0, h), dtype=hidden_states.dtype, device=hidden_states.device + ) + for _ in range(num_deepstack) + ] + + # Step 4: All-gather (contiguous assignment, no reordering needed) + # Compute per-rank embedding counts locally (grid_thw is replicated on all ranks) + all_counts = [ + sum(embedding_counts[i] for i in image_assignments[r]) + for r in range(dp_size) + ] + all_embeddings = GatherVisionEmbeddings.apply( + local_embeddings, dp_group, all_counts + ) + + assert all_embeddings.shape[0] == total_embeddings, ( + f"[Vision DP] Output embedding count mismatch: " + f"all_embeddings.shape[0]={all_embeddings.shape[0]}, " + f"expected={total_embeddings}" + ) + + # Step 5: All-gather deepstack embeddings (all ranks must participate) + if local_deepstack is not None: + gathered_deepstack = [ + GatherVisionEmbeddings.apply(ds, dp_group, all_counts) + for ds in local_deepstack + ] + return all_embeddings, gathered_deepstack + + return all_embeddings + + return dp_vision_forward diff --git a/tests/utils/test_vision_dp_on_cpu.py b/tests/utils/test_vision_dp_on_cpu.py new file mode 100644 index 00000000..0255013e --- /dev/null +++ b/tests/utils/test_vision_dp_on_cpu.py @@ -0,0 +1,250 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Alibaba Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for Vision Data Parallel utilities (CPU-only, no distributed). + +Ported from verl (https://github.com/verl-project/verl/pull/5230). +Test naming convention: test___() +""" + +import pytest +import torch + +from roll.utils.context_parallel.vision_dp import ( + assign_images_to_dp_ranks, + gather_vision_embeddings, + get_image_embedding_counts, + get_image_patch_counts, + prepare_local_vision_inputs, +) + + +class TestGetImagePatchCounts: + @pytest.mark.parametrize( + "grid_thw,expected", + [ + ([[2, 4, 4], [1, 2, 2], [1, 8, 8]], [32, 4, 64]), + ([[1, 4, 4]], [16]), + ([[4, 4, 4]], [64]), + ], + ids=["multi-image", "single-image", "video-frames"], + ) + def test_patch_counts_various_grids_correct_products(self, grid_thw, expected): + counts = get_image_patch_counts(torch.tensor(grid_thw)) + assert counts == expected + + def test_patch_counts_empty_input_raises_value_error(self): + with pytest.raises(ValueError, match="grid_thw is empty"): + get_image_patch_counts(torch.empty((0, 3), dtype=torch.long)) + + +class TestGetImageEmbeddingCounts: + @pytest.mark.parametrize( + "grid_thw,merge_size,expected", + [ + ([[1, 8, 8]], 1, [64]), + ([[1, 8, 8]], 2, [16]), + ([[1, 6, 6], [1, 4, 4]], 2, [9, 4]), + ], + ids=["no-merge", "merge-2", "multi-image-merge"], + ) + def test_embedding_counts_with_merge_size_correct(self, grid_thw, merge_size, expected): + counts = get_image_embedding_counts(torch.tensor(grid_thw), merge_size) + assert counts == expected + + def test_embedding_counts_empty_input_raises_value_error(self): + with pytest.raises(ValueError, match="grid_thw is empty"): + get_image_embedding_counts(torch.empty((0, 3), dtype=torch.long)) + + +class TestAssignImagesToDpRanks: + @pytest.mark.parametrize( + "patch_counts,dp_size,expected_all_assigned", + [ + ([100, 100, 100, 100], 2, True), + ([100, 200, 300], 1, True), + ([100, 100, 100, 100, 100, 100], 3, True), + ], + ids=["balanced-2ranks", "single-rank", "balanced-3ranks"], + ) + def test_assign_all_images_distributed_correctly(self, patch_counts, dp_size, expected_all_assigned): + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == list(range(len(patch_counts))) + assert sum(loads) == sum(patch_counts) + + def test_assign_fewer_images_than_ranks_all_assigned(self): + assignments, loads = assign_images_to_dp_ranks([100, 200], dp_size=4) + non_empty = sum(1 for a in assignments if len(a) > 0) + assert non_empty == 2 + all_assigned = set() + for a in assignments: + all_assigned.update(a) + assert all_assigned == {0, 1} + + def test_assign_empty_input_raises_value_error(self): + with pytest.raises(ValueError, match="patch_counts is empty"): + assign_images_to_dp_ranks([], dp_size=4) + + def test_assign_zero_dp_size_raises_value_error(self): + with pytest.raises(ValueError, match="dp_size must be positive"): + assign_images_to_dp_ranks([100], dp_size=0) + + def test_assign_negative_dp_size_raises_value_error(self): + with pytest.raises(ValueError, match="dp_size must be positive"): + assign_images_to_dp_ranks([100], dp_size=-1) + + def test_assign_image_order_preserved_contiguous(self): + assignments, _ = assign_images_to_dp_ranks([10, 20, 30, 40, 50], dp_size=2) + for rank_assignment in assignments: + assert rank_assignment == sorted(rank_assignment) + + def test_assign_load_balanced_unequal_patches_reduces_imbalance(self): + """With unequal patch counts, greedy balancing should reduce imbalance.""" + patch_counts = [4096, 256, 256, 256] + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3] + max_load = max(loads) + min_load = min(load for load in loads if load > 0) + assert max_load / min_load < 8.0 + + def test_assign_contiguous_coverage_all_dp_sizes(self): + """All images are covered exactly once across ranks for various dp_size.""" + patch_counts = [10, 20, 30, 40, 50, 60, 70] + for dp_size in [1, 2, 3, 4, 7]: + assignments, _ = assign_images_to_dp_ranks(patch_counts, dp_size) + all_indices = [] + for a in assignments: + all_indices.extend(a) + assert sorted(all_indices) == list(range(len(patch_counts))) + + +class TestPrepareLocalVisionInputs: + def test_prepare_two_images_splits_correctly(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 6, 6], [1, 8, 8]]) # 36 + 64 = 100 + image_assignments = [[0], [1]] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert pix.shape[0] == 36 + assert grid.shape[0] == 1 + assert indices == [0] + assert torch.allclose(pix, pixel_values[:36]) + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + assert pix.shape[0] == 64 + assert grid.shape[0] == 1 + assert indices == [1] + assert torch.allclose(pix, pixel_values[36:100]) + + def test_prepare_multiple_contiguous_images_per_rank(self): + pixel_values = torch.randn(200, 768) + grid_thw = torch.tensor([[1, 5, 10]] * 4) # 4 x 50 patches + image_assignments = [[0, 1], [2, 3]] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert pix.shape[0] == 100 + assert grid.shape[0] == 2 + assert indices == [0, 1] + assert torch.allclose(pix, pixel_values[:100]) + + def test_prepare_empty_rank_returns_empty_tensors(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0], []] + + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + assert pix.shape[0] == 0 + assert grid.shape[0] == 0 + assert indices == [] + + def test_prepare_local_inputs_grid_thw_values_preserved(self): + pixel_values = torch.randn(150, 768) + grid_thw = torch.tensor([[1, 5, 5], [2, 5, 5], [3, 5, 5]]) # 25 + 50 + 75 + image_assignments = [[0, 1], [2]] + + _, local_grid, _ = prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=0) + assert local_grid.shape == (2, 3) + assert torch.equal(local_grid[0], grid_thw[0]) + assert torch.equal(local_grid[1], grid_thw[1]) + + def test_prepare_out_of_range_dp_rank_raises_value_error(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0]] + with pytest.raises(ValueError, match="dp_rank=1 out of range"): + prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=1) + + def test_prepare_negative_dp_rank_raises_value_error(self): + pixel_values = torch.randn(100, 768) + grid_thw = torch.tensor([[1, 10, 10]]) + image_assignments = [[0]] + with pytest.raises(ValueError, match="dp_rank=-1 out of range"): + prepare_local_vision_inputs(pixel_values, grid_thw, image_assignments, dp_rank=-1) + + +class TestGatherVisionEmbeddings: + def test_gather_embeddings_none_group_returns_input_unchanged(self): + embeddings = torch.randn(10, 64) + result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) + assert torch.equal(result, embeddings) + + def test_gather_embeddings_none_group_same_storage(self): + """Single-rank group should short-circuit and return same tensor (not a copy).""" + embeddings = torch.randn(10, 64) + result = gather_vision_embeddings(embeddings, dp_group=None, all_counts=[10]) + assert result.data_ptr() == embeddings.data_ptr() + + +class TestIntegration: + def test_full_workflow_all_patches_covered(self): + grid_thw = torch.tensor([[1, 4, 4], [1, 8, 8], [1, 4, 4], [1, 6, 6], [1, 4, 4]]) + total_patches = 16 + 64 + 16 + 36 + 16 # 148 + pixel_values = torch.randn(total_patches, 768) + + patch_counts = get_image_patch_counts(grid_thw) + assert patch_counts == [16, 64, 16, 36, 16] + + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=2) + all_assigned = [] + for a in assignments: + all_assigned.extend(a) + assert sorted(all_assigned) == [0, 1, 2, 3, 4] + + total_local_patches = 0 + for rank in range(2): + pix, grid, indices = prepare_local_vision_inputs(pixel_values, grid_thw, assignments, dp_rank=rank) + expected = sum(patch_counts[i] for i in indices) + assert pix.shape[0] == expected + assert grid.shape[0] == len(indices) + total_local_patches += pix.shape[0] + + assert total_local_patches == total_patches + + def test_same_size_images_4_ranks_balanced(self): + num_images = 50 + grid_thw = torch.tensor([[1, 8, 8]] * num_images) + patch_counts = get_image_patch_counts(grid_thw) + assignments, loads = assign_images_to_dp_ranks(patch_counts, dp_size=4) + + for rank in range(4): + assert 12 <= len(assignments[rank]) <= 13 + for load in loads: + assert load in [768, 832]