diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index eff0b982ec..4661e7dba7 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -773,6 +773,7 @@ |[Qwen/Qwen3-VL-Embedding-8B](https://modelscope.cn/models/Qwen/Qwen3-VL-Embedding-8B)|qwen3_vl_emb|qwen3_vl_emb|transformers>=4.57, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3-VL-Embedding-8B](https://huggingface.co/Qwen/Qwen3-VL-Embedding-8B)| |[Qwen/Qwen3-VL-Reranker-2B](https://modelscope.cn/models/Qwen/Qwen3-VL-Reranker-2B)|qwen3_vl_reranker|qwen3_vl_reranker|transformers>=4.57, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3-VL-Reranker-2B](https://huggingface.co/Qwen/Qwen3-VL-Reranker-2B)| |[Qwen/Qwen3-VL-Reranker-8B](https://modelscope.cn/models/Qwen/Qwen3-VL-Reranker-8B)|qwen3_vl_reranker|qwen3_vl_reranker|transformers>=4.57, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3-VL-Reranker-8B](https://huggingface.co/Qwen/Qwen3-VL-Reranker-8B)| +|[Qwen/Qwen3.5-397B-A17B](https://modelscope.cn/models/Qwen/Qwen3.5-397B-A17B)|qwen3_5_moe|qwen3_5|transformers>=5.2.0.dev, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3.5-397B-A17B](https://huggingface.co/Qwen/Qwen3.5-397B-A17B)| |[iic/gme-Qwen2-VL-2B-Instruct](https://modelscope.cn/models/iic/gme-Qwen2-VL-2B-Instruct)|qwen2_gme|qwen2_gme|-|✘|vision|[Alibaba-NLP/gme-Qwen2-VL-2B-Instruct](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct)| |[iic/gme-Qwen2-VL-7B-Instruct](https://modelscope.cn/models/iic/gme-Qwen2-VL-7B-Instruct)|qwen2_gme|qwen2_gme|-|✘|vision|[Alibaba-NLP/gme-Qwen2-VL-7B-Instruct](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct)| |[AIDC-AI/Ovis1.6-Gemma2-9B](https://modelscope.cn/models/AIDC-AI/Ovis1.6-Gemma2-9B)|ovis1_6|ovis1_6|transformers>=4.42|✘|vision|[AIDC-AI/Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index c9e5f3a078..9897c3e607 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -774,6 +774,7 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen3-VL-Embedding-8B](https://modelscope.cn/models/Qwen/Qwen3-VL-Embedding-8B)|qwen3_vl_emb|qwen3_vl_emb|transformers>=4.57, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3-VL-Embedding-8B](https://huggingface.co/Qwen/Qwen3-VL-Embedding-8B)| |[Qwen/Qwen3-VL-Reranker-2B](https://modelscope.cn/models/Qwen/Qwen3-VL-Reranker-2B)|qwen3_vl_reranker|qwen3_vl_reranker|transformers>=4.57, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3-VL-Reranker-2B](https://huggingface.co/Qwen/Qwen3-VL-Reranker-2B)| |[Qwen/Qwen3-VL-Reranker-8B](https://modelscope.cn/models/Qwen/Qwen3-VL-Reranker-8B)|qwen3_vl_reranker|qwen3_vl_reranker|transformers>=4.57, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3-VL-Reranker-8B](https://huggingface.co/Qwen/Qwen3-VL-Reranker-8B)| +|[Qwen/Qwen3.5-397B-A17B](https://modelscope.cn/models/Qwen/Qwen3.5-397B-A17B)|qwen3_5_moe|qwen3_5|transformers>=5.2.0.dev, qwen_vl_utils>=0.0.14, decord|✔|vision, video|[Qwen/Qwen3.5-397B-A17B](https://huggingface.co/Qwen/Qwen3.5-397B-A17B)| |[iic/gme-Qwen2-VL-2B-Instruct](https://modelscope.cn/models/iic/gme-Qwen2-VL-2B-Instruct)|qwen2_gme|qwen2_gme|-|✘|vision|[Alibaba-NLP/gme-Qwen2-VL-2B-Instruct](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct)| |[iic/gme-Qwen2-VL-7B-Instruct](https://modelscope.cn/models/iic/gme-Qwen2-VL-7B-Instruct)|qwen2_gme|qwen2_gme|-|✘|vision|[Alibaba-NLP/gme-Qwen2-VL-7B-Instruct](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct)| |[AIDC-AI/Ovis1.6-Gemma2-9B](https://modelscope.cn/models/AIDC-AI/Ovis1.6-Gemma2-9B)|ovis1_6|ovis1_6|transformers>=4.42|✘|vision|[AIDC-AI/Ovis1.6-Gemma2-9B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B)| diff --git a/swift/megatron/arguments/export_args.py b/swift/megatron/arguments/export_args.py index 6982284120..b661b395d7 100644 --- a/swift/megatron/arguments/export_args.py +++ b/swift/megatron/arguments/export_args.py @@ -34,8 +34,8 @@ def _init_output_dir(self): raise FileExistsError(f'args.output_dir: `{self.output_dir}` already exists.') logger.info(f'args.output_dir: `{self.output_dir}`') - def __post_init__(self): - super().__post_init__() + def _init_megatron_args(self): + self._init_output_dir() self.test_convert_dtype = HfConfigFactory.to_torch_dtype(self.test_convert_dtype) extra_config = MegatronArguments.load_args_config(self.ckpt_dir) extra_config['mcore_adapter'] = self.mcore_adapter @@ -50,6 +50,7 @@ def __post_init__(self): logger.info('Settting args.sequence_parallel: True') if self.merge_lora is None: self.merge_lora = self.to_hf + super()._init_megatron_args() def _init_convert(self): convert_kwargs = { diff --git a/swift/megatron/arguments/megatron_base_args.py b/swift/megatron/arguments/megatron_base_args.py index 9dcec37016..f88b23d01b 100644 --- a/swift/megatron/arguments/megatron_base_args.py +++ b/swift/megatron/arguments/megatron_base_args.py @@ -12,16 +12,15 @@ @dataclass class MegatronBaseArguments(MegatronArguments, BaseArguments): - def _init_output_dir(self): - pass + def _init_megatron_args(self): + MegatronArguments.__post_init__(self) def __post_init__(self): self.sequence_parallel_size = self.context_parallel_size if self.packing: self.padding_free = True BaseArguments.__post_init__(self) - self._init_output_dir() - MegatronArguments.__post_init__(self) + self._init_megatron_args() if self.streaming: if self.dataloader_num_workers > 1: self.dataloader_num_workers = 1 diff --git a/swift/megatron/arguments/sft_args.py b/swift/megatron/arguments/sft_args.py index d127940584..1f93388f2e 100644 --- a/swift/megatron/arguments/sft_args.py +++ b/swift/megatron/arguments/sft_args.py @@ -36,6 +36,10 @@ def _init_ckpt_dir(self, adapters=None): old_args = json.load(f) self.model = old_args.get('model') + def _init_megatron_args(self): + self._init_output_dir() + super()._init_megatron_args() + def __post_init__(self): self.mcore_model = to_abspath(self.mcore_model, check_path_exist=True) super().__post_init__() diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index f511a91601..1f80b7badc 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -674,6 +674,7 @@ def _set_moe_state( hf_prefix: str, layer_idx: int, to_mcore: bool, + is_mtp_layer: bool = False, ): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -719,30 +720,47 @@ def _set_moe_state( else: mg_experts = None hf_state_dict.update( - self._set_mlp_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank=ep_rank)) + self._set_mlp_state( + mg_experts, + hf_state_dict, + 'experts.', + layer_idx, + to_mcore, + ep_rank=ep_rank, + is_mtp_layer=is_mtp_layer)) if to_mcore: hf_state_dict = {} else: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _get_hf_grouped(self): + def _get_hf_grouped(self, is_mtp_layer: bool = False): if self.model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'dots1', 'ernie4_5_moe', 'glm4_moe', - 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe', - 'qwen3_vl_moe', 'qwen3_5_moe' + 'glm4_moe_lite', 'glm4v_moe', 'minimax_m2', 'olmoe', 'qwen3_next', 'kimi_vl', 'qwen3_omni_moe' }: return False, False + elif self.model_type == 'qwen3_5_moe' and is_mtp_layer: + return False, False return None, None - def _set_mlp_state(self, - mg_mlp, - hf_state_dict, - hf_prefix: str, - layer_idx: int, - to_mcore: bool, - ep_rank: Optional[int] = None, - hf_mlp=None): + def _get_transpose(self): + if self.model_type in {'qwen3_vl_moe', 'gpt_oss', 'llama4'}: + return True + else: + return False + + def _set_mlp_state( + self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ep_rank: Optional[int] = None, + hf_mlp=None, + is_mtp_layer: bool = False, + ): if hf_mlp is None: hf_mlp = self._get_hf_mlp(layer_idx) is_expert = ep_rank is not None @@ -756,11 +774,14 @@ def _set_mlp_state(self, is_gate_up = hasattr(hf_mlp, 'gate_up_proj') # transformers 5.0 compatibility if self.is_transformers_5: - _hf_grouped, _is_gate_up = self._get_hf_grouped() + _hf_grouped, _is_gate_up = self._get_hf_grouped(is_mtp_layer) if _hf_grouped is not None: hf_grouped = _hf_grouped if _is_gate_up is not None: is_gate_up = _is_gate_up + need_transpose = True + if self.is_transformers_5: + need_transpose = self._get_transpose() if to_mcore or hf_grouped: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) @@ -835,11 +856,14 @@ def _set_mlp_state(self, gate_up_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) else: gate_up_proj_weight = hf_state_dict['gate_up_proj'].load() - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if need_transpose: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) gate_up_proj_weight = gate_up_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] if has_scale_inv: - gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load().transpose(1, 2) + gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load() + if need_transpose: + gate_up_scale_inv = gate_up_scale_inv.transpose(1, 2) gate_up_scale_inv = gate_up_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts] if fc1_bias is not None: @@ -995,7 +1019,8 @@ def _set_mlp_state(self, if is_gate_up: if is_expert: if hf_grouped: - gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) + if need_transpose: + gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2) if 'gate_up_proj' in hf_state_dict: gate_up_proj_weight = torch.concat( [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0) @@ -1009,7 +1034,8 @@ def _set_mlp_state(self, del new_gate_up_proj_weight, gate_proj_weight, up_proj_weight hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone() if scale_inv is not None: - scale_inv = scale_inv.transpose(1, 2) + if need_transpose: + scale_inv = scale_inv.transpose(1, 2) if 'gate_up_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['gate_up_proj_scale_inv'], scale_inv], dim=0) @@ -1101,12 +1127,15 @@ def _set_mlp_state(self, down_proj_weight = self.mxfp4_quantizer.convert(blocks, scales) else: down_proj_weight = hf_state_dict['down_proj'].load() - down_proj_weight = down_proj_weight.transpose(1, 2) + if need_transpose: + down_proj_weight = down_proj_weight.transpose(1, 2) down_proj_weight = down_proj_weight[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts].reshape( -1, down_proj_weight.shape[-1]) if has_scale_inv: - down_scale_inv = hf_state_dict['down_proj_scale_inv'].load().transpose(1, 2) + down_scale_inv = hf_state_dict['down_proj_scale_inv'].load() + if need_transpose: + down_scale_inv = down_scale_inv.transpose(1, 2) down_scale_inv = down_scale_inv[ep_rank * num_local_experts:(ep_rank + 1) * num_local_experts].reshape(-1, down_scale_inv.shape[-1]) if fc2_bias is not None: @@ -1186,12 +1215,14 @@ def _set_mlp_state(self, del fc2_weight, fc2_bias if down_proj_weight is not None: if hf_grouped: - down_proj_weight = down_proj_weight.transpose(1, 2) + if need_transpose: + down_proj_weight = down_proj_weight.transpose(1, 2) if 'down_proj' in hf_state_dict: down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0) hf_state_dict['down_proj'] = down_proj_weight.clone() if scale_inv is not None: - scale_inv = scale_inv.transpose(1, 2) + if need_transpose: + scale_inv = scale_inv.transpose(1, 2) if 'down_proj_scale_inv' in hf_state_dict: scale_inv = torch.concat([hf_state_dict['down_proj_scale_inv'], scale_inv], dim=0) hf_state_dict['down_proj_scale_inv'] = scale_inv.clone() @@ -1259,13 +1290,15 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo 'input_layernorm.weight', to_mcore) return hf_state_dict - def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp_layer: bool = False): hf_mlp_prefix = self.get_hf_mlp_prefix(layer_idx) hf_mlp = self._get_hf_mlp(layer_idx) is_moe = self._is_moe(hf_mlp.state_dict()) mg_mlp = None if mg_layer is None else mg_layer.mlp if is_moe: - hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore)) + hf_state_dict.update( + self._set_moe_state( + mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp_layer=is_mtp_layer)) self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', to_mcore) else: @@ -1453,7 +1486,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore, is_mtp_layer=True)) if to_mcore: hf_state_dict = {} else: diff --git a/swift/megatron/model/mm_gpts/qwen3_5.py b/swift/megatron/model/mm_gpts/qwen3_5.py index 212531f123..d326b2d6a5 100644 --- a/swift/megatron/model/mm_gpts/qwen3_5.py +++ b/swift/megatron/model/mm_gpts/qwen3_5.py @@ -78,7 +78,7 @@ def __init__(self, config): super().__init__(config, [Qwen3_5TextModel, Qwen3_5MoeTextModel]) def get_inputs_embeds(self, inputs_embeds, **kwargs): - return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config) + return Template._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.hf_config) class Qwen3_5Bridge(Qwen3NextBridge): diff --git a/swift/model/models/qwen.py b/swift/model/models/qwen.py index 60b058df61..5dbb1551ca 100644 --- a/swift/model/models/qwen.py +++ b/swift/model/models/qwen.py @@ -1122,12 +1122,14 @@ def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrain register_model( ModelMeta( MLLMModelType.qwen3_5_moe, [ - ModelGroup([], TemplateType.qwen3_5), + ModelGroup([ + Model('Qwen/Qwen3.5-397B-A17B', 'Qwen/Qwen3.5-397B-A17B'), + ], TemplateType.qwen3_5), ], Qwen3_5MoeLoader, model_arch=ModelArch.qwen2_vl, architectures=['Qwen3_5MoeForConditionalGeneration'], - requires=['transformers>=5.0.0.dev', 'qwen_vl_utils>=0.0.14', 'decord'], + requires=['transformers>=5.2.0.dev', 'qwen_vl_utils>=0.0.14', 'decord'], tags=['vision', 'video'])) diff --git a/swift/template/templates/qwen.py b/swift/template/templates/qwen.py index c4963ef3ab..5ce8fe3ad7 100644 --- a/swift/template/templates/qwen.py +++ b/swift/template/templates/qwen.py @@ -562,7 +562,12 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: register_template( QwenTemplateMeta( - MLLMTemplateType.qwen3_5, template_cls=Qwen3_5Template, default_system=None, thinking_prefix='\n')) + MLLMTemplateType.qwen3_5, + template_cls=Qwen3_5Template, + default_system=None, + thinking_prefix='\n', + non_thinking_prefix='\n\n\n\n', + is_thinking=True)) class Qwen3VLEmbTemplate(Qwen3VLTemplate):