From 64b4a7cf106432813e222aaca502aea12c0d124b Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Thu, 26 Feb 2026 11:54:47 +0000 Subject: [PATCH 01/10] init support qwen3.5 --- xtuner/v1/model/base.py | 4 +- .../model/compose/qwen3_5/qwen3_5_config.py | 28 +++ xtuner/v1/model/dense/dense.py | 12 +- xtuner/v1/model/moe/moe.py | 18 +- xtuner/v1/model/moe/qwen3_5_text.py | 74 ++++++ xtuner/v1/model/moe/qwen3vl_text.py | 7 +- xtuner/v1/module/attention/__init__.py | 3 + xtuner/v1/module/attention/gate_deltanet.py | 214 ++++++++++++++++++ xtuner/v1/module/attention/mha.py | 29 ++- .../decoder_layer/dense_decoder_layer.py | 5 +- .../module/decoder_layer/moe_decoder_layer.py | 16 +- xtuner/v1/module/rms_norm/rms_norm.py | 17 +- xtuner/v1/module/rope/rope.py | 40 +++- xtuner/v1/ops/__init__.py | 2 +- xtuner/v1/ops/rms_norm/__init__.py | 27 ++- xtuner/v1/ops/rms_norm/protocol.py | 3 +- xtuner/v1/ops/rotary_emb.py | 31 ++- 17 files changed, 499 insertions(+), 31 deletions(-) create mode 100644 xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py create mode 100644 xtuner/v1/model/moe/qwen3_5_text.py create mode 100644 xtuner/v1/module/attention/gate_deltanet.py diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 9f1e0f995..be58b9519 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -39,7 +39,7 @@ WeightWithDynamicTilewiseFloat8CastTensor, ) from xtuner.v1.loss import BaseLossContext -from xtuner.v1.module.attention import MHAConfig, MLAConfig +from xtuner.v1.module.attention import MHAConfig, MLAConfig, GateDeltaNetConfig from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory @@ -147,9 +147,11 @@ class TransformerConfig(XTunerBaseModelConfig): hidden_size: Annotated[int, Parameter(group="model")] intermediate_size: Annotated[int, Parameter(group="model")] rms_norm_eps: Annotated[float, Parameter(group="model")] + rms_norm_type: Annotated[str, Parameter(group="model")] = 'default' # default | zero_centered rope_theta: Annotated[float, Parameter(group="model")] # required by transformers's build rope hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS` attention: MLAConfig | MHAConfig + linear_attention: Annotated[GateDeltaNetConfig | None, Parameter(group="model")] = None mlp_bias: Annotated[bool, Parameter(group="model")] = False tie_word_embeddings: Annotated[bool, Parameter(group="model")] = False model_type: Annotated[str | None, Parameter(group="model")] = None # TODO: yehaochen maybe should be removed diff --git a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py new file mode 100644 index 000000000..c982705af --- /dev/null +++ b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py @@ -0,0 +1,28 @@ +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig +from xtuner.v1.utils import get_logger + +from ..qwen3_vl.qwen3_vl_config import Qwen3VLVisionConfig, Qwen3VLProjectorConfig, Qwen3VLBaseConfig + +logger = get_logger() + +class Qwen3_5_VisionConfig(Qwen3VLVisionConfig): + deepstack_visual_indexes: list[int] = [] + +class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig): + deepstack_visual_indexes: list[int] = [] + +class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): + vision_config: Qwen3_5_VisionConfig + projector_config: Qwen3_5_ProjectorConfig + text_config: TransformerConfig + + image_token_id: int = 248056 + video_token_id: int = 248057 + vision_start_token_id: int = 248053 + vision_end_token_id: int = 248054 + +class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig): + vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() + projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() + text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig() diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index cde398b38..e77b0ecba 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -53,7 +53,7 @@ class Dense(BaseModel): def __init__(self, config: TransformerConfig): super().__init__(config) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, type=config.rms_norm_type) self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) @@ -117,13 +117,21 @@ def build_layers(self, config: TransformerConfig) -> nn.ModuleDict: # 这样可以保证部分 layer 被切掉后,idx 保持不变 layers = nn.ModuleDict() for layer_idx in range(config.num_hidden_layers): + if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]: + attention_config = config.attention + elif config.layers_type[layer_idx] == "linear_attention": + attention_config = config.linear_attention + assert attention_config is not None, "linear_attention config must be provided for linear_attention layer" + else: + raise ValueError(f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported.") + layers[str(layer_idx)] = DenseDecoderLayer( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, mlp_bias=config.mlp_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, - attention_config=config.attention, + attention_config=attention_config, generate_config=config.generate_config, rope_scaling_cfg=config.rope_scaling_cfg, float8_cfg=config.float8_cfg, diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 962129e1d..c60796f50 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -116,6 +116,7 @@ class MoEConfig(TransformerConfig): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") n_routed_experts: Annotated[int, Parameter(group="moe")] n_shared_experts: Annotated[int, Parameter(group="moe")] + with_shared_expert_gate: bool = False # enable when n_shared_experts > 0 num_experts_per_tok: Annotated[int, Parameter(group="moe")] first_k_dense_replace: Annotated[int, Parameter(group="moe")] = 0 hidden_factor: Annotated[float, Parameter(group="moe")] = 1.0 @@ -166,7 +167,7 @@ def __init__(self, config: MoEConfig): else: self.ep_mesh = None - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, type=config.rms_norm_type) self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) self.layers = self.build_layers(config) @@ -593,6 +594,14 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: # 这样可以保证部分 layer 被切掉后,idx 保持不变 layers = nn.ModuleDict() for layer_idx in range(config.num_hidden_layers): + if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]: + attention_config = config.attention + elif config.layers_type[layer_idx] == "linear_attention": + attention_config = config.linear_attention + assert attention_config is not None, "linear_attention config must be provided for linear_attention layer" + else: + raise ValueError(f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported.") + if layer_idx < config.first_k_dense_replace: layers[str(layer_idx)] = DenseDecoderLayer( hidden_size=config.hidden_size, @@ -600,7 +609,8 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: mlp_bias=config.mlp_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, - attention_config=config.attention, + rms_norm_type=config.rms_norm_type, + attention_config=attention_config, layer_type=config.layers_type[layer_idx], rope_scaling_cfg=config.rope_scaling_cfg, generate_config=config.generate_config, @@ -617,12 +627,14 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: moe_bias=config.moe_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, num_experts_per_tok=config.num_experts_per_tok, n_routed_experts=config.n_routed_experts, n_shared_experts=config.n_shared_experts, + with_shared_expert_gate=config.with_shared_expert_gate, hidden_factor=config.hidden_factor, layer_type=config.layers_type[layer_idx], - attention_config=config.attention, + attention_config=attention_config, rope_scaling_cfg=config.rope_scaling_cfg, generate_config=config.generate_config, router_config=config.router, diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py new file mode 100644 index 000000000..35c501040 --- /dev/null +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -0,0 +1,74 @@ + +from pydantic import computed_field +from typing import Literal +import re +from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig +from xtuner.v1.module.attention import MHAConfig, GateDeltaNetConfig +from xtuner.v1.module.router.greedy import GreedyRouterConfig +from xtuner.v1.module.rope import RopeScalingConfig + +from xtuner.v1.model.moe.moe import MoEConfig +from .qwen3vl_text import Qwen3VLTextMoE + + +class Qwen3_5_VLTextMoEConfig(MoEConfig): + with_shared_expert_gate: bool = True + rms_norm_type: Literal["defalut", "zero_centered"] = "zero_centered" + + @computed_field + def layers_type(self) -> list[Literal["full_attention", "linear_attention"]]: + return ["full_attention" if bool((i + 1) % 4) else "linear_attention" for i in range(self.num_hidden_layers)] + + def build(self) -> Qwen3VLTextMoE: + return Qwen3VLTextMoE(self) + + +class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): + vocab_size: int = 248320 + max_position_embeddings: int = 262144 + # Qwen3 Model(dense and moe)'s pad_token_id is not set, so we need to set it to None. + # If this pad_token_id is not set, the embedding module will not act specially for pad token. + # Note: Qwen3 Model's pad_token_id may be different from Qwen tokenizer's pad_token_id. + pad_token_id: int | None = None + eos_token_id: int = 248044 + num_hidden_layers: int = 40 + max_window_layers: int = 40 + hidden_size: int = 2048 + intermediate_size: int = 0 # not used + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000000.0 + hidden_act: str = "silu" + attention: MHAConfig = MHAConfig( + with_gate=True, + num_attention_heads=16, + num_key_value_heads=2, + head_dim=256, + qk_norm=True, + rms_norm_eps=1e-6, + rms_norm_type="zero_centered", + sliding_window=1024 + ) + linear_attention: GateDeltaNetConfig = GateDeltaNetConfig( + num_value_heads=32, + num_key_heads=16, + key_head_dim=128, + value_head_dim=128, + conv_kernel_dim=4, + hidden_act='silu', + rms_norm_eps=1e-6, + ) + tie_word_embeddings: bool = False + n_routed_experts: int = 256 + n_shared_experts: int = 1 + num_experts_per_tok: int = 8 + first_k_dense_replace: int = 0 + hidden_factor: float = 1.0 + moe_intermediate_size: int = 512 + router: GreedyRouterConfig = GreedyRouterConfig( + scoring_func="softmax", + norm_topk_prob=True, + router_scaling_factor=1.0, + ) + rope_scaling_cfg = RopeScalingConfig(type="qwen3_vl", mrope_section=[11, 11, 10], partial_rotary_factor=0.25) + balancing_loss_cfg: BalancingLossConfig | None = BalancingLossConfig() + z_loss_cfg: ZLossConfig | None = None diff --git a/xtuner/v1/model/moe/qwen3vl_text.py b/xtuner/v1/model/moe/qwen3vl_text.py index 1996edcf7..9838312ea 100644 --- a/xtuner/v1/model/moe/qwen3vl_text.py +++ b/xtuner/v1/model/moe/qwen3vl_text.py @@ -17,7 +17,12 @@ def to_hf_key_list(self, key: str) -> list[str]: key = "model.language_model." + key if "layers" in key: - key = re.sub(r"layers\.(\d+)\.(experts|gate)", r"layers.\1.mlp.\2", key) + key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts)", r"layers.\1.mlp.\2", key) + key = key.replace("shared_experts", "shared_expert") + + layer_idx = int(re.findall(r"layers\.(\d+)\.", key)[0]) + if self.config.layers_type[layer_idx] == "linear_attention": + key = key.replace("self_attn", "linear_attn") if "fused_w1w3.weight" in key: key = key.replace("fused_w1w3.weight", "gate_up_proj") diff --git a/xtuner/v1/module/attention/__init__.py b/xtuner/v1/module/attention/__init__.py index 10fcb26df..f444703d8 100644 --- a/xtuner/v1/module/attention/__init__.py +++ b/xtuner/v1/module/attention/__init__.py @@ -2,6 +2,7 @@ from .attn_outputs import AttnOutputs from .mha import MHAConfig, MultiHeadAttention from .mla import MLAConfig, MultiLatentAttention +from .gate_deltanet import GateDeltaNetConfig, GateDeltaNet __all__ = [ @@ -10,4 +11,6 @@ "MHAConfig", "MLAConfig", "AttnOutputs", + "GateDeltaNet", + "GateDeltaNetConfig", ] diff --git a/xtuner/v1/module/attention/gate_deltanet.py b/xtuner/v1/module/attention/gate_deltanet.py new file mode 100644 index 000000000..597cfe3df --- /dev/null +++ b/xtuner/v1/module/attention/gate_deltanet.py @@ -0,0 +1,214 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Annotated, Callable, Literal, cast + +import torch +from cyclopts import Parameter +from mmengine import is_installed +from pydantic import BaseModel, ConfigDict +from torch import nn +from torch.distributed.tensor import DTensor +from typing_extensions import overload +import torch.nn.functional as F + +from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.float8.config import Float8Config +from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all +from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_device, get_logger + +from ..linear import build_linear +from .attn_outputs import AttnOutputs + +from fla.modules import FusedRMSNormGated +from fla.ops.gated_delta_rule import chunk_gated_delta_rule +from causal_conv1d import causal_conv1d_fn + +logger = get_logger() + + +class GateDeltaNetConfig(BaseModel): + model_config = ConfigDict(title="Base attention config for xtuner", extra="forbid") + num_value_heads: Annotated[int, Parameter(group="attention")] + num_key_heads: Annotated[int, Parameter(group="attention")] + key_head_dim: Annotated[int, Parameter(group="attention")] + value_head_dim: Annotated[int, Parameter(group="attention")] + conv_kernel_dim: Annotated[int, Parameter(group="attention")] + hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS` + rms_norm_eps: Annotated[float, Parameter(group="attention")] + + def build( + self, + hidden_size: int, + float8_cfg: Float8Config | None = None, + **kwargs, + ) -> "GateDeltaNet": + return GateDeltaNet( + **self.model_dump(), + hidden_size=hidden_size, + float8_cfg=float8_cfg, + ) + + +class GateDeltaNet(nn.Module): + def __init__(self, + hidden_size: int, + num_value_heads: int, + num_key_heads: int, + key_head_dim: int, + value_head_dim: int, + conv_kernel_dim: int, + hidden_act: str, + rms_norm_eps: float, + layer_idx: int = 0, + float8_cfg: Float8Config | None = None) -> None: + super().__init__() + self.name = f"layers.{layer_idx}.gate_deltanet" + self.float8_cfg = float8_cfg + + self.hidden_size = hidden_size + self.num_v_heads = num_value_heads + self.num_k_heads = num_key_heads + self.head_k_dim = key_head_dim + self.head_v_dim = value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = conv_kernel_dim + self.layer_idx = layer_idx + self.activation = hidden_act + self.rms_norm_eps = rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.causal_conv1d_fn = causal_conv1d_fn + self.chunk_gated_delta_rule = chunk_gated_delta_rule + + self.norm = FusedRMSNormGated( + self.head_v_dim, + eps=self.rms_norm_eps, + activation=self.activation + ) + + self.out_proj = build_linear( + self.value_dim, + self.hidden_size, + bias=False, + float8_cfg=self.float8_cfg, + ) + + self.in_proj_qkv = build_linear( + self.hidden_size, + self.key_dim * 2 + self.value_dim, + bias=False, + float8_cfg=self.float8_cfg, + ) + self.in_proj_z = build_linear( + self.hidden_size, + self.value_dim, + bias=False, + float8_cfg=self.float8_cfg, + ) + self.in_proj_b = build_linear( + self.hidden_size, + self.num_v_heads, + bias=False, + float8_cfg=self.float8_cfg, + ) + self.in_proj_a = build_linear( + self.hidden_size, + self.num_v_heads, + bias=False, + float8_cfg=self.float8_cfg, + ) + + def forward( + self, + hidden_states: torch.Tensor, + seq_ctx: SequenceContext, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # not used + ) -> AttnOutputs: + batch_size, seq_len, _ = hidden_states.shape + assert batch_size==1, "Only batch size of 1 is supported for now in GateDeltaNet" + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, # TODO: packed sequence support + ) + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + core_attn_out, _ = self.chunk_gated_delta_rule( # TODO: packed sequence support + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + @overload # type: ignore + def __call__( # type: ignore + self, + hidden_states: torch.Tensor, + seq_ctx: SequenceContext, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> AttnOutputs: ... + + __call__ = nn.Module.__call__ + diff --git a/xtuner/v1/module/attention/mha.py b/xtuner/v1/module/attention/mha.py index 9dcd1f942..0ee245b97 100644 --- a/xtuner/v1/module/attention/mha.py +++ b/xtuner/v1/module/attention/mha.py @@ -38,9 +38,11 @@ class MHAConfig(BaseModel): qkv_bias: Annotated[bool, Parameter(group="attention")] = False qk_norm: bool = False rms_norm_eps: float = 1e-06 + rms_norm_type: Literal['default', 'zero_centered'] = 'default' o_bias: Annotated[bool, Parameter(group="attention")] = False sliding_window: Annotated[int | None, Parameter(group="attention")] = -1 with_sink: Annotated[bool, Parameter(group="attention")] = False + with_gate: Annotated[bool, Parameter(group="attention")] = False attn_impl: Literal["flash_attention", "flex_attention", "eager_attention"] = "flash_attention" def model_post_init(self, _): @@ -122,8 +124,10 @@ def __init__( qkv_bias: bool = False, qk_norm: bool = False, rms_norm_eps: float = 1e-6, + rms_norm_type: Literal['default', 'zero_centered'] = 'default', o_bias: bool = False, with_sink: bool = False, + with_gate: bool = False, attn_impl: Literal["flash_attention", "flex_attention", "eager_attention"] = "flash_attention", rope_scaling_cfg: RopeScalingConfig | None = None, float8_cfg: Float8Config | None = None, @@ -145,14 +149,16 @@ def __init__( self.qkv_bias = qkv_bias self.qk_norm = qk_norm self.rms_norm_eps = rms_norm_eps + self.rms_norm_type = rms_norm_type self.o_bias = o_bias self.generate_config = generate_config self.float8_cfg = float8_cfg self.layer_idx = layer_idx + self.with_gate = with_gate self.q_proj = build_linear( self.hidden_size, - self.num_attention_heads * self.head_dim, + self.num_attention_heads * self.head_dim if not with_gate else self.num_attention_heads * self.head_dim * 2, bias=self.qkv_bias, float8_cfg=self.float8_cfg, ) @@ -176,8 +182,8 @@ def __init__( ) if self.qk_norm: - self.q_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps) + self.q_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps, type=self.rms_norm_type) + self.k_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps, type=self.rms_norm_type) self.with_sink = with_sink if self.with_sink: @@ -188,7 +194,8 @@ def __init__( self.window_size = (sliding_window, sliding_window) fope_sep_head = rope_scaling_cfg.fope_sep_head if rope_scaling_cfg is not None else None - self.apply_rotary_emb = get_apply_rotary_emb(fope_sep_head) # type: ignore + enable_partial_rotary = rope_scaling_cfg.partial_rotary_factor != 1.0 if rope_scaling_cfg is not None else False + self.apply_rotary_emb = get_apply_rotary_emb(fope_sep_head, enable_partial_rotary=enable_partial_rotary) # type: ignore self.attn_impl_func: Callable[..., AttnOpOutputs] = attn_impl_mapping[attn_impl] # type: ignore[assignment] @@ -327,8 +334,15 @@ def forward( """ input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape) # [b, seq, n_head, head_dim] + + if self.with_gate: + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + else: + gate = None + query_states = self.q_proj(hidden_states).view(hidden_shape) # [b, seq, n_head, head_dim] key_states = self.k_proj(hidden_states).view(hidden_shape) value_states = self.v_proj(hidden_states).view(hidden_shape) @@ -409,6 +423,9 @@ def forward( ) raw_output = raw_output.reshape(*input_shape, -1).contiguous() + if self.with_gate: + raw_output = raw_output * torch.sigmoid(gate) + projected_output = self.o_proj(raw_output) attn_outputs: AttnOutputs = { "projected_output": projected_output, diff --git a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py index 53a3f90b3..c63f0eadd 100644 --- a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py @@ -44,6 +44,7 @@ def __init__( mlp_bias: bool = False, hidden_act: str, rms_norm_eps: float = 1e-6, + rms_norm_type: Literal['default', 'zero_centered'] = 'default', attention_config: MLAConfig | MHAConfig, rope_scaling_cfg: RopeScalingConfig | None = None, generate_config: GenerateConfig | None = None, @@ -68,8 +69,8 @@ def __init__( hidden_act=hidden_act, float8_cfg=float8_cfg, ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps,type=rms_norm_type) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) def forward( self, diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index 9303d66d4..a43d0e0f1 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -190,9 +190,11 @@ def __init__( moe_bias: bool = False, hidden_act: str, rms_norm_eps: float = 1e-6, + rms_norm_type: Literal['default', 'zero_centered'] = 'default', num_experts_per_tok: int, n_routed_experts: int, n_shared_experts: int, + with_shared_expert_gate: bool = False, hidden_factor: float = 1.0, attention_config: MHAConfig | MLAConfig, rope_scaling_cfg: RopeScalingConfig | None = None, @@ -220,10 +222,11 @@ def __init__( layer_type=layer_type, float8_cfg=float8_cfg, ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) self.shared_experts: MoEMLP | None self.layer_idx = layer_idx - + + self.with_shared_expert_gate = with_shared_expert_gate if n_shared_experts > 0: self.shared_experts = MoEMLP( hidden_size=hidden_size, @@ -233,10 +236,13 @@ def __init__( mlp_bias=mlp_bias, float8_cfg=float8_cfg, ) + if with_shared_expert_gate: + self.shared_expert_gate = build_linear(hidden_size, 1, bias=False, float8_cfg=float8_cfg) else: self.shared_experts = None + self.shared_expert_gate = None - self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) self.gate = MoEGate( hidden_size=hidden_size, @@ -585,6 +591,10 @@ def _shared_experts_forward( ) -> torch.Tensor: assert self.shared_experts is not None, "Shared experts should be initialized when n_shared_experts > 0" shared_experts_out = self.shared_experts(hidden_states) + + if self.with_shared_expert_gate: + shared_experts_out = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_experts_out + return shared_experts_out def _post_moe_forward( diff --git a/xtuner/v1/module/rms_norm/rms_norm.py b/xtuner/v1/module/rms_norm/rms_norm.py index 08594382d..407688a92 100644 --- a/xtuner/v1/module/rms_norm/rms_norm.py +++ b/xtuner/v1/module/rms_norm/rms_norm.py @@ -1,19 +1,28 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch from torch import nn +from typing import Literal from torch.distributed.tensor import DTensor -from xtuner.v1.ops import rms_norm +from xtuner.v1.ops import rms_norm, zero_centered_rms_norm class RMSNorm(nn.Module): weight: torch.Tensor - def __init__(self, hidden_size: int, eps: float = 1e-6): + def __init__(self, hidden_size: int, eps: float = 1e-6, type: Literal['default', 'zero_centered'] = 'default'): """RMSNorm is equivalent to T5LayerNorm.""" super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self._type = type + + if type == 'default': + self.rms_norm_fn = rms_norm + elif type == 'zero_centered': + self.rms_norm_fn = zero_centered_rms_norm + else: + raise ValueError(f'Unsupported RMSNorm type: {type}') def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if isinstance(self.weight, DTensor): @@ -28,10 +37,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # return (weight * hidden_states).to(input_dtype) # gpt_oss # return weight * hidden_states.to(input_dtype) # Llama - return rms_norm(hidden_states, weight, epsilon=self.variance_epsilon) # type: ignore[operator] + return self.rms_norm_fn(hidden_states, weight, epsilon=self.variance_epsilon) # type: ignore[operator] def init_weights(self): self.weight.data.fill_(1.0) def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + return f"{tuple(self.weight.shape)}, type={self._type}, eps={self.variance_epsilon}" diff --git a/xtuner/v1/module/rope/rope.py b/xtuner/v1/module/rope/rope.py index a8594d437..9997b7faf 100644 --- a/xtuner/v1/module/rope/rope.py +++ b/xtuner/v1/module/rope/rope.py @@ -1,4 +1,5 @@ -from typing import Literal, Protocol, cast +from typing import Literal, Protocol, cast, Optional, Callable + import torch import torch.nn as nn @@ -25,6 +26,7 @@ class RopeScalingConfig(BaseModel): # For Qwen3VL mrope_section: list[int] | None = None # e.g. [24, 20, 20] + partial_rotary_factor: float = 1.0 factor: float | None = None beta_fast: float | None = None @@ -60,6 +62,27 @@ def __call__(self, x: torch.Tensor, position_ids: torch.LongTensor) -> tuple[tor def to(self, device: torch.device) -> Self: ... +def compute_default_rope_parameters( + config, + device: Optional["torch.device"] = None, + ) -> tuple["torch.Tensor", float]: + base = config.rope_theta + if config.rope_scaling_cfg is not None: + rope_scaling_cfg: RopeScalingConfig = config.rope_scaling_cfg + partial_rotary_factor = getattr(rope_scaling_cfg, "partial_rotary_factor", 1.0) + else: + partial_rotary_factor = 1.0 + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + class RotaryEmbedding(nn.Module): inv_freq: torch.Tensor @@ -81,7 +104,11 @@ def __init__(self, config, device=None): f"Unsupported rope_type: {self.rope_type}. Supported types are: 'default', 'linear', 'yarn', 'llama3'." ) - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + # The implementation of RoPE has been refactored in Transformers V5, and + # the following approach is used for compatibility. + self.rope_init_fn: Callable = compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq: torch.Tensor inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) @@ -283,7 +310,12 @@ def __init__(self, config, device=None): self.original_max_seq_len = config.max_position_embeddings self.rope_type = "default" self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + # The implementation of RoPE has been refactored in Transformers V5, and + # the following approach is used for compatibility. + self.rope_init_fn: Callable = compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq: torch.Tensor inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) @@ -293,7 +325,7 @@ def __init__(self, config, device=None): self.mrope_section = config.rope_scaling_cfg.mrope_section assert self.mrope_section is not None - + def apply_interleaved_mrope(self, freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. diff --git a/xtuner/v1/ops/__init__.py b/xtuner/v1/ops/__init__.py index d7f043ae8..4a8d6907e 100644 --- a/xtuner/v1/ops/__init__.py +++ b/xtuner/v1/ops/__init__.py @@ -7,7 +7,7 @@ from .attn_imp import AttnOpOutputs, attn_impl_mapping from .flash_attn import flash_attn_varlen_func from .moe import group_gemm, permute, unpermute -from .rms_norm import rms_norm +from .rms_norm import rms_norm, zero_centered_rms_norm from .rotary_emb import get_apply_rotary_emb from .tensor_parallel import attn_column_parallel, attn_row_parallel diff --git a/xtuner/v1/ops/rms_norm/__init__.py b/xtuner/v1/ops/rms_norm/__init__.py index 5c8431b96..95cfa41ca 100644 --- a/xtuner/v1/ops/rms_norm/__init__.py +++ b/xtuner/v1/ops/rms_norm/__init__.py @@ -1,5 +1,4 @@ import os -from functools import partial import torch @@ -23,6 +22,16 @@ def _triton_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> t return rms_norm_fn(x, weight, bias=None, eps=epsilon) +def native_zero_centered_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: + # TODO: is native_rms_norm ? + def _norm(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + epsilon) + output = _norm(x.float()) + # Llama does x.to(float16) * w whilst Qwen3_5Moe is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + weight.float()) + return output.type_as(x) + def get_rms_norm_fn() -> RMSNormProtocol: from xtuner.v1.utils import get_device @@ -39,5 +48,21 @@ def get_rms_norm_fn() -> RMSNormProtocol: else: raise NotImplementedError(f"RMSNorm is not implemented on {device}") +def get_zero_centered_rms_norm_fn() -> RMSNormProtocol: + from xtuner.v1.utils import get_device + + device = get_device() + if device in ["cpu", "cuda"]: + # TODO: control triton rmsnorm by model config rather than env var + if os.getenv("XTUNER_USE_NATIVE_RMSNORM", "1") == "0" and device == "cuda": + raise NotImplementedError("Zero-centered RMSNorm is not implemented in triton") + else: + return native_zero_centered_rms_norm + elif device == "npu": + raise NotImplementedError("Zero-centered RMSNorm is not implemented on NPU") + else: + raise NotImplementedError(f"RMSNorm is not implemented on {device}") + rms_norm = get_rms_norm_fn() +zero_centered_rms_norm = get_zero_centered_rms_norm_fn() \ No newline at end of file diff --git a/xtuner/v1/ops/rms_norm/protocol.py b/xtuner/v1/ops/rms_norm/protocol.py index 1ce7b5ea3..426f7d683 100644 --- a/xtuner/v1/ops/rms_norm/protocol.py +++ b/xtuner/v1/ops/rms_norm/protocol.py @@ -1,7 +1,8 @@ from typing import Protocol +from typing_extensions import Literal import torch class RMSNormProtocol(Protocol): - def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: ... + def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float, type: Literal['default', 'zero_centered']) -> torch.Tensor: ... diff --git a/xtuner/v1/ops/rotary_emb.py b/xtuner/v1/ops/rotary_emb.py index e4508ca06..917e9dc2a 100644 --- a/xtuner/v1/ops/rotary_emb.py +++ b/xtuner/v1/ops/rotary_emb.py @@ -49,6 +49,27 @@ def apply_rotary_pos_emb_cuda( return q_embed, k_embed +# Note: Although this function is compatible with apply_rotary_pos_emb_cuda, +# it is still recommended to separate them into two for efficiency considerations. +def apply_rotary_pos_emb_cuda_for_partial_rotary(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + def apply_rotary_pos_emb_sep_cuda( q: torch.Tensor, k: torch.Tensor, @@ -120,16 +141,22 @@ def __call__( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def get_apply_rotary_emb(fope_sep_head: bool | None = False) -> ApplyRotaryEmbProtocol: +def get_apply_rotary_emb(fope_sep_head: bool | None = False, + enable_partial_rotary: bool= False) -> ApplyRotaryEmbProtocol: from xtuner.v1.utils.device import get_device device = get_device() if device == "npu": assert fope_sep_head is None or not fope_sep_head, "FoPE with sep head is not supported on NPU yet." + assert not enable_partial_rotary, "Partial rotary is not supported on NPU yet." return apply_rotary_pos_emb_npu else: if fope_sep_head: logger.debug("Using FoPE with fope_sep_head") + assert not enable_partial_rotary, "Partial rotary is not supported when using FoPE with sep head." return apply_rotary_pos_emb_sep_cuda else: - return apply_rotary_pos_emb_cuda + if enable_partial_rotary: + return apply_rotary_pos_emb_cuda_for_partial_rotary + else: + return apply_rotary_pos_emb_cuda From 5c328d74fd2ffd14924109c32dbb27203e598cf1 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Fri, 27 Feb 2026 04:44:34 +0000 Subject: [PATCH 02/10] fix load_weight and forward --- xtuner/v1/model/compose/qwen3_5/__init__.py | 6 ++ xtuner/v1/model/moe/qwen3_5_text.py | 100 +++++++++++++++++++- xtuner/v1/model/moe/qwen3vl_text.py | 7 +- xtuner/v1/module/attention/gate_deltanet.py | 5 +- 4 files changed, 107 insertions(+), 11 deletions(-) create mode 100644 xtuner/v1/model/compose/qwen3_5/__init__.py diff --git a/xtuner/v1/model/compose/qwen3_5/__init__.py b/xtuner/v1/model/compose/qwen3_5/__init__.py new file mode 100644 index 000000000..e9b332168 --- /dev/null +++ b/xtuner/v1/model/compose/qwen3_5/__init__.py @@ -0,0 +1,6 @@ +from .qwen3_5_config import Qwen3_5_VLMoE35BA3Config + + +__all__ = [ + "Qwen3_5_VLMoE35BA3Config", +] diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 35c501040..681addcbf 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -2,6 +2,7 @@ from pydantic import computed_field from typing import Literal import re +import torch from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig from xtuner.v1.module.attention import MHAConfig, GateDeltaNetConfig from xtuner.v1.module.router.greedy import GreedyRouterConfig @@ -10,6 +11,97 @@ from xtuner.v1.model.moe.moe import MoEConfig from .qwen3vl_text import Qwen3VLTextMoE +class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): + def to_hf_key_list(self, key: str) -> list[str]: + if "layers" in key or "embed_tokens" in key: + key = "model.language_model." + key + + if "layers" in key: + key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts|shared_expert_gate)", r"layers.\1.mlp.\2", key) + key = key.replace("shared_experts", "shared_expert") + + layer_idx = int(re.findall(r"layers\.(\d+)\.", key)[0]) + if self.config.layers_type[layer_idx] == "linear_attention": + key = key.replace("self_attn", "linear_attn") + + if "fused_w1w3.weight" in key: + key = key.replace("fused_w1w3.weight", "gate_up_proj") + elif "fused_w2.weight" in key: + key = key.replace("fused_w2.weight", "down_proj") + if "fused_w1w3.bias" in key: + key = key.replace("fused_w1w3.bias", "gate_up_proj_bias") + elif "fused_w2.bias" in key: + key = key.replace("fused_w2.bias", "down_proj_bias") + + if key.startswith("norm."): + return [key.replace("norm.", "model.language_model.norm.")] + elif key.startswith("rotary_emb."): + # FoPE has model.rotary_emb.sin_coef and model.rotary_emb.cos_coef in the safetensors + return [key.replace("rotary_emb.", "model.language_model.rotary_emb.")] + else: + return [key] + + def safetensors_to_params( + self, + safetensors: list[torch.Tensor], + local_tensor: torch.Tensor, + param_name: str, + start: int | None, + end: int | None, + dim: int | None, + ): + if len(safetensors) > 1: + assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1" + loaded_tensor = torch.cat(safetensors, dim=dim) + else: + loaded_tensor = safetensors[0] + + if "fused_w1w3.weight" in param_name: + # hf: num_experts, 2 * expert_dim, hidden_size + # xtuner: num_experts * 2 * expert_dim, hidden_size + # num_experts * 2 * expert_dim, hidden_size + loaded_tensor = loaded_tensor.flatten(0, 1) + + elif "fused_w2.weight" in param_name: + # hf: num_experts, hidden_size, expert_dim + # xtuner: num_experts * hidden_size, expert_dim + loaded_tensor = loaded_tensor.flatten(0, 1) + + if start is not None and end is not None: + start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM]) + end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM]) + loaded_tensor_slice = loaded_tensor.index_select( + dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device) + ) + non_pad_len = end - start + local_tensor[:non_pad_len].copy_(loaded_tensor_slice) + + if non_pad_len < local_tensor.shape[self.FSDP_SHARD_DIM]: + assert self.config.float8_cfg is not None + local_tensor[non_pad_len:].copy_(0.0) # type: ignore # padded part must be set to 0 + else: + local_tensor.copy_(loaded_tensor) + + def param_to_safetensor( + self, + safetensor: torch.Tensor, + hf_param_name: str, + ): + assert isinstance(hf_param_name, str) + if "gate_up_proj" in hf_param_name: + # xtuner: num_experts * 2 * expert_dim, hidden_size + # hf: num_experts, 2 * expert_dim, hidden_size + num_experts = self.config.n_routed_experts + hidden_size = safetensor.size(1) + safetensor = safetensor.reshape(num_experts, -1, hidden_size).contiguous() # num_experts, 2 * expert_dim, hidden_size + elif "down_proj" in hf_param_name: + # xtuner: num_experts * hidden_size, expert_dim + # hf: num_experts, hidden_size, expert_dim + num_experts = self.config.n_routed_experts + expert_dim = safetensor.size(1) + safetensor = safetensor.reshape(num_experts, -1, expert_dim).contiguous() + return safetensor + class Qwen3_5_VLTextMoEConfig(MoEConfig): with_shared_expert_gate: bool = True @@ -17,10 +109,10 @@ class Qwen3_5_VLTextMoEConfig(MoEConfig): @computed_field def layers_type(self) -> list[Literal["full_attention", "linear_attention"]]: - return ["full_attention" if bool((i + 1) % 4) else "linear_attention" for i in range(self.num_hidden_layers)] + return ["linear_attention" if bool((i + 1) % 4) else "full_attention" for i in range(self.num_hidden_layers)] - def build(self) -> Qwen3VLTextMoE: - return Qwen3VLTextMoE(self) + def build(self) -> Qwen3_5_VLTextMoE: + return Qwen3_5_VLTextMoE(self) class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): @@ -69,6 +161,6 @@ class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): norm_topk_prob=True, router_scaling_factor=1.0, ) - rope_scaling_cfg = RopeScalingConfig(type="qwen3_vl", mrope_section=[11, 11, 10], partial_rotary_factor=0.25) + rope_scaling_cfg: RopeScalingConfig = RopeScalingConfig(type="qwen3_vl", mrope_section=[11, 11, 10], partial_rotary_factor=0.25) balancing_loss_cfg: BalancingLossConfig | None = BalancingLossConfig() z_loss_cfg: ZLossConfig | None = None diff --git a/xtuner/v1/model/moe/qwen3vl_text.py b/xtuner/v1/model/moe/qwen3vl_text.py index 9838312ea..1996edcf7 100644 --- a/xtuner/v1/model/moe/qwen3vl_text.py +++ b/xtuner/v1/model/moe/qwen3vl_text.py @@ -17,12 +17,7 @@ def to_hf_key_list(self, key: str) -> list[str]: key = "model.language_model." + key if "layers" in key: - key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts)", r"layers.\1.mlp.\2", key) - key = key.replace("shared_experts", "shared_expert") - - layer_idx = int(re.findall(r"layers\.(\d+)\.", key)[0]) - if self.config.layers_type[layer_idx] == "linear_attention": - key = key.replace("self_attn", "linear_attn") + key = re.sub(r"layers\.(\d+)\.(experts|gate)", r"layers.\1.mlp.\2", key) if "fused_w1w3.weight" in key: key = key.replace("fused_w1w3.weight", "gate_up_proj") diff --git a/xtuner/v1/module/attention/gate_deltanet.py b/xtuner/v1/module/attention/gate_deltanet.py index 597cfe3df..b0f9abd21 100644 --- a/xtuner/v1/module/attention/gate_deltanet.py +++ b/xtuner/v1/module/attention/gate_deltanet.py @@ -200,7 +200,10 @@ def forward( core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) output = self.out_proj(core_attn_out) - return output + attn_outputs: AttnOutputs = { + "projected_output": output, + } + return attn_outputs @overload # type: ignore def __call__( # type: ignore From ce0cad8677fea1e13c595d12483536c67aed2283 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Fri, 27 Feb 2026 12:13:24 +0000 Subject: [PATCH 03/10] fix loss of hf --- .../module/decoder_layer/moe_decoder_layer.py | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index a43d0e0f1..5a44cb4b5 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -127,9 +127,13 @@ def forward( bias = self.bias.to_local() if isinstance(self.bias, DTensor) else self.bias bias = bias.float() - logits = F.linear(hidden_states.float(), weight.float(), bias) - - return self.router(logits, rollout_routed_experts) + # logits = F.linear(hidden_states.float(), weight.float(), bias) + # return self.router(logits, rollout_routed_experts) + + logits = F.linear(hidden_states, weight, bias) + gate = self.router(logits, rollout_routed_experts) + gate['topk_weights'] = gate['topk_weights'].float() + return gate class MoEBlock(nn.Module): @@ -387,6 +391,41 @@ def _forward( ) combined_hidden_states = post_combined["hidden_states"] combined_hidden_states = combined_hidden_states.view(*origin_shape) + + # # 对齐 hf 实现用 + # # xtuner: num_experts * 2 * expert_dim, hidden_size + # # hf: num_experts, 2 * expert_dim, hidden_size + # origin_gate_up_proj = self.experts.fused_w1w3.weight + # gate_up_proj = origin_gate_up_proj.view(self.n_routed_experts, 2 * self.experts.intermediate_size, self.hidden_size) + + # # xtuner: num_experts * hidden_size, expert_dim + # # hf: num_experts, hidden_size, expert_dim + # origin_down_proj = self.experts.fused_w2.weight + # down_proj = origin_down_proj.view(self.n_routed_experts, self.hidden_size, self.experts.intermediate_size) + + # from transformers.activations import ACT2FN + # act_fn = ACT2FN['silu'] + + # hidden_states_reshaped = hidden_states.view(-1, hidden_states.size(-1)) + # combined_hidden_states = torch.zeros_like(hidden_states_reshaped) + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_results["topk_ids"], num_classes=self.n_routed_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hit: + # expert_idx = expert_idx[0] + # if expert_idx == self.n_routed_experts: + # continue + # top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + # current_state = hidden_states_reshaped[token_idx] + # gate, up = nn.functional.linear(current_state, gate_up_proj[expert_idx]).chunk(2, dim=-1) + # current_hidden_states = act_fn(gate) * up + # current_hidden_states = nn.functional.linear(current_hidden_states, down_proj[expert_idx]) + # current_hidden_states = current_hidden_states * router_results["topk_weights"][token_idx, top_k_pos, None] + # combined_hidden_states.index_add_(0, token_idx, current_hidden_states.to(combined_hidden_states.dtype)) + + # combined_hidden_states = combined_hidden_states.view(*origin_shape) + # ProberList.after_combine(self.layer_idx, combined_hidden_states) if self.n_shared_experts > 0: From bd8bc8e0f983fdd7336ed515d7482fe055aaf63f Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Sat, 28 Feb 2026 04:38:45 +0000 Subject: [PATCH 04/10] support pack --- xtuner/v1/module/attention/gate_deltanet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xtuner/v1/module/attention/gate_deltanet.py b/xtuner/v1/module/attention/gate_deltanet.py index b0f9abd21..aa9a7e2cd 100644 --- a/xtuner/v1/module/attention/gate_deltanet.py +++ b/xtuner/v1/module/attention/gate_deltanet.py @@ -154,12 +154,15 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) + num_tokens = seq_ctx.seq_lens_q + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, + device=hidden_states.device) for i, s in enumerate(num_tokens)], dim=0)[None] mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, - seq_idx=None, # TODO: packed sequence support + seq_idx=seq_idx, ) mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -191,6 +194,7 @@ def forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, + cu_seqlens=seq_ctx.cu_seq_lens_q, ) # reshape input data into 2D tensor From 1f0060e0e06c697971cf6f48056fe737faeba89e Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Sat, 28 Feb 2026 08:51:25 +0000 Subject: [PATCH 05/10] fix save hf --- xtuner/v1/data_proto/messages/chat.py | 2 +- xtuner/v1/data_proto/sequence_context.py | 11 +++++++++-- xtuner/v1/datasets/collator.py | 2 ++ xtuner/v1/model/__init__.py | 2 ++ xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py | 5 +++-- xtuner/v1/model/moe/qwen3_5_text.py | 2 +- xtuner/v1/module/attention/gate_deltanet.py | 7 ++----- 7 files changed, 20 insertions(+), 11 deletions(-) diff --git a/xtuner/v1/data_proto/messages/chat.py b/xtuner/v1/data_proto/messages/chat.py index 69a098d25..ebf3019b8 100644 --- a/xtuner/v1/data_proto/messages/chat.py +++ b/xtuner/v1/data_proto/messages/chat.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from transformers import PreTrainedTokenizer -from xtuner.utils import IGNORE_INDEX +from xtuner.v1.utils import IGNORE_INDEX from xtuner.v1.data_proto.messages.base import BaseMessages from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate from xtuner.v1.utils import get_logger diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index fc91e4cb4..06215d0ab 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -101,7 +101,12 @@ def __init__( seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] seq_lens_q = self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] - + + # Used for causal_conv1d varlen. It cannot be calculated in the compile function of linear attention, + # and needs to be calculated in advance. + self.seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, + device=device) for i, s in enumerate(seq_lens_q)], dim=0)[None] + if position_ids is None: _position_ids = [torch.arange(k - q, k) for q, k in zip(seq_lens_q, seq_lens_k)] position_ids = torch.cat(_position_ids).unsqueeze(0).to(self.cu_seq_lens_k.device) # type: ignore[assignment] @@ -377,10 +382,12 @@ def to(self, device: torch.device | str): if device == "npu" or isinstance(device, torch.device) and device.type == "npu": self.cu_seq_lens_q = self.cu_seq_lens_q.cpu() # type: ignore self.cu_seq_lens_k = self.cu_seq_lens_k.cpu() # type: ignore + self.seq_idx = self.seq_idx.cpu() # type: ignore else: self.cu_seq_lens_q = self.cu_seq_lens_q.to(device) # type: ignore self.cu_seq_lens_k = self.cu_seq_lens_k.to(device) # type: ignore - + self.seq_idx = self.seq_idx.to(device) # type: ignore + if self.position_ids is not None and hasattr(self.position_ids, "to"): self.position_ids = self.position_ids.to(device) # type: ignore diff --git a/xtuner/v1/datasets/collator.py b/xtuner/v1/datasets/collator.py index b19d67f34..62b2916cb 100644 --- a/xtuner/v1/datasets/collator.py +++ b/xtuner/v1/datasets/collator.py @@ -247,6 +247,8 @@ def qwen3_vl_sft_collator( if len(position_ids_list) > 0: position_ids = torch.cat(position_ids_list, dim=-1) position_ids = position_ids[:, :, :-1] + if pack_to_max_length and pack_max_length - position_ids.shape[-1] > 0: + position_ids = pad_to_max_length(position_ids, 0, max_length=pack_max_length, dim=-1) num_img_tokens: list[int] = [] for data in instance: diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index 3f77cd9b8..a6888dbcc 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -17,6 +17,7 @@ Qwen3VLMoE30BA3Config, Qwen3VLMoE235BA22Config, ) +from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config from .dense.dense import Dense from .dense.qwen2 import Qwen2Dense7BConfig, Qwen2DenseConfig from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig @@ -98,4 +99,5 @@ def get_model_config_from_hf(model_path: Path): "TorchCompileOption", "DEFAULT_FLOAT8_CFG", "XTunerBaseModelConfig", + "Qwen3_5_VLMoE35BA3Config", ] diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index 865cbcfb1..3b6b0d8f7 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -182,8 +182,9 @@ def forward( if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0: assert seq_ctx.position_ids is not None - assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \ - f" but got {seq_ctx.position_ids.ndim}" + # qwen3.5 does not satisfy this condition + # assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \ + # f" but got {seq_ctx.position_ids.ndim}" deepstack_visual_embeds = None visual_pos_masks = None diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 681addcbf..a5505bf26 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -94,7 +94,7 @@ def param_to_safetensor( num_experts = self.config.n_routed_experts hidden_size = safetensor.size(1) safetensor = safetensor.reshape(num_experts, -1, hidden_size).contiguous() # num_experts, 2 * expert_dim, hidden_size - elif "down_proj" in hf_param_name: + elif "down_proj" in hf_param_name and "shared_expert" not in hf_param_name: # xtuner: num_experts * hidden_size, expert_dim # hf: num_experts, hidden_size, expert_dim num_experts = self.config.n_routed_experts diff --git a/xtuner/v1/module/attention/gate_deltanet.py b/xtuner/v1/module/attention/gate_deltanet.py index aa9a7e2cd..7d2d3372f 100644 --- a/xtuner/v1/module/attention/gate_deltanet.py +++ b/xtuner/v1/module/attention/gate_deltanet.py @@ -154,15 +154,12 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - num_tokens = seq_ctx.seq_lens_q - seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, - device=hidden_states.device) for i, s in enumerate(num_tokens)], dim=0)[None] mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, - seq_idx=seq_idx, + seq_idx=seq_ctx.seq_idx, ) mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -185,7 +182,7 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - core_attn_out, _ = self.chunk_gated_delta_rule( # TODO: packed sequence support + core_attn_out, _ = self.chunk_gated_delta_rule( query, key, value, From 6ab0141cfd6d3bd77933228dcf8517f3fd221d8b Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Sat, 28 Feb 2026 10:04:06 +0000 Subject: [PATCH 06/10] rename --- xtuner/v1/model/base.py | 4 ++-- xtuner/v1/model/moe/qwen3_5_text.py | 4 ++-- xtuner/v1/module/attention/__init__.py | 6 +++--- .../attention/{gate_deltanet.py => gated_deltanet.py} | 11 ++++++----- 4 files changed, 13 insertions(+), 12 deletions(-) rename xtuner/v1/module/attention/{gate_deltanet.py => gated_deltanet.py} (97%) diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index be58b9519..4ccbda643 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -39,7 +39,7 @@ WeightWithDynamicTilewiseFloat8CastTensor, ) from xtuner.v1.loss import BaseLossContext -from xtuner.v1.module.attention import MHAConfig, MLAConfig, GateDeltaNetConfig +from xtuner.v1.module.attention import MHAConfig, MLAConfig, GatedDeltaNetConfig from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory @@ -151,7 +151,7 @@ class TransformerConfig(XTunerBaseModelConfig): rope_theta: Annotated[float, Parameter(group="model")] # required by transformers's build rope hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS` attention: MLAConfig | MHAConfig - linear_attention: Annotated[GateDeltaNetConfig | None, Parameter(group="model")] = None + linear_attention: Annotated[GatedDeltaNetConfig | None, Parameter(group="model")] = None mlp_bias: Annotated[bool, Parameter(group="model")] = False tie_word_embeddings: Annotated[bool, Parameter(group="model")] = False model_type: Annotated[str | None, Parameter(group="model")] = None # TODO: yehaochen maybe should be removed diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index a5505bf26..692cdaf0a 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -4,7 +4,7 @@ import re import torch from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig -from xtuner.v1.module.attention import MHAConfig, GateDeltaNetConfig +from xtuner.v1.module.attention import MHAConfig, GatedDeltaNetConfig from xtuner.v1.module.router.greedy import GreedyRouterConfig from xtuner.v1.module.rope import RopeScalingConfig @@ -140,7 +140,7 @@ class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): rms_norm_type="zero_centered", sliding_window=1024 ) - linear_attention: GateDeltaNetConfig = GateDeltaNetConfig( + linear_attention: GatedDeltaNetConfig = GatedDeltaNetConfig( num_value_heads=32, num_key_heads=16, key_head_dim=128, diff --git a/xtuner/v1/module/attention/__init__.py b/xtuner/v1/module/attention/__init__.py index f444703d8..593cc4b22 100644 --- a/xtuner/v1/module/attention/__init__.py +++ b/xtuner/v1/module/attention/__init__.py @@ -2,7 +2,7 @@ from .attn_outputs import AttnOutputs from .mha import MHAConfig, MultiHeadAttention from .mla import MLAConfig, MultiLatentAttention -from .gate_deltanet import GateDeltaNetConfig, GateDeltaNet +from .gated_deltanet import GatedDeltaNetConfig, GatedDeltaNet __all__ = [ @@ -11,6 +11,6 @@ "MHAConfig", "MLAConfig", "AttnOutputs", - "GateDeltaNet", - "GateDeltaNetConfig", + "GatedDeltaNet", + "GatedDeltaNetConfig", ] diff --git a/xtuner/v1/module/attention/gate_deltanet.py b/xtuner/v1/module/attention/gated_deltanet.py similarity index 97% rename from xtuner/v1/module/attention/gate_deltanet.py rename to xtuner/v1/module/attention/gated_deltanet.py index 7d2d3372f..0c0066f9f 100644 --- a/xtuner/v1/module/attention/gate_deltanet.py +++ b/xtuner/v1/module/attention/gated_deltanet.py @@ -15,6 +15,7 @@ from xtuner.v1.float8.config import Float8Config from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_device, get_logger +from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all from ..linear import build_linear from .attn_outputs import AttnOutputs @@ -26,7 +27,7 @@ logger = get_logger() -class GateDeltaNetConfig(BaseModel): +class GatedDeltaNetConfig(BaseModel): model_config = ConfigDict(title="Base attention config for xtuner", extra="forbid") num_value_heads: Annotated[int, Parameter(group="attention")] num_key_heads: Annotated[int, Parameter(group="attention")] @@ -41,15 +42,15 @@ def build( hidden_size: int, float8_cfg: Float8Config | None = None, **kwargs, - ) -> "GateDeltaNet": - return GateDeltaNet( + ) -> "GatedDeltaNet": + return GatedDeltaNet( **self.model_dump(), hidden_size=hidden_size, float8_cfg=float8_cfg, ) -class GateDeltaNet(nn.Module): +class GatedDeltaNet(nn.Module): def __init__(self, hidden_size: int, num_value_heads: int, @@ -153,7 +154,7 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - + mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), From 3d2b3691dd8d4e8b2a0986da1e530d927ca64b71 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Mon, 2 Mar 2026 04:16:55 +0000 Subject: [PATCH 07/10] fix lint --- xtuner/v1/data_proto/messages/chat.py | 3 +- xtuner/v1/data_proto/sequence_context.py | 13 +- xtuner/v1/model/__init__.py | 2 +- xtuner/v1/model/base.py | 4 +- .../model/compose/qwen3_5/qwen3_5_config.py | 7 +- .../compose/qwen3_vl/modeling_qwen3_vl.py | 7 +- xtuner/v1/model/dense/dense.py | 21 +++- xtuner/v1/model/moe/moe.py | 16 ++- xtuner/v1/model/moe/qwen3_5_text.py | 41 +++--- xtuner/v1/module/__init__.py | 4 + xtuner/v1/module/attention/__init__.py | 2 +- xtuner/v1/module/attention/gated_deltanet.py | 98 ++++++++------- xtuner/v1/module/attention/mha.py | 17 ++- .../decoder_layer/dense_decoder_layer.py | 12 +- .../module/decoder_layer/moe_decoder_layer.py | 119 ++++++++++-------- xtuner/v1/module/rms_norm/rms_norm.py | 11 +- xtuner/v1/module/rope/rope.py | 44 ++++--- xtuner/v1/ops/rms_norm/__init__.py | 5 +- xtuner/v1/ops/rms_norm/protocol.py | 3 +- xtuner/v1/ops/rotary_emb.py | 16 ++- 20 files changed, 261 insertions(+), 184 deletions(-) diff --git a/xtuner/v1/data_proto/messages/chat.py b/xtuner/v1/data_proto/messages/chat.py index ebf3019b8..fcc4adc64 100644 --- a/xtuner/v1/data_proto/messages/chat.py +++ b/xtuner/v1/data_proto/messages/chat.py @@ -6,10 +6,9 @@ from pydantic import BaseModel, ConfigDict from transformers import PreTrainedTokenizer -from xtuner.v1.utils import IGNORE_INDEX from xtuner.v1.data_proto.messages.base import BaseMessages from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate -from xtuner.v1.utils import get_logger +from xtuner.v1.utils import IGNORE_INDEX, get_logger logger = get_logger() diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index 06215d0ab..8c217400b 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -101,12 +101,13 @@ def __init__( seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] seq_lens_q = self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] - - # Used for causal_conv1d varlen. It cannot be calculated in the compile function of linear attention, + + # Used for causal_conv1d varlen. It cannot be calculated in the compile function of linear attention, # and needs to be calculated in advance. - self.seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, - device=device) for i, s in enumerate(seq_lens_q)], dim=0)[None] - + self.seq_idx = torch.cat( + [torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seq_lens_q)], dim=0 + )[None] + if position_ids is None: _position_ids = [torch.arange(k - q, k) for q, k in zip(seq_lens_q, seq_lens_k)] position_ids = torch.cat(_position_ids).unsqueeze(0).to(self.cu_seq_lens_k.device) # type: ignore[assignment] @@ -387,7 +388,7 @@ def to(self, device: torch.device | str): self.cu_seq_lens_q = self.cu_seq_lens_q.to(device) # type: ignore self.cu_seq_lens_k = self.cu_seq_lens_k.to(device) # type: ignore self.seq_idx = self.seq_idx.to(device) # type: ignore - + if self.position_ids is not None and hasattr(self.position_ids, "to"): self.position_ids = self.position_ids.to(device) # type: ignore diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index a6888dbcc..b0744864f 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -11,13 +11,13 @@ InternVL3P5MoE30BA3Config, InternVLBaseConfig, ) +from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config from .compose.qwen3_vl import ( Qwen3VLDense4BConfig, Qwen3VLDense8BConfig, Qwen3VLMoE30BA3Config, Qwen3VLMoE235BA22Config, ) -from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config from .dense.dense import Dense from .dense.qwen2 import Qwen2Dense7BConfig, Qwen2DenseConfig from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 4ccbda643..c4a61fa35 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -39,7 +39,7 @@ WeightWithDynamicTilewiseFloat8CastTensor, ) from xtuner.v1.loss import BaseLossContext -from xtuner.v1.module.attention import MHAConfig, MLAConfig, GatedDeltaNetConfig +from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig, MLAConfig from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory @@ -147,7 +147,7 @@ class TransformerConfig(XTunerBaseModelConfig): hidden_size: Annotated[int, Parameter(group="model")] intermediate_size: Annotated[int, Parameter(group="model")] rms_norm_eps: Annotated[float, Parameter(group="model")] - rms_norm_type: Annotated[str, Parameter(group="model")] = 'default' # default | zero_centered + rms_norm_type: Annotated[Literal["default", "zero_centered"], Parameter(group="model")] = "default" rope_theta: Annotated[float, Parameter(group="model")] # required by transformers's build rope hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS` attention: MLAConfig | MHAConfig diff --git a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py index c982705af..884cd48b1 100644 --- a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py +++ b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py @@ -2,16 +2,20 @@ from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig from xtuner.v1.utils import get_logger -from ..qwen3_vl.qwen3_vl_config import Qwen3VLVisionConfig, Qwen3VLProjectorConfig, Qwen3VLBaseConfig +from ..qwen3_vl.qwen3_vl_config import Qwen3VLBaseConfig, Qwen3VLProjectorConfig, Qwen3VLVisionConfig + logger = get_logger() + class Qwen3_5_VisionConfig(Qwen3VLVisionConfig): deepstack_visual_indexes: list[int] = [] + class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig): deepstack_visual_indexes: list[int] = [] + class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): vision_config: Qwen3_5_VisionConfig projector_config: Qwen3_5_ProjectorConfig @@ -22,6 +26,7 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): vision_start_token_id: int = 248053 vision_end_token_id: int = 248054 + class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig): vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index 3b6b0d8f7..06706800a 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -182,9 +182,10 @@ def forward( if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0: assert seq_ctx.position_ids is not None - # qwen3.5 does not satisfy this condition - # assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \ - # f" but got {seq_ctx.position_ids.ndim}" + assert seq_ctx.position_ids.ndim in (2, 3), ( + f"position_ids must be 2-dim or 3-dim when deepstack_visual_embeds is None," + f" but got {seq_ctx.position_ids.ndim}" + ) deepstack_visual_embeds = None visual_pos_masks = None diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index e77b0ecba..2cca9da86 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -29,7 +29,15 @@ TransformerConfig, ) from xtuner.v1.model.utils import checkpoint_wrapper -from xtuner.v1.module import LMHead, RMSNorm, RotaryEmbeddingProtocol, get_rope_embedding +from xtuner.v1.module import ( + GatedDeltaNetConfig, + LMHead, + MHAConfig, + MLAConfig, + RMSNorm, + RotaryEmbeddingProtocol, + get_rope_embedding, +) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.utils import ( get_device, @@ -116,15 +124,20 @@ def build_layers(self, config: TransformerConfig) -> nn.ModuleDict: # 让 layers 是一个 nn.ModuleDict 方便做 pipeline parallel 的参数切分, # 这样可以保证部分 layer 被切掉后,idx 保持不变 layers = nn.ModuleDict() + attention_config: GatedDeltaNetConfig | MLAConfig | MHAConfig | None = None for layer_idx in range(config.num_hidden_layers): if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]: attention_config = config.attention elif config.layers_type[layer_idx] == "linear_attention": attention_config = config.linear_attention - assert attention_config is not None, "linear_attention config must be provided for linear_attention layer" + assert attention_config is not None, ( + "linear_attention config must be provided for linear_attention layer" + ) else: - raise ValueError(f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported.") - + raise ValueError( + f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported." + ) + layers[str(layer_idx)] = DenseDecoderLayer( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index c60796f50..4b96c8f6d 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -37,8 +37,11 @@ ) from xtuner.v1.model.utils import ModelForwardExtraLogInfo, checkpoint_wrapper, module_dict_repr from xtuner.v1.module import ( + GatedDeltaNetConfig, GreedyRouterConfig, LMHead, + MHAConfig, + MLAConfig, NoAuxRouter, NoAuxRouterConfig, RMSNorm, @@ -116,7 +119,7 @@ class MoEConfig(TransformerConfig): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") n_routed_experts: Annotated[int, Parameter(group="moe")] n_shared_experts: Annotated[int, Parameter(group="moe")] - with_shared_expert_gate: bool = False # enable when n_shared_experts > 0 + with_shared_expert_gate: bool = False # enable when n_shared_experts > 0 num_experts_per_tok: Annotated[int, Parameter(group="moe")] first_k_dense_replace: Annotated[int, Parameter(group="moe")] = 0 hidden_factor: Annotated[float, Parameter(group="moe")] = 1.0 @@ -593,15 +596,20 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: # 让 layers 是一个 nn.ModuleDict 方便做 pipeline parallel 的参数切分, # 这样可以保证部分 layer 被切掉后,idx 保持不变 layers = nn.ModuleDict() + attention_config: GatedDeltaNetConfig | MLAConfig | MHAConfig | None = None for layer_idx in range(config.num_hidden_layers): if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]: attention_config = config.attention elif config.layers_type[layer_idx] == "linear_attention": attention_config = config.linear_attention - assert attention_config is not None, "linear_attention config must be provided for linear_attention layer" + assert attention_config is not None, ( + "linear_attention config must be provided for linear_attention layer" + ) else: - raise ValueError(f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported.") - + raise ValueError( + f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported." + ) + if layer_idx < config.first_k_dense_replace: layers[str(layer_idx)] = DenseDecoderLayer( hidden_size=config.hidden_size, diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index 692cdaf0a..a7309771c 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -1,16 +1,17 @@ - -from pydantic import computed_field -from typing import Literal import re +from typing import Literal + import torch +from pydantic import computed_field + from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig -from xtuner.v1.module.attention import MHAConfig, GatedDeltaNetConfig -from xtuner.v1.module.router.greedy import GreedyRouterConfig +from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig from xtuner.v1.module.rope import RopeScalingConfig +from xtuner.v1.module.router.greedy import GreedyRouterConfig -from xtuner.v1.model.moe.moe import MoEConfig from .qwen3vl_text import Qwen3VLTextMoE + class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): def to_hf_key_list(self, key: str) -> list[str]: if "layers" in key or "embed_tokens" in key: @@ -22,7 +23,7 @@ def to_hf_key_list(self, key: str) -> list[str]: layer_idx = int(re.findall(r"layers\.(\d+)\.", key)[0]) if self.config.layers_type[layer_idx] == "linear_attention": - key = key.replace("self_attn", "linear_attn") + key = key.replace("self_attn", "linear_attn") if "fused_w1w3.weight" in key: key = key.replace("fused_w1w3.weight", "gate_up_proj") @@ -93,19 +94,21 @@ def param_to_safetensor( # hf: num_experts, 2 * expert_dim, hidden_size num_experts = self.config.n_routed_experts hidden_size = safetensor.size(1) - safetensor = safetensor.reshape(num_experts, -1, hidden_size).contiguous() # num_experts, 2 * expert_dim, hidden_size + safetensor = safetensor.reshape( + num_experts, -1, hidden_size + ).contiguous() # num_experts, 2 * expert_dim, hidden_size elif "down_proj" in hf_param_name and "shared_expert" not in hf_param_name: # xtuner: num_experts * hidden_size, expert_dim # hf: num_experts, hidden_size, expert_dim num_experts = self.config.n_routed_experts expert_dim = safetensor.size(1) safetensor = safetensor.reshape(num_experts, -1, expert_dim).contiguous() - return safetensor + return safetensor class Qwen3_5_VLTextMoEConfig(MoEConfig): with_shared_expert_gate: bool = True - rms_norm_type: Literal["defalut", "zero_centered"] = "zero_centered" + rms_norm_type: Literal["default", "zero_centered"] = "zero_centered" @computed_field def layers_type(self) -> list[Literal["full_attention", "linear_attention"]]: @@ -126,19 +129,19 @@ class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): num_hidden_layers: int = 40 max_window_layers: int = 40 hidden_size: int = 2048 - intermediate_size: int = 0 # not used + intermediate_size: int = 0 # not used rms_norm_eps: float = 1e-6 rope_theta: float = 10000000.0 hidden_act: str = "silu" attention: MHAConfig = MHAConfig( with_gate=True, - num_attention_heads=16, - num_key_value_heads=2, - head_dim=256, - qk_norm=True, + num_attention_heads=16, + num_key_value_heads=2, + head_dim=256, + qk_norm=True, rms_norm_eps=1e-6, rms_norm_type="zero_centered", - sliding_window=1024 + sliding_window=1024, ) linear_attention: GatedDeltaNetConfig = GatedDeltaNetConfig( num_value_heads=32, @@ -146,7 +149,7 @@ class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): key_head_dim=128, value_head_dim=128, conv_kernel_dim=4, - hidden_act='silu', + hidden_act="silu", rms_norm_eps=1e-6, ) tie_word_embeddings: bool = False @@ -161,6 +164,8 @@ class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): norm_topk_prob=True, router_scaling_factor=1.0, ) - rope_scaling_cfg: RopeScalingConfig = RopeScalingConfig(type="qwen3_vl", mrope_section=[11, 11, 10], partial_rotary_factor=0.25) + rope_scaling_cfg: RopeScalingConfig = RopeScalingConfig( + type="qwen3_vl", mrope_section=[11, 11, 10], partial_rotary_factor=0.25 + ) balancing_loss_cfg: BalancingLossConfig | None = BalancingLossConfig() z_loss_cfg: ZLossConfig | None = None diff --git a/xtuner/v1/module/__init__.py b/xtuner/v1/module/__init__.py index c84f7d500..2f1d59e67 100644 --- a/xtuner/v1/module/__init__.py +++ b/xtuner/v1/module/__init__.py @@ -1,5 +1,7 @@ from .attention import ( AttnOutputs, + GatedDeltaNet, + GatedDeltaNetConfig, MHAConfig, MLAConfig, MultiHeadAttention, @@ -26,6 +28,8 @@ "MultiLatentAttention", "MHAConfig", "MLAConfig", + "GatedDeltaNetConfig", + "GatedDeltaNet", "AttnOutputs", "RopeScalingConfig", "RotaryEmbedding", diff --git a/xtuner/v1/module/attention/__init__.py b/xtuner/v1/module/attention/__init__.py index 593cc4b22..d4594014a 100644 --- a/xtuner/v1/module/attention/__init__.py +++ b/xtuner/v1/module/attention/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .attn_outputs import AttnOutputs +from .gated_deltanet import GatedDeltaNet, GatedDeltaNetConfig from .mha import MHAConfig, MultiHeadAttention from .mla import MLAConfig, MultiLatentAttention -from .gated_deltanet import GatedDeltaNetConfig, GatedDeltaNet __all__ = [ diff --git a/xtuner/v1/module/attention/gated_deltanet.py b/xtuner/v1/module/attention/gated_deltanet.py index 0c0066f9f..a885d89e8 100644 --- a/xtuner/v1/module/attention/gated_deltanet.py +++ b/xtuner/v1/module/attention/gated_deltanet.py @@ -1,28 +1,33 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Annotated, Callable, Literal, cast +from typing import Annotated import torch +import torch.nn.functional as F from cyclopts import Parameter -from mmengine import is_installed from pydantic import BaseModel, ConfigDict from torch import nn -from torch.distributed.tensor import DTensor from typing_extensions import overload -import torch.nn.functional as F from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.config import Float8Config -from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all -from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_device, get_logger -from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all +from xtuner.v1.utils import get_logger from ..linear import build_linear from .attn_outputs import AttnOutputs -from fla.modules import FusedRMSNormGated -from fla.ops.gated_delta_rule import chunk_gated_delta_rule -from causal_conv1d import causal_conv1d_fn + +try: + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule +except ImportError: + FusedRMSNormGated = None + chunk_gated_delta_rule = None + +try: + from causal_conv1d import causal_conv1d_fn +except ImportError: + causal_conv1d_fn = None logger = get_logger() @@ -51,17 +56,19 @@ def build( class GatedDeltaNet(nn.Module): - def __init__(self, - hidden_size: int, - num_value_heads: int, - num_key_heads: int, - key_head_dim: int, - value_head_dim: int, - conv_kernel_dim: int, - hidden_act: str, - rms_norm_eps: float, - layer_idx: int = 0, - float8_cfg: Float8Config | None = None) -> None: + def __init__( + self, + hidden_size: int, + num_value_heads: int, + num_key_heads: int, + key_head_dim: int, + value_head_dim: int, + conv_kernel_dim: int, + hidden_act: str, + rms_norm_eps: float, + layer_idx: int = 0, + float8_cfg: Float8Config | None = None, + ) -> None: super().__init__() self.name = f"layers.{layer_idx}.gate_deltanet" self.float8_cfg = float8_cfg @@ -97,14 +104,18 @@ def __init__(self, A = torch.empty(self.num_v_heads).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) + assert causal_conv1d_fn is not None, ( + "causal_conv1d_fn is not available. Please install causal-conv1d to use GatedDeltaNet by `https://github.com/Dao-AILab/causal-conv1d`." + ) self.causal_conv1d_fn = causal_conv1d_fn + assert chunk_gated_delta_rule is not None, ( + "chunk_gated_delta_rule is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." + ) self.chunk_gated_delta_rule = chunk_gated_delta_rule - - self.norm = FusedRMSNormGated( - self.head_v_dim, - eps=self.rms_norm_eps, - activation=self.activation + assert FusedRMSNormGated is not None, ( + "FusedRMSNormGated is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." ) + self.norm = FusedRMSNormGated(self.head_v_dim, eps=self.rms_norm_eps, activation=self.activation) self.out_proj = build_linear( self.value_dim, @@ -137,15 +148,15 @@ def __init__(self, bias=False, float8_cfg=self.float8_cfg, ) - + def forward( self, hidden_states: torch.Tensor, seq_ctx: SequenceContext, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # not used + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # not used ) -> AttnOutputs: batch_size, seq_len, _ = hidden_states.shape - assert batch_size==1, "Only batch size of 1 is supported for now in GateDeltaNet" + assert batch_size == 1, "Only batch size of 1 is supported for now in GateDeltaNet" mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -154,7 +165,7 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - + mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), @@ -182,19 +193,19 @@ def forward( if self.num_v_heads // self.num_k_heads > 1: query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - + core_attn_out, _ = self.chunk_gated_delta_rule( - query, - key, - value, - g=g, - beta=beta, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=True, - cu_seqlens=seq_ctx.cu_seq_lens_q, - ) - + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=seq_ctx.cu_seq_lens_q, + ) + # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) z = z.reshape(-1, self.head_v_dim) @@ -206,7 +217,7 @@ def forward( "projected_output": output, } return attn_outputs - + @overload # type: ignore def __call__( # type: ignore self, @@ -216,4 +227,3 @@ def __call__( # type: ignore ) -> AttnOutputs: ... __call__ = nn.Module.__call__ - diff --git a/xtuner/v1/module/attention/mha.py b/xtuner/v1/module/attention/mha.py index 0ee245b97..8e6121019 100644 --- a/xtuner/v1/module/attention/mha.py +++ b/xtuner/v1/module/attention/mha.py @@ -38,7 +38,7 @@ class MHAConfig(BaseModel): qkv_bias: Annotated[bool, Parameter(group="attention")] = False qk_norm: bool = False rms_norm_eps: float = 1e-06 - rms_norm_type: Literal['default', 'zero_centered'] = 'default' + rms_norm_type: Literal["default", "zero_centered"] = "default" o_bias: Annotated[bool, Parameter(group="attention")] = False sliding_window: Annotated[int | None, Parameter(group="attention")] = -1 with_sink: Annotated[bool, Parameter(group="attention")] = False @@ -124,7 +124,7 @@ def __init__( qkv_bias: bool = False, qk_norm: bool = False, rms_norm_eps: float = 1e-6, - rms_norm_type: Literal['default', 'zero_centered'] = 'default', + rms_norm_type: Literal["default", "zero_centered"] = "default", o_bias: bool = False, with_sink: bool = False, with_gate: bool = False, @@ -158,7 +158,9 @@ def __init__( self.q_proj = build_linear( self.hidden_size, - self.num_attention_heads * self.head_dim if not with_gate else self.num_attention_heads * self.head_dim * 2, + self.num_attention_heads * self.head_dim + if not with_gate + else self.num_attention_heads * self.head_dim * 2, bias=self.qkv_bias, float8_cfg=self.float8_cfg, ) @@ -194,7 +196,9 @@ def __init__( self.window_size = (sliding_window, sliding_window) fope_sep_head = rope_scaling_cfg.fope_sep_head if rope_scaling_cfg is not None else None - enable_partial_rotary = rope_scaling_cfg.partial_rotary_factor != 1.0 if rope_scaling_cfg is not None else False + enable_partial_rotary = ( + rope_scaling_cfg.partial_rotary_factor != 1.0 if rope_scaling_cfg is not None else False + ) self.apply_rotary_emb = get_apply_rotary_emb(fope_sep_head, enable_partial_rotary=enable_partial_rotary) # type: ignore self.attn_impl_func: Callable[..., AttnOpOutputs] = attn_impl_mapping[attn_impl] # type: ignore[assignment] @@ -334,7 +338,7 @@ def forward( """ input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - + if self.with_gate: query_states, gate = torch.chunk( self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 @@ -424,8 +428,9 @@ def forward( raw_output = raw_output.reshape(*input_shape, -1).contiguous() if self.with_gate: + assert gate is not None raw_output = raw_output * torch.sigmoid(gate) - + projected_output = self.o_proj(raw_output) attn_outputs: AttnOutputs = { "projected_output": projected_output, diff --git a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py index c63f0eadd..e01f89ada 100644 --- a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py @@ -6,7 +6,7 @@ from xtuner.v1.config import GenerateConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.config import Float8Config -from xtuner.v1.module import AttnOutputs, MHAConfig, MLAConfig, RMSNorm +from xtuner.v1.module import AttnOutputs, GatedDeltaNetConfig, MHAConfig, MLAConfig, RMSNorm from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.act_fn import get_act_fn from xtuner.v1.utils import ForwardState @@ -44,8 +44,8 @@ def __init__( mlp_bias: bool = False, hidden_act: str, rms_norm_eps: float = 1e-6, - rms_norm_type: Literal['default', 'zero_centered'] = 'default', - attention_config: MLAConfig | MHAConfig, + rms_norm_type: Literal["default", "zero_centered"] = "default", + attention_config: MLAConfig | MHAConfig | GatedDeltaNetConfig, rope_scaling_cfg: RopeScalingConfig | None = None, generate_config: GenerateConfig | None = None, float8_cfg: Float8Config | None = None, @@ -69,7 +69,7 @@ def __init__( hidden_act=hidden_act, float8_cfg=float8_cfg, ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps,type=rms_norm_type) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) def forward( @@ -111,7 +111,7 @@ def prefilling( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states = self.self_attn.prefilling( + hidden_states = self.self_attn.prefilling( # type: ignore hidden_states=hidden_states, position_embeddings=position_embeddings, seq_ctx=seq_ctx, @@ -157,7 +157,7 @@ def decoding( def build_kv_cache( self, max_batch_size: int | None = None, max_length: int | None = None, block_size: int | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - return self.self_attn.build_kv_cache( + return self.self_attn.build_kv_cache( # type: ignore max_batch_size=max_batch_size, max_length=max_length, block_size=block_size, diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index 5a44cb4b5..8f865e84f 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -14,9 +14,13 @@ from xtuner.v1.float8 import Float8Config from xtuner.v1.module import ( AttnOutputs, + GatedDeltaNet, + GatedDeltaNetConfig, GreedyRouterConfig, MHAConfig, MLAConfig, + MultiHeadAttention, + MultiLatentAttention, NoAuxRouterConfig, RMSNorm, RouterResults, @@ -127,13 +131,14 @@ def forward( bias = self.bias.to_local() if isinstance(self.bias, DTensor) else self.bias bias = bias.float() - # logits = F.linear(hidden_states.float(), weight.float(), bias) - # return self.router(logits, rollout_routed_experts) - - logits = F.linear(hidden_states, weight, bias) - gate = self.router(logits, rollout_routed_experts) - gate['topk_weights'] = gate['topk_weights'].float() - return gate + logits = F.linear(hidden_states.float(), weight.float(), bias) + return self.router(logits, rollout_routed_experts) + + # Debug for aligning with hf implementation. + # logits = F.linear(hidden_states, weight, bias) + # gate = self.router(logits, rollout_routed_experts) + # gate['topk_weights'] = gate['topk_weights'].float() + # return gate class MoEBlock(nn.Module): @@ -194,13 +199,13 @@ def __init__( moe_bias: bool = False, hidden_act: str, rms_norm_eps: float = 1e-6, - rms_norm_type: Literal['default', 'zero_centered'] = 'default', + rms_norm_type: Literal["default", "zero_centered"] = "default", num_experts_per_tok: int, n_routed_experts: int, n_shared_experts: int, with_shared_expert_gate: bool = False, hidden_factor: float = 1.0, - attention_config: MHAConfig | MLAConfig, + attention_config: MHAConfig | MLAConfig | GatedDeltaNetConfig, rope_scaling_cfg: RopeScalingConfig | None = None, layer_type: Literal["full_attention", "sliding_attention"] | None = None, generate_config: GenerateConfig | None = None, @@ -218,7 +223,7 @@ def __init__( self.n_shared_experts = n_shared_experts self.hidden_factor = hidden_factor - self.self_attn = attention_config.build( + self.self_attn: MultiHeadAttention | MultiLatentAttention | GatedDeltaNet = attention_config.build( hidden_size=hidden_size, layer_idx=layer_idx, generate_config=generate_config, @@ -227,10 +232,12 @@ def __init__( float8_cfg=float8_cfg, ) self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) - self.shared_experts: MoEMLP | None self.layer_idx = layer_idx - + self.with_shared_expert_gate = with_shared_expert_gate + self.shared_expert_gate: nn.Module | None + self.shared_experts: MoEMLP | None + if n_shared_experts > 0: self.shared_experts = MoEMLP( hidden_size=hidden_size, @@ -317,6 +324,44 @@ def forward( position_embeddings_list=position_embeddings, ) + def _hf_expert_forward_for_debug(self, hidden_states: torch.Tensor, router_results: RouterResults, origin_shape): + # xtuner: num_experts * 2 * expert_dim, hidden_size + # hf: num_experts, 2 * expert_dim, hidden_size + origin_gate_up_proj = self.experts.fused_w1w3.weight + gate_up_proj = origin_gate_up_proj.view( + self.n_routed_experts, 2 * self.experts.intermediate_size, self.hidden_size + ) + + # xtuner: num_experts * hidden_size, expert_dim + # hf: num_experts, hidden_size, expert_dim + origin_down_proj = self.experts.fused_w2.weight + down_proj = origin_down_proj.view(self.n_routed_experts, self.hidden_size, self.experts.intermediate_size) + + from transformers.activations import ACT2FN + + act_fn = ACT2FN["silu"] + + hidden_states_reshaped = hidden_states.view(-1, hidden_states.size(-1)) + combined_hidden_states = torch.zeros_like(hidden_states_reshaped) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_results["topk_ids"], num_classes=self.n_routed_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.n_routed_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states_reshaped[token_idx] + gate, up = nn.functional.linear(current_state, gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, down_proj[expert_idx]) + current_hidden_states = current_hidden_states * router_results["topk_weights"][token_idx, top_k_pos, None] + combined_hidden_states.index_add_(0, token_idx, current_hidden_states.to(combined_hidden_states.dtype)) + + combined_hidden_states = combined_hidden_states.view(*origin_shape) + return combined_hidden_states + def _forward( self, hidden_states: torch.Tensor, @@ -391,40 +436,9 @@ def _forward( ) combined_hidden_states = post_combined["hidden_states"] combined_hidden_states = combined_hidden_states.view(*origin_shape) - - # # 对齐 hf 实现用 - # # xtuner: num_experts * 2 * expert_dim, hidden_size - # # hf: num_experts, 2 * expert_dim, hidden_size - # origin_gate_up_proj = self.experts.fused_w1w3.weight - # gate_up_proj = origin_gate_up_proj.view(self.n_routed_experts, 2 * self.experts.intermediate_size, self.hidden_size) - - # # xtuner: num_experts * hidden_size, expert_dim - # # hf: num_experts, hidden_size, expert_dim - # origin_down_proj = self.experts.fused_w2.weight - # down_proj = origin_down_proj.view(self.n_routed_experts, self.hidden_size, self.experts.intermediate_size) - - # from transformers.activations import ACT2FN - # act_fn = ACT2FN['silu'] - - # hidden_states_reshaped = hidden_states.view(-1, hidden_states.size(-1)) - # combined_hidden_states = torch.zeros_like(hidden_states_reshaped) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_results["topk_ids"], num_classes=self.n_routed_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hit: - # expert_idx = expert_idx[0] - # if expert_idx == self.n_routed_experts: - # continue - # top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - # current_state = hidden_states_reshaped[token_idx] - # gate, up = nn.functional.linear(current_state, gate_up_proj[expert_idx]).chunk(2, dim=-1) - # current_hidden_states = act_fn(gate) * up - # current_hidden_states = nn.functional.linear(current_hidden_states, down_proj[expert_idx]) - # current_hidden_states = current_hidden_states * router_results["topk_weights"][token_idx, top_k_pos, None] - # combined_hidden_states.index_add_(0, token_idx, current_hidden_states.to(combined_hidden_states.dtype)) - - # combined_hidden_states = combined_hidden_states.view(*origin_shape) + + # debug for aligning with hf implementation. + # combined_hidden_states = self._hf_expert_forward_for_debug(hidden_states, router_results, origin_shape) # ProberList.after_combine(self.layer_idx, combined_hidden_states) @@ -597,7 +611,7 @@ def _pre_moe_forward( hidden_states = attn_outputs["projected_output"] elif state == ForwardState.PREFILLING: assert past_key_values is not None, "past_key_values should be provided in pre-filling state" - hidden_states = self.self_attn.prefilling( + hidden_states = self.self_attn.prefilling( # type: ignore hidden_states=hidden_states, position_embeddings=position_embeddings, seq_ctx=seq_ctx, @@ -605,7 +619,7 @@ def _pre_moe_forward( ) elif state == ForwardState.DECODING: assert past_key_values is not None, "past_key_values should be provided in decoding state" - hidden_states = self.self_attn.decoding( + hidden_states = self.self_attn.decoding( # type: ignore hidden_states=hidden_states, position_embeddings=position_embeddings, seq_ctx=seq_ctx, @@ -632,8 +646,11 @@ def _shared_experts_forward( shared_experts_out = self.shared_experts(hidden_states) if self.with_shared_expert_gate: - shared_experts_out = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_experts_out - + assert self.shared_expert_gate is not None, ( + "Shared expert gate should be initialized when with_shared_expert_gate is True" + ) + shared_experts_out = torch.sigmoid(self.shared_expert_gate(hidden_states)) * shared_experts_out + return shared_experts_out def _post_moe_forward( @@ -650,7 +667,7 @@ def _post_moe_forward( def build_kv_cache( self, max_batch_size: int | None = None, max_length: int | None = None, block_size: int | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - return self.self_attn.build_kv_cache( + return self.self_attn.build_kv_cache( # type: ignore max_batch_size=max_batch_size, max_length=max_length, block_size=block_size, diff --git a/xtuner/v1/module/rms_norm/rms_norm.py b/xtuner/v1/module/rms_norm/rms_norm.py index 407688a92..11ccfb1d9 100644 --- a/xtuner/v1/module/rms_norm/rms_norm.py +++ b/xtuner/v1/module/rms_norm/rms_norm.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Literal + import torch from torch import nn -from typing import Literal from torch.distributed.tensor import DTensor from xtuner.v1.ops import rms_norm, zero_centered_rms_norm @@ -10,19 +11,19 @@ class RMSNorm(nn.Module): weight: torch.Tensor - def __init__(self, hidden_size: int, eps: float = 1e-6, type: Literal['default', 'zero_centered'] = 'default'): + def __init__(self, hidden_size: int, eps: float = 1e-6, type: Literal["default", "zero_centered"] = "default"): """RMSNorm is equivalent to T5LayerNorm.""" super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self._type = type - if type == 'default': + if type == "default": self.rms_norm_fn = rms_norm - elif type == 'zero_centered': + elif type == "zero_centered": self.rms_norm_fn = zero_centered_rms_norm else: - raise ValueError(f'Unsupported RMSNorm type: {type}') + raise ValueError(f"Unsupported RMSNorm type: {type}") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if isinstance(self.weight, DTensor): diff --git a/xtuner/v1/module/rope/rope.py b/xtuner/v1/module/rope/rope.py index 9997b7faf..0b952b949 100644 --- a/xtuner/v1/module/rope/rope.py +++ b/xtuner/v1/module/rope/rope.py @@ -1,5 +1,4 @@ -from typing import Literal, Protocol, cast, Optional, Callable - +from typing import Callable, Literal, Optional, Protocol, cast import torch import torch.nn as nn @@ -63,25 +62,24 @@ def to(self, device: torch.device) -> Self: ... def compute_default_rope_parameters( - config, - device: Optional["torch.device"] = None, - ) -> tuple["torch.Tensor", float]: - base = config.rope_theta - if config.rope_scaling_cfg is not None: - rope_scaling_cfg: RopeScalingConfig = config.rope_scaling_cfg - partial_rotary_factor = getattr(rope_scaling_cfg, "partial_rotary_factor", 1.0) - else: - partial_rotary_factor = 1.0 - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) + config, + device: Optional["torch.device"] = None, +) -> tuple["torch.Tensor", float]: + base = config.rope_theta + if config.rope_scaling_cfg is not None: + rope_scaling_cfg: RopeScalingConfig = config.rope_scaling_cfg + partial_rotary_factor = getattr(rope_scaling_cfg, "partial_rotary_factor", 1.0) + else: + partial_rotary_factor = 1.0 + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) - attention_factor = 1.0 # Unused in this type of RoPE + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor class RotaryEmbedding(nn.Module): inv_freq: torch.Tensor @@ -104,7 +102,7 @@ def __init__(self, config, device=None): f"Unsupported rope_type: {self.rope_type}. Supported types are: 'default', 'linear', 'yarn', 'llama3'." ) - # The implementation of RoPE has been refactored in Transformers V5, and + # The implementation of RoPE has been refactored in Transformers V5, and # the following approach is used for compatibility. self.rope_init_fn: Callable = compute_default_rope_parameters if self.rope_type != "default": @@ -310,8 +308,8 @@ def __init__(self, config, device=None): self.original_max_seq_len = config.max_position_embeddings self.rope_type = "default" self.config = config - - # The implementation of RoPE has been refactored in Transformers V5, and + + # The implementation of RoPE has been refactored in Transformers V5, and # the following approach is used for compatibility. self.rope_init_fn: Callable = compute_default_rope_parameters if self.rope_type != "default": @@ -325,7 +323,7 @@ def __init__(self, config, device=None): self.mrope_section = config.rope_scaling_cfg.mrope_section assert self.mrope_section is not None - + def apply_interleaved_mrope(self, freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. diff --git a/xtuner/v1/ops/rms_norm/__init__.py b/xtuner/v1/ops/rms_norm/__init__.py index 95cfa41ca..7da199d1d 100644 --- a/xtuner/v1/ops/rms_norm/__init__.py +++ b/xtuner/v1/ops/rms_norm/__init__.py @@ -22,10 +22,12 @@ def _triton_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> t return rms_norm_fn(x, weight, bias=None, eps=epsilon) + def native_zero_centered_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: # TODO: is native_rms_norm ? def _norm(x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + epsilon) + output = _norm(x.float()) # Llama does x.to(float16) * w whilst Qwen3_5Moe is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 @@ -48,6 +50,7 @@ def get_rms_norm_fn() -> RMSNormProtocol: else: raise NotImplementedError(f"RMSNorm is not implemented on {device}") + def get_zero_centered_rms_norm_fn() -> RMSNormProtocol: from xtuner.v1.utils import get_device @@ -65,4 +68,4 @@ def get_zero_centered_rms_norm_fn() -> RMSNormProtocol: rms_norm = get_rms_norm_fn() -zero_centered_rms_norm = get_zero_centered_rms_norm_fn() \ No newline at end of file +zero_centered_rms_norm = get_zero_centered_rms_norm_fn() diff --git a/xtuner/v1/ops/rms_norm/protocol.py b/xtuner/v1/ops/rms_norm/protocol.py index 426f7d683..1ce7b5ea3 100644 --- a/xtuner/v1/ops/rms_norm/protocol.py +++ b/xtuner/v1/ops/rms_norm/protocol.py @@ -1,8 +1,7 @@ from typing import Protocol -from typing_extensions import Literal import torch class RMSNormProtocol(Protocol): - def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float, type: Literal['default', 'zero_centered']) -> torch.Tensor: ... + def __call__(self, x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: ... diff --git a/xtuner/v1/ops/rotary_emb.py b/xtuner/v1/ops/rotary_emb.py index 917e9dc2a..ebca8c184 100644 --- a/xtuner/v1/ops/rotary_emb.py +++ b/xtuner/v1/ops/rotary_emb.py @@ -49,9 +49,16 @@ def apply_rotary_pos_emb_cuda( return q_embed, k_embed -# Note: Although this function is compatible with apply_rotary_pos_emb_cuda, +# Note: Although this function is compatible with apply_rotary_pos_emb_cuda, # it is still recommended to separate them into two for efficiency considerations. -def apply_rotary_pos_emb_cuda_for_partial_rotary(q, k, cos, sin, unsqueeze_dim=1): +def apply_rotary_pos_emb_cuda_for_partial_rotary( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor | None = None, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) @@ -141,8 +148,9 @@ def __call__( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def get_apply_rotary_emb(fope_sep_head: bool | None = False, - enable_partial_rotary: bool= False) -> ApplyRotaryEmbProtocol: +def get_apply_rotary_emb( + fope_sep_head: bool | None = False, enable_partial_rotary: bool = False +) -> ApplyRotaryEmbProtocol: from xtuner.v1.utils.device import get_device device = get_device() From cebb742c45235b106de4909927dab1b9622b237e Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Mon, 2 Mar 2026 09:03:33 +0000 Subject: [PATCH 08/10] support compile and fp8 --- xtuner/v1/model/moe/qwen3_5_text.py | 36 ++++++++++ xtuner/v1/module/attention/gated_deltanet.py | 68 ++++++++++++++----- .../module/decoder_layer/moe_decoder_layer.py | 2 +- 3 files changed, 88 insertions(+), 18 deletions(-) diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py index a7309771c..3afc4e5dd 100644 --- a/xtuner/v1/model/moe/qwen3_5_text.py +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -3,7 +3,12 @@ import torch from pydantic import computed_field +from typing_extensions import override +from xtuner.v1.model.base import ( + DEFAULT_FLOAT8_CFG, + TorchCompileOption, +) from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig from xtuner.v1.module.rope import RopeScalingConfig @@ -12,6 +17,29 @@ from .qwen3vl_text import Qwen3VLTextMoE +MOE_NON_EP_COMPILE_CFG: dict[str, TorchCompileOption] = { + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEBlock.forward": TorchCompileOption(fullgraph=True), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer.forward": TorchCompileOption(fullgraph=False), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._pre_moe_forward": TorchCompileOption( + fullgraph=False + ), + "xtuner.v1.module.attention.mha.MultiHeadAttention.forward": TorchCompileOption(fullgraph=True), + # TODO: GatedDeltaNet does not currently support torch.compile(full_graph=True); support will be added in the future. + "xtuner.v1.module.attention.gated_deltanet.GatedDeltaNet.forward": TorchCompileOption(fullgraph=False), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._shared_experts_forward": TorchCompileOption( + fullgraph=True + ), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._post_moe_forward": TorchCompileOption( + fullgraph=True + ), + "xtuner.v1.module.decoder_layer.dense_decoder_layer.DenseDecoderLayer.forward": TorchCompileOption(fullgraph=True), + **DEFAULT_FLOAT8_CFG, +} + +MOE_EP_COMPILE_CFG = MOE_NON_EP_COMPILE_CFG.copy() +MOE_EP_COMPILE_CFG.pop("xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer.forward") + + class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): def to_hf_key_list(self, key: str) -> list[str]: if "layers" in key or "embed_tokens" in key: @@ -105,6 +133,14 @@ def param_to_safetensor( safetensor = safetensor.reshape(num_experts, -1, expert_dim).contiguous() return safetensor + @property + @override + def default_compile_cfg(self) -> dict[str, TorchCompileOption]: + if self.config.ep_size > 1: + return MOE_EP_COMPILE_CFG + else: + return MOE_NON_EP_COMPILE_CFG + class Qwen3_5_VLTextMoEConfig(MoEConfig): with_shared_expert_gate: bool = True diff --git a/xtuner/v1/module/attention/gated_deltanet.py b/xtuner/v1/module/attention/gated_deltanet.py index a885d89e8..0f74a9481 100644 --- a/xtuner/v1/module/attention/gated_deltanet.py +++ b/xtuner/v1/module/attention/gated_deltanet.py @@ -7,6 +7,7 @@ from cyclopts import Parameter from pydantic import BaseModel, ConfigDict from torch import nn +from torch.distributed.tensor import DTensor from typing_extensions import overload from xtuner.v1.data_proto import SequenceContext @@ -18,10 +19,37 @@ try: - from fla.modules import FusedRMSNormGated + from fla.modules import FusedRMSNormGated as FLA_FusedRMSNormGated + from fla.modules.fused_norm_gate import rms_norm_gated from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + class FusedRMSNormGated(FLA_FusedRMSNormGated): + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + residual: torch.Tensor | None = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + ) -> torch.Tensor: + weight = self.weight + if isinstance(weight, DTensor): + weight = weight.to_local() + + return rms_norm_gated( + x, + g, + weight, + self.bias, + self.activation, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + except ImportError: - FusedRMSNormGated = None + FusedRMSNormGated = None # type: ignore chunk_gated_delta_rule = None try: @@ -136,18 +164,8 @@ def __init__( bias=False, float8_cfg=self.float8_cfg, ) - self.in_proj_b = build_linear( - self.hidden_size, - self.num_v_heads, - bias=False, - float8_cfg=self.float8_cfg, - ) - self.in_proj_a = build_linear( - self.hidden_size, - self.num_v_heads, - bias=False, - float8_cfg=self.float8_cfg, - ) + self.in_proj_b = build_linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = build_linear(self.hidden_size, self.num_v_heads, bias=False) def forward( self, @@ -166,10 +184,17 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) + weight = self.conv1d.weight.squeeze(1) + bias = self.conv1d.bias + if isinstance(weight, DTensor): + weight = weight.to_local() + if bias and isinstance(bias, DTensor): + bias = bias.to_local() + mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, - weight=self.conv1d.weight.squeeze(1), - bias=self.conv1d.bias, + weight=weight, + bias=bias, activation=self.activation, seq_idx=seq_ctx.seq_idx, ) @@ -189,7 +214,15 @@ def forward( beta = b.sigmoid() # If the model is loaded in fp16, without the .float() here, A might be -inf - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + A_log = self.A_log + dt_bias = self.dt_bias + if isinstance(A_log, DTensor): + A_log = A_log.to_local() + if isinstance(dt_bias, DTensor): + dt_bias = dt_bias.to_local() + + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + if self.num_v_heads // self.num_k_heads > 1: query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) @@ -209,6 +242,7 @@ def forward( # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index 8f865e84f..4e72751da 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -248,7 +248,7 @@ def __init__( float8_cfg=float8_cfg, ) if with_shared_expert_gate: - self.shared_expert_gate = build_linear(hidden_size, 1, bias=False, float8_cfg=float8_cfg) + self.shared_expert_gate = build_linear(hidden_size, 1, bias=False) else: self.shared_experts = None self.shared_expert_gate = None From 815c9b3c1bec7921aa96e72ee5eb0598ab453bba Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Mon, 2 Mar 2026 09:22:02 +0000 Subject: [PATCH 09/10] add flash-linear-attention dep --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6fa81c006..c24d0c6d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,8 @@ all = [ "liger-kernel", "parametrize", "mathruler", - "pylatexenc" + "pylatexenc", + "flash-linear-attention" ] [tool.mypy] From 9e595f944089ff43124c3db3b42b32298eb194c7 Mon Sep 17 00:00:00 2001 From: "huanghaian@pjlab.org.cn" Date: Mon, 2 Mar 2026 12:53:29 +0000 Subject: [PATCH 10/10] update seq_idx --- xtuner/v1/data_proto/sequence_context.py | 10 ++-------- xtuner/v1/module/attention/gated_deltanet.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index 8c217400b..4c4aad06c 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -36,6 +36,7 @@ class SequenceContext: block_table: torch.Tensor | None device: str | torch.device # TODO: 这个地方有点乱,到处是 device position_ids: torch.LongTensor | None + seq_idx: torch.IntTensor | None # Qwen3VL image_grid_thw: torch.Tensor | None @@ -98,16 +99,11 @@ def __init__( self.inputs_embeds = inputs_embeds self.num_img_tokens = num_img_tokens self.rollout_routed_experts = rollout_routed_experts + self.seq_idx = None seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] seq_lens_q = self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] - # Used for causal_conv1d varlen. It cannot be calculated in the compile function of linear attention, - # and needs to be calculated in advance. - self.seq_idx = torch.cat( - [torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seq_lens_q)], dim=0 - )[None] - if position_ids is None: _position_ids = [torch.arange(k - q, k) for q, k in zip(seq_lens_q, seq_lens_k)] position_ids = torch.cat(_position_ids).unsqueeze(0).to(self.cu_seq_lens_k.device) # type: ignore[assignment] @@ -383,11 +379,9 @@ def to(self, device: torch.device | str): if device == "npu" or isinstance(device, torch.device) and device.type == "npu": self.cu_seq_lens_q = self.cu_seq_lens_q.cpu() # type: ignore self.cu_seq_lens_k = self.cu_seq_lens_k.cpu() # type: ignore - self.seq_idx = self.seq_idx.cpu() # type: ignore else: self.cu_seq_lens_q = self.cu_seq_lens_q.to(device) # type: ignore self.cu_seq_lens_k = self.cu_seq_lens_k.to(device) # type: ignore - self.seq_idx = self.seq_idx.to(device) # type: ignore if self.position_ids is not None and hasattr(self.position_ids, "to"): self.position_ids = self.position_ids.to(device) # type: ignore diff --git a/xtuner/v1/module/attention/gated_deltanet.py b/xtuner/v1/module/attention/gated_deltanet.py index 0f74a9481..c88dcc9b8 100644 --- a/xtuner/v1/module/attention/gated_deltanet.py +++ b/xtuner/v1/module/attention/gated_deltanet.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Annotated +from typing import Annotated, cast import torch import torch.nn.functional as F @@ -191,12 +191,25 @@ def forward( if bias and isinstance(bias, DTensor): bias = bias.to_local() + # TODO: If full_graph mode is supported in the future, it needs to be modified to custom_op + if seq_ctx.seq_idx is None: + seq_idx = torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=mixed_qkv.device) + for i, s in enumerate(seq_ctx.seq_lens_q) + ], + dim=0, + )[None] + seq_ctx.seq_idx = cast(torch.IntTensor, seq_idx) + else: + seq_idx = seq_ctx.seq_idx + mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, weight=weight, bias=bias, activation=self.activation, - seq_idx=seq_ctx.seq_idx, + seq_idx=seq_idx, ) mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split(