Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ all = [
"liger-kernel",
"parametrize",
"mathruler",
"pylatexenc"
"pylatexenc",
"flash-linear-attention"
]

[tool.mypy]
Expand Down
3 changes: 1 addition & 2 deletions xtuner/v1/data_proto/messages/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from pydantic import BaseModel, ConfigDict

from transformers import PreTrainedTokenizer
from xtuner.utils import IGNORE_INDEX
from xtuner.v1.data_proto.messages.base import BaseMessages
from xtuner.v1.data_proto.templates import ChatTemplate, HybridChatTemplate
from xtuner.v1.utils import get_logger
from xtuner.v1.utils import IGNORE_INDEX, get_logger


logger = get_logger()
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/data_proto/sequence_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SequenceContext:
block_table: torch.Tensor | None
device: str | torch.device # TODO: 这个地方有点乱,到处是 device
position_ids: torch.LongTensor | None
seq_idx: torch.IntTensor | None

# Qwen3VL
image_grid_thw: torch.Tensor | None
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
self.inputs_embeds = inputs_embeds
self.num_img_tokens = num_img_tokens
self.rollout_routed_experts = rollout_routed_experts
self.seq_idx = None

seq_lens_k = self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1]
seq_lens_q = self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1]
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/datasets/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def qwen3_vl_sft_collator(
if len(position_ids_list) > 0:
position_ids = torch.cat(position_ids_list, dim=-1)
position_ids = position_ids[:, :, :-1]
if pack_to_max_length and pack_max_length - position_ids.shape[-1] > 0:
position_ids = pad_to_max_length(position_ids, 0, max_length=pack_max_length, dim=-1)

num_img_tokens: list[int] = []
for data in instance:
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
InternVL3P5MoE30BA3Config,
InternVLBaseConfig,
)
from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config
from .compose.qwen3_vl import (
Qwen3VLDense4BConfig,
Qwen3VLDense8BConfig,
Expand Down Expand Up @@ -98,4 +99,5 @@ def get_model_config_from_hf(model_path: Path):
"TorchCompileOption",
"DEFAULT_FLOAT8_CFG",
"XTunerBaseModelConfig",
"Qwen3_5_VLMoE35BA3Config",
]
4 changes: 3 additions & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
WeightWithDynamicTilewiseFloat8CastTensor,
)
from xtuner.v1.loss import BaseLossContext
from xtuner.v1.module.attention import MHAConfig, MLAConfig
from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig, MLAConfig
from xtuner.v1.module.rope import RopeScalingConfig
from xtuner.v1.ops.comm.foreach_allgather import foreach_all_gather
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory
Expand Down Expand Up @@ -147,9 +147,11 @@ class TransformerConfig(XTunerBaseModelConfig):
hidden_size: Annotated[int, Parameter(group="model")]
intermediate_size: Annotated[int, Parameter(group="model")]
rms_norm_eps: Annotated[float, Parameter(group="model")]
rms_norm_type: Annotated[Literal["default", "zero_centered"], Parameter(group="model")] = "default"
rope_theta: Annotated[float, Parameter(group="model")] # required by transformers's build rope
hidden_act: Annotated[str, Parameter(group="model")] # key defined in `transformers.activations.ACT2CLS`
attention: MLAConfig | MHAConfig
linear_attention: Annotated[GatedDeltaNetConfig | None, Parameter(group="model")] = None
mlp_bias: Annotated[bool, Parameter(group="model")] = False
tie_word_embeddings: Annotated[bool, Parameter(group="model")] = False
model_type: Annotated[str | None, Parameter(group="model")] = None # TODO: yehaochen maybe should be removed
Expand Down
6 changes: 6 additions & 0 deletions xtuner/v1/model/compose/qwen3_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .qwen3_5_config import Qwen3_5_VLMoE35BA3Config


__all__ = [
"Qwen3_5_VLMoE35BA3Config",
]
33 changes: 33 additions & 0 deletions xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from xtuner.v1.model.base import TransformerConfig
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
from xtuner.v1.utils import get_logger

from ..qwen3_vl.qwen3_vl_config import Qwen3VLBaseConfig, Qwen3VLProjectorConfig, Qwen3VLVisionConfig


logger = get_logger()


class Qwen3_5_VisionConfig(Qwen3VLVisionConfig):
deepstack_visual_indexes: list[int] = []


class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
deepstack_visual_indexes: list[int] = []


class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
vision_config: Qwen3_5_VisionConfig
projector_config: Qwen3_5_ProjectorConfig
text_config: TransformerConfig

image_token_id: int = 248056
video_token_id: int = 248057
vision_start_token_id: int = 248053
vision_end_token_id: int = 248054


class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
6 changes: 4 additions & 2 deletions xtuner/v1/model/compose/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,10 @@ def forward(

if deepstack_visual_embeds is not None and len(deepstack_visual_embeds) == 0:
assert seq_ctx.position_ids is not None
assert seq_ctx.position_ids.ndim == 2, f"position_ids must be 2-dim when deepstack_visual_embeds is None," \
f" but got {seq_ctx.position_ids.ndim}"
assert seq_ctx.position_ids.ndim in (2, 3), (
f"position_ids must be 2-dim or 3-dim when deepstack_visual_embeds is None,"
f" but got {seq_ctx.position_ids.ndim}"
)
deepstack_visual_embeds = None
visual_pos_masks = None

Expand Down
27 changes: 24 additions & 3 deletions xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
TransformerConfig,
)
from xtuner.v1.model.utils import checkpoint_wrapper
from xtuner.v1.module import LMHead, RMSNorm, RotaryEmbeddingProtocol, get_rope_embedding
from xtuner.v1.module import (
GatedDeltaNetConfig,
LMHead,
MHAConfig,
MLAConfig,
RMSNorm,
RotaryEmbeddingProtocol,
get_rope_embedding,
)
from xtuner.v1.module.decoder_layer.dense_decoder_layer import DenseDecoderLayer
from xtuner.v1.utils import (
get_device,
Expand All @@ -53,7 +61,7 @@ class Dense(BaseModel):
def __init__(self, config: TransformerConfig):
super().__init__(config)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, type=config.rms_norm_type)
self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False)
self.layers = self.build_layers(config)
self.rotary_emb = self.build_rotary_embedding(config)
Expand Down Expand Up @@ -116,14 +124,27 @@ def build_layers(self, config: TransformerConfig) -> nn.ModuleDict:
# 让 layers 是一个 nn.ModuleDict 方便做 pipeline parallel 的参数切分,
# 这样可以保证部分 layer 被切掉后,idx 保持不变
layers = nn.ModuleDict()
attention_config: GatedDeltaNetConfig | MLAConfig | MHAConfig | None = None
for layer_idx in range(config.num_hidden_layers):
if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]:
attention_config = config.attention
elif config.layers_type[layer_idx] == "linear_attention":
attention_config = config.linear_attention
assert attention_config is not None, (
"linear_attention config must be provided for linear_attention layer"
)
else:
raise ValueError(
f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported."
)

layers[str(layer_idx)] = DenseDecoderLayer(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
mlp_bias=config.mlp_bias,
hidden_act=config.hidden_act,
rms_norm_eps=config.rms_norm_eps,
attention_config=config.attention,
attention_config=attention_config,
generate_config=config.generate_config,
rope_scaling_cfg=config.rope_scaling_cfg,
float8_cfg=config.float8_cfg,
Expand Down
26 changes: 23 additions & 3 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
)
from xtuner.v1.model.utils import ModelForwardExtraLogInfo, checkpoint_wrapper, module_dict_repr
from xtuner.v1.module import (
GatedDeltaNetConfig,
GreedyRouterConfig,
LMHead,
MHAConfig,
MLAConfig,
NoAuxRouter,
NoAuxRouterConfig,
RMSNorm,
Expand Down Expand Up @@ -116,6 +119,7 @@ class MoEConfig(TransformerConfig):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
n_routed_experts: Annotated[int, Parameter(group="moe")]
n_shared_experts: Annotated[int, Parameter(group="moe")]
with_shared_expert_gate: bool = False # enable when n_shared_experts > 0
num_experts_per_tok: Annotated[int, Parameter(group="moe")]
first_k_dense_replace: Annotated[int, Parameter(group="moe")] = 0
hidden_factor: Annotated[float, Parameter(group="moe")] = 1.0
Expand Down Expand Up @@ -166,7 +170,7 @@ def __init__(self, config: MoEConfig):
else:
self.ep_mesh = None

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, type=config.rms_norm_type)
self.lm_head = LMHead(config.hidden_size, config.vocab_size, bias=False)

self.layers = self.build_layers(config)
Expand Down Expand Up @@ -592,15 +596,29 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict:
# 让 layers 是一个 nn.ModuleDict 方便做 pipeline parallel 的参数切分,
# 这样可以保证部分 layer 被切掉后,idx 保持不变
layers = nn.ModuleDict()
attention_config: GatedDeltaNetConfig | MLAConfig | MHAConfig | None = None
for layer_idx in range(config.num_hidden_layers):
if config.layers_type[layer_idx] in ["full_attention", "sliding_attention"]:
attention_config = config.attention
elif config.layers_type[layer_idx] == "linear_attention":
attention_config = config.linear_attention
assert attention_config is not None, (
"linear_attention config must be provided for linear_attention layer"
)
else:
raise ValueError(
f"Unsupported layer type {config.layers_type[layer_idx]} at layer {layer_idx}. Only 'full_attention', 'sliding_attention' and 'linear_attention' are supported."
)

if layer_idx < config.first_k_dense_replace:
layers[str(layer_idx)] = DenseDecoderLayer(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
mlp_bias=config.mlp_bias,
hidden_act=config.hidden_act,
rms_norm_eps=config.rms_norm_eps,
attention_config=config.attention,
rms_norm_type=config.rms_norm_type,
attention_config=attention_config,
layer_type=config.layers_type[layer_idx],
rope_scaling_cfg=config.rope_scaling_cfg,
generate_config=config.generate_config,
Expand All @@ -617,12 +635,14 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict:
moe_bias=config.moe_bias,
hidden_act=config.hidden_act,
rms_norm_eps=config.rms_norm_eps,
rms_norm_type=config.rms_norm_type,
num_experts_per_tok=config.num_experts_per_tok,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
with_shared_expert_gate=config.with_shared_expert_gate,
hidden_factor=config.hidden_factor,
layer_type=config.layers_type[layer_idx],
attention_config=config.attention,
attention_config=attention_config,
rope_scaling_cfg=config.rope_scaling_cfg,
generate_config=config.generate_config,
router_config=config.router,
Expand Down
Loading
Loading