diff --git a/pyproject.toml b/pyproject.toml index 6fa81c006..c24d0c6d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,8 @@ all = [ "liger-kernel", "parametrize", "mathruler", - "pylatexenc" + "pylatexenc", + "flash-linear-attention" ] [tool.mypy] diff --git a/xtuner/v1/data_proto/messages/chat.py b/xtuner/v1/data_proto/messages/chat.py index 69a098d25..fcc4adc64 100644 --- a/xtuner/v1/data_proto/messages/chat.py +++ b/xtuner/v1/data_proto/messages/chat.py @@ -6,10 +6,9 @@ from pydantic import BaseModel, ConfigDict from transformers import PreTrainedTokenizer -from xtuner.utils import IGNORE_INDEX from xtuner.v1.data_proto.messages.base import BaseMessages from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate -from xtuner.v1.utils import get_logger +from xtuner.v1.utils import IGNORE_INDEX, get_logger logger = get_logger() diff --git a/xtuner/v1/data_proto/sequence_context.py b/xtuner/v1/data_proto/sequence_context.py index fc91e4cb4..4c4aad06c 100644 --- a/xtuner/v1/data_proto/sequence_context.py +++ b/xtuner/v1/data_proto/sequence_context.py @@ -36,6 +36,7 @@ class SequenceContext: block_table: torch.Tensor | None device: str | torch.device # TODO: 这个地方有点乱,到处是 device position_ids: torch.LongTensor | None + seq_idx: torch.IntTensor | None # Qwen3VL image_grid_thw: torch.Tensor | None @@ -98,6 +99,7 @@ def __init__( self.inputs_embeds = inputs_embeds self.num_img_tokens = num_img_tokens self.rollout_routed_experts = rollout_routed_experts + self.seq_idx = None seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] seq_lens_q = self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] diff --git a/xtuner/v1/datasets/collator.py b/xtuner/v1/datasets/collator.py index b19d67f34..62b2916cb 100644 --- a/xtuner/v1/datasets/collator.py +++ b/xtuner/v1/datasets/collator.py @@ -247,6 +247,8 @@ def qwen3_vl_sft_collator( if len(position_ids_list) > 0: position_ids = torch.cat(position_ids_list, dim=-1) position_ids = position_ids[:, :, :-1] + if pack_to_max_length and pack_max_length - position_ids.shape[-1] > 0: + position_ids = pad_to_max_length(position_ids, 0, max_length=pack_max_length, dim=-1) num_img_tokens: list[int] = [] for data in instance: diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index 3f77cd9b8..b0744864f 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -11,6 +11,7 @@ InternVL3P5MoE30BA3Config, InternVLBaseConfig, ) +from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config from .compose.qwen3_vl import ( Qwen3VLDense4BConfig, Qwen3VLDense8BConfig, @@ -98,4 +99,5 @@ def get_model_config_from_hf(model_path: Path): "TorchCompileOption", "DEFAULT_FLOAT8_CFG", "XTunerBaseModelConfig", + "Qwen3_5_VLMoE35BA3Config", ] diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 9f1e0f995..c4a61fa35 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -39,7 +39,7 @@ WeightWithDynamicTilewiseFloat8CastTensor, ) from xtuner.v1.loss import BaseLossContext -from xtuner.v1.module.attention import MHAConfig, MLAConfig +from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig, MLAConfig from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory @@ -147,9 +147,11 @@ class TransformerConfig(XTunerBaseModelConfig): hidden_size: Annotated[int, Parameter(group="model")] intermediate_size: Annotated[int, Parameter(group="model")] rms_norm_eps: Annotated[float, Parameter(group="model")] + rms_norm_type: Annotated[Literal["default", "zero_centered"], Parameter(group="model")] = "default" rope_theta: Annotated[float, Parameter(group="model")] # required by transformers's build rope hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS` attention: MLAConfig | MHAConfig + linear_attention: Annotated[GatedDeltaNetConfig | None, Parameter(group="model")] = None mlp_bias: Annotated[bool, Parameter(group="model")] = False tie_word_embeddings: Annotated[bool, Parameter(group="model")] = False model_type: Annotated[str | None, Parameter(group="model")] = None # TODO: yehaochen maybe should be removed diff --git a/xtuner/v1/model/compose/qwen3_5/__init__.py b/xtuner/v1/model/compose/qwen3_5/__init__.py new file mode 100644 index 000000000..e9b332168 --- /dev/null +++ b/xtuner/v1/model/compose/qwen3_5/__init__.py @@ -0,0 +1,6 @@ +from .qwen3_5_config import Qwen3_5_VLMoE35BA3Config + + +__all__ = [ + "Qwen3_5_VLMoE35BA3Config", +] diff --git a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py new file mode 100644 index 000000000..884cd48b1 --- /dev/null +++ b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py @@ -0,0 +1,33 @@ +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig +from xtuner.v1.utils import get_logger + +from ..qwen3_vl.qwen3_vl_config import Qwen3VLBaseConfig, Qwen3VLProjectorConfig, Qwen3VLVisionConfig + + +logger = get_logger() + + +class Qwen3_5_VisionConfig(Qwen3VLVisionConfig): + deepstack_visual_indexes: list[int] = [] + + +class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig): + deepstack_visual_indexes: list[int] = [] + + +class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): + vision_config: Qwen3_5_VisionConfig + projector_config: Qwen3_5_ProjectorConfig + text_config: TransformerConfig + + image_token_id: int = 248056 + video_token_id: int = 248057 + vision_start_token_id: int = 248053 + vision_end_token_id: int = 248054 + + +class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig): + vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() + projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() + text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig() diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py index 865cbcfb1..06706800a 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py @@ -182,8 +182,10 @@ def forward( if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0: assert seq_ctx.position_ids is not None - assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \ - f" but got {seq_ctx.position_ids.ndim}" + assert seq_ctx.position_ids.ndim in (2, 3), ( + f"position_ids must be 2-dim or 3-dim when deepstack_visual_embeds is None," + f" but got {seq_ctx.position_ids.ndim}" + ) deepstack_visual_embeds = None visual_pos_masks = None diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index cde398b38..2cca9da86 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -29,7 +29,15 @@ TransformerConfig, ) from xtuner.v1.model.utils import checkpoint_wrapper -from xtuner.v1.module import LMHead, RMSNorm, RotaryEmbeddingProtocol, get_rope_embedding +from xtuner.v1.module import ( + GatedDeltaNetConfig, + LMHead, + MHAConfig, + MLAConfig, + RMSNorm, + RotaryEmbeddingProtocol, + get_rope_embedding, +) from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer from xtuner.v1.utils import ( get_device, @@ -53,7 +61,7 @@ class Dense(BaseModel): def __init__(self, config: TransformerConfig): super().__init__(config) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, type=config.rms_norm_type) self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) self.layers = self.build_layers(config) self.rotary_emb = self.build_rotary_embedding(config) @@ -116,14 +124,27 @@ def build_layers(self, config: TransformerConfig) -> nn.ModuleDict: # 让 layers 是一个 nn.ModuleDict 方便做 pipeline parallel 的参数切分, # 这样可以保证部分 layer 被切掉后,idx 保持不变 layers = nn.ModuleDict() + attention_config: GatedDeltaNetConfig | MLAConfig | MHAConfig | None = None for layer_idx in range(config.num_hidden_layers): + if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]: + attention_config = config.attention + elif config.layers_type[layer_idx] == "linear_attention": + attention_config = config.linear_attention + assert attention_config is not None, ( + "linear_attention config must be provided for linear_attention layer" + ) + else: + raise ValueError( + f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported." + ) + layers[str(layer_idx)] = DenseDecoderLayer( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, mlp_bias=config.mlp_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, - attention_config=config.attention, + attention_config=attention_config, generate_config=config.generate_config, rope_scaling_cfg=config.rope_scaling_cfg, float8_cfg=config.float8_cfg, diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 962129e1d..4b96c8f6d 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -37,8 +37,11 @@ ) from xtuner.v1.model.utils import ModelForwardExtraLogInfo, checkpoint_wrapper, module_dict_repr from xtuner.v1.module import ( + GatedDeltaNetConfig, GreedyRouterConfig, LMHead, + MHAConfig, + MLAConfig, NoAuxRouter, NoAuxRouterConfig, RMSNorm, @@ -116,6 +119,7 @@ class MoEConfig(TransformerConfig): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") n_routed_experts: Annotated[int, Parameter(group="moe")] n_shared_experts: Annotated[int, Parameter(group="moe")] + with_shared_expert_gate: bool = False # enable when n_shared_experts > 0 num_experts_per_tok: Annotated[int, Parameter(group="moe")] first_k_dense_replace: Annotated[int, Parameter(group="moe")] = 0 hidden_factor: Annotated[float, Parameter(group="moe")] = 1.0 @@ -166,7 +170,7 @@ def __init__(self, config: MoEConfig): else: self.ep_mesh = None - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, type=config.rms_norm_type) self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False) self.layers = self.build_layers(config) @@ -592,7 +596,20 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: # 让 layers 是一个 nn.ModuleDict 方便做 pipeline parallel 的参数切分, # 这样可以保证部分 layer 被切掉后,idx 保持不变 layers = nn.ModuleDict() + attention_config: GatedDeltaNetConfig | MLAConfig | MHAConfig | None = None for layer_idx in range(config.num_hidden_layers): + if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]: + attention_config = config.attention + elif config.layers_type[layer_idx] == "linear_attention": + attention_config = config.linear_attention + assert attention_config is not None, ( + "linear_attention config must be provided for linear_attention layer" + ) + else: + raise ValueError( + f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported." + ) + if layer_idx < config.first_k_dense_replace: layers[str(layer_idx)] = DenseDecoderLayer( hidden_size=config.hidden_size, @@ -600,7 +617,8 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: mlp_bias=config.mlp_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, - attention_config=config.attention, + rms_norm_type=config.rms_norm_type, + attention_config=attention_config, layer_type=config.layers_type[layer_idx], rope_scaling_cfg=config.rope_scaling_cfg, generate_config=config.generate_config, @@ -617,12 +635,14 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict: moe_bias=config.moe_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, num_experts_per_tok=config.num_experts_per_tok, n_routed_experts=config.n_routed_experts, n_shared_experts=config.n_shared_experts, + with_shared_expert_gate=config.with_shared_expert_gate, hidden_factor=config.hidden_factor, layer_type=config.layers_type[layer_idx], - attention_config=config.attention, + attention_config=attention_config, rope_scaling_cfg=config.rope_scaling_cfg, generate_config=config.generate_config, router_config=config.router, diff --git a/xtuner/v1/model/moe/qwen3_5_text.py b/xtuner/v1/model/moe/qwen3_5_text.py new file mode 100644 index 000000000..3afc4e5dd --- /dev/null +++ b/xtuner/v1/model/moe/qwen3_5_text.py @@ -0,0 +1,207 @@ +import re +from typing import Literal + +import torch +from pydantic import computed_field +from typing_extensions import override + +from xtuner.v1.model.base import ( + DEFAULT_FLOAT8_CFG, + TorchCompileOption, +) +from xtuner.v1.model.moe.moe import BalancingLossConfig, MoEConfig, ZLossConfig +from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig +from xtuner.v1.module.rope import RopeScalingConfig +from xtuner.v1.module.router.greedy import GreedyRouterConfig + +from .qwen3vl_text import Qwen3VLTextMoE + + +MOE_NON_EP_COMPILE_CFG: dict[str, TorchCompileOption] = { + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEBlock.forward": TorchCompileOption(fullgraph=True), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer.forward": TorchCompileOption(fullgraph=False), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._pre_moe_forward": TorchCompileOption( + fullgraph=False + ), + "xtuner.v1.module.attention.mha.MultiHeadAttention.forward": TorchCompileOption(fullgraph=True), + # TODO: GatedDeltaNet does not currently support torch.compile(full_graph=True); support will be added in the future. + "xtuner.v1.module.attention.gated_deltanet.GatedDeltaNet.forward": TorchCompileOption(fullgraph=False), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._shared_experts_forward": TorchCompileOption( + fullgraph=True + ), + "xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer._post_moe_forward": TorchCompileOption( + fullgraph=True + ), + "xtuner.v1.module.decoder_layer.dense_decoder_layer.DenseDecoderLayer.forward": TorchCompileOption(fullgraph=True), + **DEFAULT_FLOAT8_CFG, +} + +MOE_EP_COMPILE_CFG = MOE_NON_EP_COMPILE_CFG.copy() +MOE_EP_COMPILE_CFG.pop("xtuner.v1.module.decoder_layer.moe_decoder_layer.MoEDecoderLayer.forward") + + +class Qwen3_5_VLTextMoE(Qwen3VLTextMoE): + def to_hf_key_list(self, key: str) -> list[str]: + if "layers" in key or "embed_tokens" in key: + key = "model.language_model." + key + + if "layers" in key: + key = re.sub(r"layers\.(\d+)\.(experts|gate|shared_experts|shared_expert_gate)", r"layers.\1.mlp.\2", key) + key = key.replace("shared_experts", "shared_expert") + + layer_idx = int(re.findall(r"layers\.(\d+)\.", key)[0]) + if self.config.layers_type[layer_idx] == "linear_attention": + key = key.replace("self_attn", "linear_attn") + + if "fused_w1w3.weight" in key: + key = key.replace("fused_w1w3.weight", "gate_up_proj") + elif "fused_w2.weight" in key: + key = key.replace("fused_w2.weight", "down_proj") + if "fused_w1w3.bias" in key: + key = key.replace("fused_w1w3.bias", "gate_up_proj_bias") + elif "fused_w2.bias" in key: + key = key.replace("fused_w2.bias", "down_proj_bias") + + if key.startswith("norm."): + return [key.replace("norm.", "model.language_model.norm.")] + elif key.startswith("rotary_emb."): + # FoPE has model.rotary_emb.sin_coef and model.rotary_emb.cos_coef in the safetensors + return [key.replace("rotary_emb.", "model.language_model.rotary_emb.")] + else: + return [key] + + def safetensors_to_params( + self, + safetensors: list[torch.Tensor], + local_tensor: torch.Tensor, + param_name: str, + start: int | None, + end: int | None, + dim: int | None, + ): + if len(safetensors) > 1: + assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1" + loaded_tensor = torch.cat(safetensors, dim=dim) + else: + loaded_tensor = safetensors[0] + + if "fused_w1w3.weight" in param_name: + # hf: num_experts, 2 * expert_dim, hidden_size + # xtuner: num_experts * 2 * expert_dim, hidden_size + # num_experts * 2 * expert_dim, hidden_size + loaded_tensor = loaded_tensor.flatten(0, 1) + + elif "fused_w2.weight" in param_name: + # hf: num_experts, hidden_size, expert_dim + # xtuner: num_experts * hidden_size, expert_dim + loaded_tensor = loaded_tensor.flatten(0, 1) + + if start is not None and end is not None: + start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM]) + end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM]) + loaded_tensor_slice = loaded_tensor.index_select( + dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device) + ) + non_pad_len = end - start + local_tensor[:non_pad_len].copy_(loaded_tensor_slice) + + if non_pad_len < local_tensor.shape[self.FSDP_SHARD_DIM]: + assert self.config.float8_cfg is not None + local_tensor[non_pad_len:].copy_(0.0) # type: ignore # padded part must be set to 0 + else: + local_tensor.copy_(loaded_tensor) + + def param_to_safetensor( + self, + safetensor: torch.Tensor, + hf_param_name: str, + ): + assert isinstance(hf_param_name, str) + if "gate_up_proj" in hf_param_name: + # xtuner: num_experts * 2 * expert_dim, hidden_size + # hf: num_experts, 2 * expert_dim, hidden_size + num_experts = self.config.n_routed_experts + hidden_size = safetensor.size(1) + safetensor = safetensor.reshape( + num_experts, -1, hidden_size + ).contiguous() # num_experts, 2 * expert_dim, hidden_size + elif "down_proj" in hf_param_name and "shared_expert" not in hf_param_name: + # xtuner: num_experts * hidden_size, expert_dim + # hf: num_experts, hidden_size, expert_dim + num_experts = self.config.n_routed_experts + expert_dim = safetensor.size(1) + safetensor = safetensor.reshape(num_experts, -1, expert_dim).contiguous() + return safetensor + + @property + @override + def default_compile_cfg(self) -> dict[str, TorchCompileOption]: + if self.config.ep_size > 1: + return MOE_EP_COMPILE_CFG + else: + return MOE_NON_EP_COMPILE_CFG + + +class Qwen3_5_VLTextMoEConfig(MoEConfig): + with_shared_expert_gate: bool = True + rms_norm_type: Literal["default", "zero_centered"] = "zero_centered" + + @computed_field + def layers_type(self) -> list[Literal["full_attention", "linear_attention"]]: + return ["linear_attention" if bool((i + 1) % 4) else "full_attention" for i in range(self.num_hidden_layers)] + + def build(self) -> Qwen3_5_VLTextMoE: + return Qwen3_5_VLTextMoE(self) + + +class Qwen3_5_VLTextMoE35BA3BConfig(Qwen3_5_VLTextMoEConfig): + vocab_size: int = 248320 + max_position_embeddings: int = 262144 + # Qwen3 Model(dense and moe)'s pad_token_id is not set, so we need to set it to None. + # If this pad_token_id is not set, the embedding module will not act specially for pad token. + # Note: Qwen3 Model's pad_token_id may be different from Qwen tokenizer's pad_token_id. + pad_token_id: int | None = None + eos_token_id: int = 248044 + num_hidden_layers: int = 40 + max_window_layers: int = 40 + hidden_size: int = 2048 + intermediate_size: int = 0 # not used + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000000.0 + hidden_act: str = "silu" + attention: MHAConfig = MHAConfig( + with_gate=True, + num_attention_heads=16, + num_key_value_heads=2, + head_dim=256, + qk_norm=True, + rms_norm_eps=1e-6, + rms_norm_type="zero_centered", + sliding_window=1024, + ) + linear_attention: GatedDeltaNetConfig = GatedDeltaNetConfig( + num_value_heads=32, + num_key_heads=16, + key_head_dim=128, + value_head_dim=128, + conv_kernel_dim=4, + hidden_act="silu", + rms_norm_eps=1e-6, + ) + tie_word_embeddings: bool = False + n_routed_experts: int = 256 + n_shared_experts: int = 1 + num_experts_per_tok: int = 8 + first_k_dense_replace: int = 0 + hidden_factor: float = 1.0 + moe_intermediate_size: int = 512 + router: GreedyRouterConfig = GreedyRouterConfig( + scoring_func="softmax", + norm_topk_prob=True, + router_scaling_factor=1.0, + ) + rope_scaling_cfg: RopeScalingConfig = RopeScalingConfig( + type="qwen3_vl", mrope_section=[11, 11, 10], partial_rotary_factor=0.25 + ) + balancing_loss_cfg: BalancingLossConfig | None = BalancingLossConfig() + z_loss_cfg: ZLossConfig | None = None diff --git a/xtuner/v1/module/__init__.py b/xtuner/v1/module/__init__.py index c84f7d500..2f1d59e67 100644 --- a/xtuner/v1/module/__init__.py +++ b/xtuner/v1/module/__init__.py @@ -1,5 +1,7 @@ from .attention import ( AttnOutputs, + GatedDeltaNet, + GatedDeltaNetConfig, MHAConfig, MLAConfig, MultiHeadAttention, @@ -26,6 +28,8 @@ "MultiLatentAttention", "MHAConfig", "MLAConfig", + "GatedDeltaNetConfig", + "GatedDeltaNet", "AttnOutputs", "RopeScalingConfig", "RotaryEmbedding", diff --git a/xtuner/v1/module/attention/__init__.py b/xtuner/v1/module/attention/__init__.py index 10fcb26df..d4594014a 100644 --- a/xtuner/v1/module/attention/__init__.py +++ b/xtuner/v1/module/attention/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .attn_outputs import AttnOutputs +from .gated_deltanet import GatedDeltaNet, GatedDeltaNetConfig from .mha import MHAConfig, MultiHeadAttention from .mla import MLAConfig, MultiLatentAttention @@ -10,4 +11,6 @@ "MHAConfig", "MLAConfig", "AttnOutputs", + "GatedDeltaNet", + "GatedDeltaNetConfig", ] diff --git a/xtuner/v1/module/attention/gated_deltanet.py b/xtuner/v1/module/attention/gated_deltanet.py new file mode 100644 index 000000000..c88dcc9b8 --- /dev/null +++ b/xtuner/v1/module/attention/gated_deltanet.py @@ -0,0 +1,276 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Annotated, cast + +import torch +import torch.nn.functional as F +from cyclopts import Parameter +from pydantic import BaseModel, ConfigDict +from torch import nn +from torch.distributed.tensor import DTensor +from typing_extensions import overload + +from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.float8.config import Float8Config +from xtuner.v1.utils import get_logger + +from ..linear import build_linear +from .attn_outputs import AttnOutputs + + +try: + from fla.modules import FusedRMSNormGated as FLA_FusedRMSNormGated + from fla.modules.fused_norm_gate import rms_norm_gated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + class FusedRMSNormGated(FLA_FusedRMSNormGated): + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + residual: torch.Tensor | None = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + ) -> torch.Tensor: + weight = self.weight + if isinstance(weight, DTensor): + weight = weight.to_local() + + return rms_norm_gated( + x, + g, + weight, + self.bias, + self.activation, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + +except ImportError: + FusedRMSNormGated = None # type: ignore + chunk_gated_delta_rule = None + +try: + from causal_conv1d import causal_conv1d_fn +except ImportError: + causal_conv1d_fn = None + +logger = get_logger() + + +class GatedDeltaNetConfig(BaseModel): + model_config = ConfigDict(title="Base attention config for xtuner", extra="forbid") + num_value_heads: Annotated[int, Parameter(group="attention")] + num_key_heads: Annotated[int, Parameter(group="attention")] + key_head_dim: Annotated[int, Parameter(group="attention")] + value_head_dim: Annotated[int, Parameter(group="attention")] + conv_kernel_dim: Annotated[int, Parameter(group="attention")] + hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS` + rms_norm_eps: Annotated[float, Parameter(group="attention")] + + def build( + self, + hidden_size: int, + float8_cfg: Float8Config | None = None, + **kwargs, + ) -> "GatedDeltaNet": + return GatedDeltaNet( + **self.model_dump(), + hidden_size=hidden_size, + float8_cfg=float8_cfg, + ) + + +class GatedDeltaNet(nn.Module): + def __init__( + self, + hidden_size: int, + num_value_heads: int, + num_key_heads: int, + key_head_dim: int, + value_head_dim: int, + conv_kernel_dim: int, + hidden_act: str, + rms_norm_eps: float, + layer_idx: int = 0, + float8_cfg: Float8Config | None = None, + ) -> None: + super().__init__() + self.name = f"layers.{layer_idx}.gate_deltanet" + self.float8_cfg = float8_cfg + + self.hidden_size = hidden_size + self.num_v_heads = num_value_heads + self.num_k_heads = num_key_heads + self.head_k_dim = key_head_dim + self.head_v_dim = value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = conv_kernel_dim + self.layer_idx = layer_idx + self.activation = hidden_act + self.rms_norm_eps = rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + assert causal_conv1d_fn is not None, ( + "causal_conv1d_fn is not available. Please install causal-conv1d to use GatedDeltaNet by `https://github.com/Dao-AILab/causal-conv1d`." + ) + self.causal_conv1d_fn = causal_conv1d_fn + assert chunk_gated_delta_rule is not None, ( + "chunk_gated_delta_rule is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." + ) + self.chunk_gated_delta_rule = chunk_gated_delta_rule + assert FusedRMSNormGated is not None, ( + "FusedRMSNormGated is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." + ) + self.norm = FusedRMSNormGated(self.head_v_dim, eps=self.rms_norm_eps, activation=self.activation) + + self.out_proj = build_linear( + self.value_dim, + self.hidden_size, + bias=False, + float8_cfg=self.float8_cfg, + ) + + self.in_proj_qkv = build_linear( + self.hidden_size, + self.key_dim * 2 + self.value_dim, + bias=False, + float8_cfg=self.float8_cfg, + ) + self.in_proj_z = build_linear( + self.hidden_size, + self.value_dim, + bias=False, + float8_cfg=self.float8_cfg, + ) + self.in_proj_b = build_linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = build_linear(self.hidden_size, self.num_v_heads, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + seq_ctx: SequenceContext, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # not used + ) -> AttnOutputs: + batch_size, seq_len, _ = hidden_states.shape + assert batch_size == 1, "Only batch size of 1 is supported for now in GateDeltaNet" + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + weight = self.conv1d.weight.squeeze(1) + bias = self.conv1d.bias + if isinstance(weight, DTensor): + weight = weight.to_local() + if bias and isinstance(bias, DTensor): + bias = bias.to_local() + + # TODO: If full_graph mode is supported in the future, it needs to be modified to custom_op + if seq_ctx.seq_idx is None: + seq_idx = torch.cat( + [ + torch.full((s,), i, dtype=torch.int32, device=mixed_qkv.device) + for i, s in enumerate(seq_ctx.seq_lens_q) + ], + dim=0, + )[None] + seq_ctx.seq_idx = cast(torch.IntTensor, seq_idx) + else: + seq_idx = seq_ctx.seq_idx + + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=weight, + bias=bias, + activation=self.activation, + seq_idx=seq_idx, + ) + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + A_log = self.A_log + dt_bias = self.dt_bias + if isinstance(A_log, DTensor): + A_log = A_log.to_local() + if isinstance(dt_bias, DTensor): + dt_bias = dt_bias.to_local() + + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + core_attn_out, _ = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + cu_seqlens=seq_ctx.cu_seq_lens_q, + ) + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + attn_outputs: AttnOutputs = { + "projected_output": output, + } + return attn_outputs + + @overload # type: ignore + def __call__( # type: ignore + self, + hidden_states: torch.Tensor, + seq_ctx: SequenceContext, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> AttnOutputs: ... + + __call__ = nn.Module.__call__ diff --git a/xtuner/v1/module/attention/mha.py b/xtuner/v1/module/attention/mha.py index 9dcd1f942..8e6121019 100644 --- a/xtuner/v1/module/attention/mha.py +++ b/xtuner/v1/module/attention/mha.py @@ -38,9 +38,11 @@ class MHAConfig(BaseModel): qkv_bias: Annotated[bool, Parameter(group="attention")] = False qk_norm: bool = False rms_norm_eps: float = 1e-06 + rms_norm_type: Literal["default", "zero_centered"] = "default" o_bias: Annotated[bool, Parameter(group="attention")] = False sliding_window: Annotated[int | None, Parameter(group="attention")] = -1 with_sink: Annotated[bool, Parameter(group="attention")] = False + with_gate: Annotated[bool, Parameter(group="attention")] = False attn_impl: Literal["flash_attention", "flex_attention", "eager_attention"] = "flash_attention" def model_post_init(self, _): @@ -122,8 +124,10 @@ def __init__( qkv_bias: bool = False, qk_norm: bool = False, rms_norm_eps: float = 1e-6, + rms_norm_type: Literal["default", "zero_centered"] = "default", o_bias: bool = False, with_sink: bool = False, + with_gate: bool = False, attn_impl: Literal["flash_attention", "flex_attention", "eager_attention"] = "flash_attention", rope_scaling_cfg: RopeScalingConfig | None = None, float8_cfg: Float8Config | None = None, @@ -145,14 +149,18 @@ def __init__( self.qkv_bias = qkv_bias self.qk_norm = qk_norm self.rms_norm_eps = rms_norm_eps + self.rms_norm_type = rms_norm_type self.o_bias = o_bias self.generate_config = generate_config self.float8_cfg = float8_cfg self.layer_idx = layer_idx + self.with_gate = with_gate self.q_proj = build_linear( self.hidden_size, - self.num_attention_heads * self.head_dim, + self.num_attention_heads * self.head_dim + if not with_gate + else self.num_attention_heads * self.head_dim * 2, bias=self.qkv_bias, float8_cfg=self.float8_cfg, ) @@ -176,8 +184,8 @@ def __init__( ) if self.qk_norm: - self.q_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps) + self.q_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps, type=self.rms_norm_type) + self.k_norm = RMSNorm(self.head_dim, eps=self.rms_norm_eps, type=self.rms_norm_type) self.with_sink = with_sink if self.with_sink: @@ -188,7 +196,10 @@ def __init__( self.window_size = (sliding_window, sliding_window) fope_sep_head = rope_scaling_cfg.fope_sep_head if rope_scaling_cfg is not None else None - self.apply_rotary_emb = get_apply_rotary_emb(fope_sep_head) # type: ignore + enable_partial_rotary = ( + rope_scaling_cfg.partial_rotary_factor != 1.0 if rope_scaling_cfg is not None else False + ) + self.apply_rotary_emb = get_apply_rotary_emb(fope_sep_head, enable_partial_rotary=enable_partial_rotary) # type: ignore self.attn_impl_func: Callable[..., AttnOpOutputs] = attn_impl_mapping[attn_impl] # type: ignore[assignment] @@ -328,7 +339,14 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states).view(hidden_shape) # [b, seq, n_head, head_dim] + if self.with_gate: + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + else: + gate = None + query_states = self.q_proj(hidden_states).view(hidden_shape) # [b, seq, n_head, head_dim] key_states = self.k_proj(hidden_states).view(hidden_shape) value_states = self.v_proj(hidden_states).view(hidden_shape) @@ -409,6 +427,10 @@ def forward( ) raw_output = raw_output.reshape(*input_shape, -1).contiguous() + if self.with_gate: + assert gate is not None + raw_output = raw_output * torch.sigmoid(gate) + projected_output = self.o_proj(raw_output) attn_outputs: AttnOutputs = { "projected_output": projected_output, diff --git a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py index 53a3f90b3..e01f89ada 100644 --- a/xtuner/v1/module/decoder_layer/dense_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/dense_decoder_layer.py @@ -6,7 +6,7 @@ from xtuner.v1.config import GenerateConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.config import Float8Config -from xtuner.v1.module import AttnOutputs, MHAConfig, MLAConfig, RMSNorm +from xtuner.v1.module import AttnOutputs, GatedDeltaNetConfig, MHAConfig, MLAConfig, RMSNorm from xtuner.v1.module.rope import RopeScalingConfig from xtuner.v1.ops.act_fn import get_act_fn from xtuner.v1.utils import ForwardState @@ -44,7 +44,8 @@ def __init__( mlp_bias: bool = False, hidden_act: str, rms_norm_eps: float = 1e-6, - attention_config: MLAConfig | MHAConfig, + rms_norm_type: Literal["default", "zero_centered"] = "default", + attention_config: MLAConfig | MHAConfig | GatedDeltaNetConfig, rope_scaling_cfg: RopeScalingConfig | None = None, generate_config: GenerateConfig | None = None, float8_cfg: Float8Config | None = None, @@ -68,8 +69,8 @@ def __init__( hidden_act=hidden_act, float8_cfg=float8_cfg, ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) def forward( self, @@ -110,7 +111,7 @@ def prefilling( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states = self.self_attn.prefilling( + hidden_states = self.self_attn.prefilling( # type: ignore hidden_states=hidden_states, position_embeddings=position_embeddings, seq_ctx=seq_ctx, @@ -156,7 +157,7 @@ def decoding( def build_kv_cache( self, max_batch_size: int | None = None, max_length: int | None = None, block_size: int | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - return self.self_attn.build_kv_cache( + return self.self_attn.build_kv_cache( # type: ignore max_batch_size=max_batch_size, max_length=max_length, block_size=block_size, diff --git a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py index 9303d66d4..4e72751da 100644 --- a/xtuner/v1/module/decoder_layer/moe_decoder_layer.py +++ b/xtuner/v1/module/decoder_layer/moe_decoder_layer.py @@ -14,9 +14,13 @@ from xtuner.v1.float8 import Float8Config from xtuner.v1.module import ( AttnOutputs, + GatedDeltaNet, + GatedDeltaNetConfig, GreedyRouterConfig, MHAConfig, MLAConfig, + MultiHeadAttention, + MultiLatentAttention, NoAuxRouterConfig, RMSNorm, RouterResults, @@ -128,9 +132,14 @@ def forward( bias = bias.float() logits = F.linear(hidden_states.float(), weight.float(), bias) - return self.router(logits, rollout_routed_experts) + # Debug for aligning with hf implementation. + # logits = F.linear(hidden_states, weight, bias) + # gate = self.router(logits, rollout_routed_experts) + # gate['topk_weights'] = gate['topk_weights'].float() + # return gate + class MoEBlock(nn.Module): def __init__( @@ -190,11 +199,13 @@ def __init__( moe_bias: bool = False, hidden_act: str, rms_norm_eps: float = 1e-6, + rms_norm_type: Literal["default", "zero_centered"] = "default", num_experts_per_tok: int, n_routed_experts: int, n_shared_experts: int, + with_shared_expert_gate: bool = False, hidden_factor: float = 1.0, - attention_config: MHAConfig | MLAConfig, + attention_config: MHAConfig | MLAConfig | GatedDeltaNetConfig, rope_scaling_cfg: RopeScalingConfig | None = None, layer_type: Literal["full_attention", "sliding_attention"] | None = None, generate_config: GenerateConfig | None = None, @@ -212,7 +223,7 @@ def __init__( self.n_shared_experts = n_shared_experts self.hidden_factor = hidden_factor - self.self_attn = attention_config.build( + self.self_attn: MultiHeadAttention | MultiLatentAttention | GatedDeltaNet = attention_config.build( hidden_size=hidden_size, layer_idx=layer_idx, generate_config=generate_config, @@ -220,10 +231,13 @@ def __init__( layer_type=layer_type, float8_cfg=float8_cfg, ) - self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.shared_experts: MoEMLP | None + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) self.layer_idx = layer_idx + self.with_shared_expert_gate = with_shared_expert_gate + self.shared_expert_gate: nn.Module | None + self.shared_experts: MoEMLP | None + if n_shared_experts > 0: self.shared_experts = MoEMLP( hidden_size=hidden_size, @@ -233,10 +247,13 @@ def __init__( mlp_bias=mlp_bias, float8_cfg=float8_cfg, ) + if with_shared_expert_gate: + self.shared_expert_gate = build_linear(hidden_size, 1, bias=False) else: self.shared_experts = None + self.shared_expert_gate = None - self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps, type=rms_norm_type) self.gate = MoEGate( hidden_size=hidden_size, @@ -307,6 +324,44 @@ def forward( position_embeddings_list=position_embeddings, ) + def _hf_expert_forward_for_debug(self, hidden_states: torch.Tensor, router_results: RouterResults, origin_shape): + # xtuner: num_experts * 2 * expert_dim, hidden_size + # hf: num_experts, 2 * expert_dim, hidden_size + origin_gate_up_proj = self.experts.fused_w1w3.weight + gate_up_proj = origin_gate_up_proj.view( + self.n_routed_experts, 2 * self.experts.intermediate_size, self.hidden_size + ) + + # xtuner: num_experts * hidden_size, expert_dim + # hf: num_experts, hidden_size, expert_dim + origin_down_proj = self.experts.fused_w2.weight + down_proj = origin_down_proj.view(self.n_routed_experts, self.hidden_size, self.experts.intermediate_size) + + from transformers.activations import ACT2FN + + act_fn = ACT2FN["silu"] + + hidden_states_reshaped = hidden_states.view(-1, hidden_states.size(-1)) + combined_hidden_states = torch.zeros_like(hidden_states_reshaped) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_results["topk_ids"], num_classes=self.n_routed_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.n_routed_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states_reshaped[token_idx] + gate, up = nn.functional.linear(current_state, gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, down_proj[expert_idx]) + current_hidden_states = current_hidden_states * router_results["topk_weights"][token_idx, top_k_pos, None] + combined_hidden_states.index_add_(0, token_idx, current_hidden_states.to(combined_hidden_states.dtype)) + + combined_hidden_states = combined_hidden_states.view(*origin_shape) + return combined_hidden_states + def _forward( self, hidden_states: torch.Tensor, @@ -381,6 +436,10 @@ def _forward( ) combined_hidden_states = post_combined["hidden_states"] combined_hidden_states = combined_hidden_states.view(*origin_shape) + + # debug for aligning with hf implementation. + # combined_hidden_states = self._hf_expert_forward_for_debug(hidden_states, router_results, origin_shape) + # ProberList.after_combine(self.layer_idx, combined_hidden_states) if self.n_shared_experts > 0: @@ -552,7 +611,7 @@ def _pre_moe_forward( hidden_states = attn_outputs["projected_output"] elif state == ForwardState.PREFILLING: assert past_key_values is not None, "past_key_values should be provided in pre-filling state" - hidden_states = self.self_attn.prefilling( + hidden_states = self.self_attn.prefilling( # type: ignore hidden_states=hidden_states, position_embeddings=position_embeddings, seq_ctx=seq_ctx, @@ -560,7 +619,7 @@ def _pre_moe_forward( ) elif state == ForwardState.DECODING: assert past_key_values is not None, "past_key_values should be provided in decoding state" - hidden_states = self.self_attn.decoding( + hidden_states = self.self_attn.decoding( # type: ignore hidden_states=hidden_states, position_embeddings=position_embeddings, seq_ctx=seq_ctx, @@ -585,6 +644,13 @@ def _shared_experts_forward( ) -> torch.Tensor: assert self.shared_experts is not None, "Shared experts should be initialized when n_shared_experts > 0" shared_experts_out = self.shared_experts(hidden_states) + + if self.with_shared_expert_gate: + assert self.shared_expert_gate is not None, ( + "Shared expert gate should be initialized when with_shared_expert_gate is True" + ) + shared_experts_out = torch.sigmoid(self.shared_expert_gate(hidden_states)) * shared_experts_out + return shared_experts_out def _post_moe_forward( @@ -601,7 +667,7 @@ def _post_moe_forward( def build_kv_cache( self, max_batch_size: int | None = None, max_length: int | None = None, block_size: int | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - return self.self_attn.build_kv_cache( + return self.self_attn.build_kv_cache( # type: ignore max_batch_size=max_batch_size, max_length=max_length, block_size=block_size, diff --git a/xtuner/v1/module/rms_norm/rms_norm.py b/xtuner/v1/module/rms_norm/rms_norm.py index 08594382d..11ccfb1d9 100644 --- a/xtuner/v1/module/rms_norm/rms_norm.py +++ b/xtuner/v1/module/rms_norm/rms_norm.py @@ -1,19 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Literal + import torch from torch import nn from torch.distributed.tensor import DTensor -from xtuner.v1.ops import rms_norm +from xtuner.v1.ops import rms_norm, zero_centered_rms_norm class RMSNorm(nn.Module): weight: torch.Tensor - def __init__(self, hidden_size: int, eps: float = 1e-6): + def __init__(self, hidden_size: int, eps: float = 1e-6, type: Literal["default", "zero_centered"] = "default"): """RMSNorm is equivalent to T5LayerNorm.""" super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self._type = type + + if type == "default": + self.rms_norm_fn = rms_norm + elif type == "zero_centered": + self.rms_norm_fn = zero_centered_rms_norm + else: + raise ValueError(f"Unsupported RMSNorm type: {type}") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if isinstance(self.weight, DTensor): @@ -28,10 +38,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # return (weight * hidden_states).to(input_dtype) # gpt_oss # return weight * hidden_states.to(input_dtype) # Llama - return rms_norm(hidden_states, weight, epsilon=self.variance_epsilon) # type: ignore[operator] + return self.rms_norm_fn(hidden_states, weight, epsilon=self.variance_epsilon) # type: ignore[operator] def init_weights(self): self.weight.data.fill_(1.0) def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + return f"{tuple(self.weight.shape)}, type={self._type}, eps={self.variance_epsilon}" diff --git a/xtuner/v1/module/rope/rope.py b/xtuner/v1/module/rope/rope.py index a8594d437..0b952b949 100644 --- a/xtuner/v1/module/rope/rope.py +++ b/xtuner/v1/module/rope/rope.py @@ -1,4 +1,4 @@ -from typing import Literal, Protocol, cast +from typing import Callable, Literal, Optional, Protocol, cast import torch import torch.nn as nn @@ -25,6 +25,7 @@ class RopeScalingConfig(BaseModel): # For Qwen3VL mrope_section: list[int] | None = None # e.g. [24, 20, 20] + partial_rotary_factor: float = 1.0 factor: float | None = None beta_fast: float | None = None @@ -60,6 +61,26 @@ def __call__(self, x: torch.Tensor, position_ids: torch.LongTensor) -> tuple[tor def to(self, device: torch.device) -> Self: ... +def compute_default_rope_parameters( + config, + device: Optional["torch.device"] = None, +) -> tuple["torch.Tensor", float]: + base = config.rope_theta + if config.rope_scaling_cfg is not None: + rope_scaling_cfg: RopeScalingConfig = config.rope_scaling_cfg + partial_rotary_factor = getattr(rope_scaling_cfg, "partial_rotary_factor", 1.0) + else: + partial_rotary_factor = 1.0 + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor + + class RotaryEmbedding(nn.Module): inv_freq: torch.Tensor @@ -81,7 +102,11 @@ def __init__(self, config, device=None): f"Unsupported rope_type: {self.rope_type}. Supported types are: 'default', 'linear', 'yarn', 'llama3'." ) - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + # The implementation of RoPE has been refactored in Transformers V5, and + # the following approach is used for compatibility. + self.rope_init_fn: Callable = compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq: torch.Tensor inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) @@ -283,7 +308,12 @@ def __init__(self, config, device=None): self.original_max_seq_len = config.max_position_embeddings self.rope_type = "default" self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + # The implementation of RoPE has been refactored in Transformers V5, and + # the following approach is used for compatibility. + self.rope_init_fn: Callable = compute_default_rope_parameters + if self.rope_type != "default": + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq: torch.Tensor inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) diff --git a/xtuner/v1/ops/__init__.py b/xtuner/v1/ops/__init__.py index d7f043ae8..4a8d6907e 100644 --- a/xtuner/v1/ops/__init__.py +++ b/xtuner/v1/ops/__init__.py @@ -7,7 +7,7 @@ from .attn_imp import AttnOpOutputs, attn_impl_mapping from .flash_attn import flash_attn_varlen_func from .moe import group_gemm, permute, unpermute -from .rms_norm import rms_norm +from .rms_norm import rms_norm, zero_centered_rms_norm from .rotary_emb import get_apply_rotary_emb from .tensor_parallel import attn_column_parallel, attn_row_parallel diff --git a/xtuner/v1/ops/rms_norm/__init__.py b/xtuner/v1/ops/rms_norm/__init__.py index 5c8431b96..7da199d1d 100644 --- a/xtuner/v1/ops/rms_norm/__init__.py +++ b/xtuner/v1/ops/rms_norm/__init__.py @@ -1,5 +1,4 @@ import os -from functools import partial import torch @@ -24,6 +23,18 @@ def _triton_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> t return rms_norm_fn(x, weight, bias=None, eps=epsilon) +def native_zero_centered_rms_norm(x: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: + # TODO: is native_rms_norm ? + def _norm(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + epsilon) + + output = _norm(x.float()) + # Llama does x.to(float16) * w whilst Qwen3_5Moe is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + weight.float()) + return output.type_as(x) + + def get_rms_norm_fn() -> RMSNormProtocol: from xtuner.v1.utils import get_device @@ -40,4 +51,21 @@ def get_rms_norm_fn() -> RMSNormProtocol: raise NotImplementedError(f"RMSNorm is not implemented on {device}") +def get_zero_centered_rms_norm_fn() -> RMSNormProtocol: + from xtuner.v1.utils import get_device + + device = get_device() + if device in ["cpu", "cuda"]: + # TODO: control triton rmsnorm by model config rather than env var + if os.getenv("XTUNER_USE_NATIVE_RMSNORM", "1") == "0" and device == "cuda": + raise NotImplementedError("Zero-centered RMSNorm is not implemented in triton") + else: + return native_zero_centered_rms_norm + elif device == "npu": + raise NotImplementedError("Zero-centered RMSNorm is not implemented on NPU") + else: + raise NotImplementedError(f"RMSNorm is not implemented on {device}") + + rms_norm = get_rms_norm_fn() +zero_centered_rms_norm = get_zero_centered_rms_norm_fn() diff --git a/xtuner/v1/ops/rotary_emb.py b/xtuner/v1/ops/rotary_emb.py index e4508ca06..ebca8c184 100644 --- a/xtuner/v1/ops/rotary_emb.py +++ b/xtuner/v1/ops/rotary_emb.py @@ -49,6 +49,34 @@ def apply_rotary_pos_emb_cuda( return q_embed, k_embed +# Note: Although this function is compatible with apply_rotary_pos_emb_cuda, +# it is still recommended to separate them into two for efficiency considerations. +def apply_rotary_pos_emb_cuda_for_partial_rotary( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor | None = None, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + def apply_rotary_pos_emb_sep_cuda( q: torch.Tensor, k: torch.Tensor, @@ -120,16 +148,23 @@ def __call__( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -def get_apply_rotary_emb(fope_sep_head: bool | None = False) -> ApplyRotaryEmbProtocol: +def get_apply_rotary_emb( + fope_sep_head: bool | None = False, enable_partial_rotary: bool = False +) -> ApplyRotaryEmbProtocol: from xtuner.v1.utils.device import get_device device = get_device() if device == "npu": assert fope_sep_head is None or not fope_sep_head, "FoPE with sep head is not supported on NPU yet." + assert not enable_partial_rotary, "Partial rotary is not supported on NPU yet." return apply_rotary_pos_emb_npu else: if fope_sep_head: logger.debug("Using FoPE with fope_sep_head") + assert not enable_partial_rotary, "Partial rotary is not supported when using FoPE with sep head." return apply_rotary_pos_emb_sep_cuda else: - return apply_rotary_pos_emb_cuda + if enable_partial_rotary: + return apply_rotary_pos_emb_cuda_for_partial_rotary + else: + return apply_rotary_pos_emb_cuda