Skip to content
4 changes: 4 additions & 0 deletions roll/configs/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class ModelArguments(LoraArguments):
default=1,
metadata={"help": "The group size for Ulysses attention."},
)
vision_dp: bool = field(
default=False,
metadata={"help": "Enable Vision DP: distribute ViT across Ulysses SP ranks."},
)

def __post_init__(self):
def split_arg(arg):
Expand Down
6 changes: 5 additions & 1 deletion roll/distributed/strategy/deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from roll.third_party.deepspeed.model_update import DeepSpeedWeightUpdater
from roll.third_party.deepspeed.offload_states_patch import bind_deepspeed_offload_states_func
from roll.utils.collective import collective
from roll.utils.context_parallel import get_ulysses_group, set_upg_manager
from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager
from roll.utils.deepspeed_utils import get_optimizer_grouped_parameters
from roll.utils.functionals import append_to_dict, entropy_from_logits, log_probs_from_logits
from roll.utils.constants import IGNORE_INDEX
Expand Down Expand Up @@ -69,6 +69,8 @@ def initialize(self, model_provider):
if (cp_size := self.worker_config.model_args.ulysses_size) > 1:
if current_platform.apply_ulysses_patch() is not None:
set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size)
if self.worker_config.model_args.vision_dp:
apply_vision_dp_patch()
else:
cp_size = 1

Expand Down Expand Up @@ -332,6 +334,8 @@ def initialize(self, model_provider):
if (cp_size := self.worker_config.model_args.ulysses_size) > 1:
current_platform.apply_ulysses_patch()
set_upg_manager(ulysses_size=cp_size, rank=global_rank, world_size=world_size)
if self.worker_config.model_args.vision_dp:
apply_vision_dp_patch()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems vision_dp also suit to fsdp strategy in the same way andapply_vision_dp_patch have to be called manually since not included in apply_ulysses_patch , could you please support it in fsdp_strategy too

Copy link
Author

@aoshen524 aoshen524 Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done.


self.worker.rank_info.dp_rank = global_rank // cp_size
self.worker.rank_info.dp_size = world_size // cp_size
Expand Down
4 changes: 3 additions & 1 deletion roll/distributed/strategy/fsdp2_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from roll.third_party.fsdp2.model_update import FSDP2WeightUpdater
from roll.utils.checkpoint_manager import CheckpointManager, download_model
from roll.utils.collective import collective
from roll.utils.context_parallel import get_ulysses_group, set_upg_manager
from roll.utils.context_parallel import apply_vision_dp_patch, get_ulysses_group, set_upg_manager
from roll.utils.context_parallel.autograd_gather import ulysses_gather
from roll.utils.context_parallel.rmpad_ulysses import (
gather_outputs_and_unpad,
Expand Down Expand Up @@ -570,6 +570,8 @@ def _prepare_fsdp2_model(
rank=global_rank,
world_size=world_size,
)
if self.worker_config.model_args.vision_dp:
apply_vision_dp_patch()
else:
cp_size = 1

Expand Down
17 changes: 15 additions & 2 deletions roll/utils/context_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from roll.utils.context_parallel.globals import get_ulysses_group, set_upg_manager
from roll.utils.context_parallel.monkey_patch import apply_ulysses_patch, unapply_ulysses_patch
from roll.utils.context_parallel.monkey_patch import (
apply_ulysses_patch,
apply_vision_dp_patch,
unapply_ulysses_patch,
unapply_vision_dp_patch,
)

__all__ = ["set_upg_manager", "get_ulysses_group", "apply_ulysses_patch", "unapply_ulysses_patch"]

__all__ = [
"set_upg_manager",
"get_ulysses_group",
"apply_ulysses_patch",
"apply_vision_dp_patch",
"unapply_ulysses_patch",
"unapply_vision_dp_patch",
]
88 changes: 88 additions & 0 deletions roll/utils/context_parallel/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
else:
old_update_causal_mask = None

# Store original vision forwards for unapply
_original_vision_forwards = {}


def apply_ulysses_patch():
from .ulysses_attention import _flash_attention_forward, _update_causal_mask
Expand All @@ -35,6 +38,90 @@ def apply_ulysses_patch():
return patch_info


def _patch_vision_class(cls, key, class_name):
"""Patch a single VisionTransformer class with Vision DP, with idempotency guard."""
from .vision_dp import create_dp_vision_forward

if getattr(cls, "_vision_dp_patched", False):
return
original = cls.forward
_original_vision_forwards[key] = original
cls.forward = create_dp_vision_forward(original)
cls._vision_dp_patched = True
logger.info(f"Monkey patch {class_name}.forward for Vision DP")


def apply_vision_dp_patch():
"""Patch VisionTransformer.forward for Vision Data Parallel.

Distributes whole images across Ulysses SP ranks for parallelized ViT computation.
Each rank processes 1/sp_size of images, then all-gathers embeddings.

This reduces ViT peak memory by ~sp_size x (e.g. SP=4 -> ~4x reduction).
Safe to call multiple times -- each class is only patched once.
"""
# Patch Qwen2-VL VisionTransformer
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

_patch_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl", "Qwen2VisionTransformerPretrainedModel")
except ImportError as e:
logger.debug(f"Qwen2-VL not available for Vision DP patch: {e}")

# Patch Qwen2.5-VL VisionTransformer
try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)

_patch_vision_class(
Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl", "Qwen2_5_VisionTransformerPretrainedModel"
)
except ImportError as e:
logger.debug(f"Qwen2.5-VL not available for Vision DP patch: {e}")

# Patch Qwen3-VL VisionModel
try:
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel

_patch_vision_class(Qwen3VLVisionModel, "qwen3_vl", "Qwen3VLVisionModel")
except ImportError as e:
logger.debug(f"Qwen3-VL not available for Vision DP patch: {e}")


def _unapply_vision_class(cls, key):
"""Restore a single VisionTransformer class, clearing the idempotency flag."""
if key in _original_vision_forwards:
cls.forward = _original_vision_forwards.pop(key)
cls._vision_dp_patched = False


def unapply_vision_dp_patch():
"""Restore original VisionTransformer.forward methods."""
try:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel

_unapply_vision_class(Qwen2VisionTransformerPretrainedModel, "qwen2_vl")
except ImportError:
pass

try:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
)

_unapply_vision_class(Qwen2_5_VisionTransformerPretrainedModel, "qwen2_5_vl")
except ImportError:
pass

try:
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel

_unapply_vision_class(Qwen3VLVisionModel, "qwen3_vl")
except ImportError:
pass


def unapply_ulysses_patch():
global old_flash_attention_forward, old_update_causal_mask
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = old_flash_attention_forward
Expand All @@ -47,3 +134,4 @@ def unapply_ulysses_patch():
unapply_hf_flash_attention_ulysses_patch()
except Exception:
pass
unapply_vision_dp_patch()
Loading