From 36105382a157e283a4b1130396f2c9a0775363ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 16:58:06 +0800 Subject: [PATCH 01/13] update --- src/mcore_bridge/model/modules/gated_delta_net.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 87580fd..a8d99f3 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import TransformerConfig from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional @@ -17,7 +18,7 @@ try: from megatron.core.ssm.gated_delta_net import GatedDeltaNet as _GatedDeltaNet - from megatron.core.ssm.gated_delta_net import torch_chunk_gated_delta_rule + from megatron.core.ssm.gated_delta_net import GatedDeltaNetSubmodules, torch_chunk_gated_delta_rule except ImportError: _GatedDeltaNet = object @@ -83,6 +84,9 @@ def get_parameter_local_cp( class GatedDeltaNet(_GatedDeltaNet): + def __init__(self, config: TransformerConfig, submodules: 'GatedDeltaNetSubmodules', *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + def forward( self, hidden_states: torch.Tensor, From 9ef352011fd9eafe3b8e5d55603980a3f766c5f0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 17:52:33 +0800 Subject: [PATCH 02/13] update --- src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py | 6 ++- .../model/modules/gated_delta_net.py | 38 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py b/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py index 780e0fd..dfb8113 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm from megatron.core.transformer.attention import SelfAttention from typing import Optional @@ -20,8 +21,7 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo if is_linear_attention: hf_state_dict.update( self._set_linear_attn_state(mg_attn, hf_state_dict, 'linear_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'self_attention.in_proj.layer_norm_weight', hf_state_dict, - 'input_layernorm.weight', to_mcore) + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) else: hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, @@ -43,7 +43,9 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): if issubclass(attn_module, SelfAttention): layer_spec.submodules.self_attention.module = GatedSelfAttention else: + layer_spec.submodules.input_layernorm = TENorm layer_spec.submodules.self_attention.module = GatedDeltaNet + layer_spec.submodules.self_attention.submodules.in_proj = TEColumnParallelLinear return layer_specs def build_model( diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index a8d99f3..b14c3c3 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -4,6 +4,8 @@ from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional @@ -85,7 +87,39 @@ def get_parameter_local_cp( class GatedDeltaNet(_GatedDeltaNet): def __init__(self, config: TransformerConfig, submodules: 'GatedDeltaNetSubmodules', *args, **kwargs): + in_proj = submodules.in_proj + submodules.in_proj = IdentityOp super().__init__(config, submodules, *args, **kwargs) + submodules.in_proj = in_proj + self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2 + self.in_proj_ba_dim = self.num_value_heads * 2 + del self.in_proj + self.in_proj_qkvz = build_module( + submodules.in_proj, + self.hidden_size, + self.in_proj_qkvz_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1_qkvz', + tp_group=self.pg_collection.tp, + ) + self.in_proj_qkvz = build_module( + submodules.in_proj, + self.hidden_size, + self.in_proj_ba_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1_ba', + tp_group=self.pg_collection.tp, + ) def forward( self, @@ -137,7 +171,9 @@ def forward( cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q # Input projection nvtx_range_push(suffix='in_proj') - qkvzba, _ = self.in_proj(hidden_states) + qkvz, _ = self.in_proj_qkvz(hidden_states) + ba, _ = self.in_proj_ba(hidden_states) + qkvzba = torch.concat([qkvz, ba], dim=0) nvtx_range_pop(suffix='in_proj') if cp_size > 1: From 3bc7c9cfeccfb1715a3cb878084e49d6e0f42435 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 19:05:26 +0800 Subject: [PATCH 03/13] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 108 ++++++++++++------ .../model/modules/gated_delta_net.py | 2 +- 2 files changed, 75 insertions(+), 35 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 01c61cc..15bc291 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -114,7 +114,8 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: dim0_keys = { 'word_embeddings', 'linear_qkv', - 'in_proj', + 'in_proj_qkvz', + 'in_proj_ba', 'conv1d', # mla 'linear_q_proj', @@ -1241,38 +1242,34 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i key_dim = config.linear_key_head_dim value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads if to_mcore: - if isinstance(mg_attn.in_proj, LoraParallelLinear): - lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() - assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all() and \ - (lora_A == hf_state_dict['in_proj_b.lora_A.weight'].load()).all() and \ - (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ - 'Need to ensure QKVZBA\'s lora_A are consistent' - qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() + if isinstance(mg_attn.in_proj_qkvz, LoraParallelLinear): + lora_A = hf_state_dict['in_proj_qkvz.lora_A.weight'].load() + assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all(), \ + 'Need to ensure QKVZ\'s lora_A are consistent' + qkv_lora_B = hf_state_dict['in_proj_qkvz.lora_B.weight'].load() q_lora_B, k_lora_B, v_lora_B = torch.split( qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) lora_B = torch.cat([ *(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]), - *(hf_state_dict[f'{key}.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) - for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']) + hf_state_dict['in_proj_z.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) ], dim=1).reshape(-1, qkv_lora_B.shape[-1]) - self._set_weight(mg_attn.in_proj.lora_A[self._adapter_name].weight, lora_A, 'in_proj.lora_A.weight') - self._set_weight(mg_attn.in_proj.lora_B[self._adapter_name].weight, lora_B, 'in_proj.lora_B.weight') + self._set_weight(mg_attn.in_proj_qkvz.lora_A[self._adapter_name].weight, lora_A, + 'in_proj_qkvz.lora_A.weight') + self._set_weight(mg_attn.in_proj_qkvz.lora_B[self._adapter_name].weight, lora_B, + 'in_proj_qkvz.lora_B.weight') elif not self._peft_format: qkv = hf_state_dict['in_proj_qkv.weight'].load() q, k, v = torch.split( qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) in_proj_weight = torch.cat([ *(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]), - *(hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size) - for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']), + hf_state_dict['in_proj_z.weight'].load().reshape(num_key_heads, -1, config.hidden_size) ], dim=1).reshape((-1, config.hidden_size)) - self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight') + self._set_weight(mg_attn.in_proj_qkvz.weight, in_proj_weight, 'in_proj_qkvz.weight') else: qkv_dim = key_dim * 2 + value_dim - z_dim = value_dim - a_dim = config.linear_num_value_heads // num_key_heads is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj, LoraParallelLinear) and self._peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') @@ -1280,38 +1277,81 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i dist.all_reduce(is_lora, group=self.pp_group) if is_lora: lora_A, _ = self._get_weight( - None if mg_attn is None else mg_attn.in_proj.lora_A[self._adapter_name].weight.data, - f'in_proj.lora_A.{self._adapter_name}.weight') + None if mg_attn is None else mg_attn.in_proj_qkvz.lora_A[self._adapter_name].weight.data, + f'in_proj_qkvz.lora_A.{self._adapter_name}.weight') lora_B, _ = self._get_weight( - None if mg_attn is None else mg_attn.in_proj.lora_B[self._adapter_name].weight.data, - f'in_proj.lora_B.{self._adapter_name}.weight') + None if mg_attn is None else mg_attn.in_proj_qkvz.lora_B[self._adapter_name].weight.data, + f'in_proj_qkvz.lora_B.{self._adapter_name}.weight') if lora_A is not None: lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) - self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a'}) - for key in ['in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a']: + self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z'}) + for key in ['in_proj_qkv', 'in_proj_z']: hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1]) k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1]) v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1]) hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0) - hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:qkv_dim + z_dim].reshape( - -1, lora_B.shape[-1]).clone() - hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, qkv_dim + z_dim:-a_dim].reshape( - -1, lora_B.shape[-1]).clone() - hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:].reshape(-1, lora_B.shape[-1]).clone() elif not self._peft_format: - in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, - 'in_proj.weight') + in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj_qkvz.weight.data, + 'in_proj_qkvz.weight') if in_proj_weight is not None: in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size) v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size) hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) - hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:(qkv_dim + z_dim)].reshape( - -1, config.hidden_size).clone() - hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim):-a_dim].reshape( - -1, config.hidden_size).clone() + hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:].reshape(-1, + config.hidden_size).clone() + if to_mcore: + if isinstance(mg_attn.in_proj_ba, LoraParallelLinear): + lora_A = hf_state_dict['in_proj_b.lora_A.weight'].load() + assert (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ + 'Need to ensure BA\'s lora_A are consistent' + b_lora_B = hf_state_dict['in_proj_b.lora_B.weight'].load() + lora_B = torch.cat([ + b_lora_B.reshape(num_key_heads, -1, b_lora_B.shape[-1]), + hf_state_dict['in_proj_a.lora_B.weight'].load().reshape(num_key_heads, -1, b_lora_B.shape[-1]), + ], + dim=1).reshape(-1, b_lora_B.shape[-1]) + self._set_weight(mg_attn.in_proj_ba.lora_A[self._adapter_name].weight, lora_A, + 'in_proj_ba.lora_A.weight') + self._set_weight(mg_attn.in_proj_ba.lora_B[self._adapter_name].weight, lora_B, + 'in_proj_ba.lora_B.weight') + elif not self._peft_format: + in_proj_weight = torch.cat([ + hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size) + for key in ['in_proj_b', 'in_proj_a'] + ], + dim=1).reshape((-1, config.hidden_size)) + self._set_weight(mg_attn.in_proj_ba.weight, in_proj_weight, 'in_proj_ba.weight') + else: + a_dim = config.linear_num_value_heads // num_key_heads + is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj_ba, + LoraParallelLinear) and self._peft_format + is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + lora_A, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_ba.lora_A[self._adapter_name].weight.data, + f'in_proj_ba.lora_A.{self._adapter_name}.weight') + lora_B, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_ba.lora_B[self._adapter_name].weight.data, + f'in_proj_ba.lora_B.{self._adapter_name}.weight') + if lora_A is not None: + lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) + self._peft_target_modules.update({'in_proj_b', 'in_proj_a'}) + for key in ['in_proj_b', 'in_proj_a']: + hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() + hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, :-a_dim].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() + elif not self._peft_format: + in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj_ba.weight.data, + 'in_proj_ba.weight') + if in_proj_weight is not None: + hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, :-a_dim].reshape(-1, + config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() if not self._peft_format: diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index b14c3c3..3845882 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -107,7 +107,7 @@ def __init__(self, config: TransformerConfig, submodules: 'GatedDeltaNetSubmodul tp_comm_buffer_name='fc1_qkvz', tp_group=self.pg_collection.tp, ) - self.in_proj_qkvz = build_module( + self.in_proj_ba = build_module( submodules.in_proj, self.hidden_size, self.in_proj_ba_dim, From 41d5162ce499581d0d3f376f03dca0d8999298aa Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 19:53:16 +0800 Subject: [PATCH 04/13] qwen3_5 fp8 --- .gitignore | 1 + requirements.txt | 2 +- src/mcore_bridge/bridge/gpt_bridge.py | 28 +++++++++++++++++-- .../model/modules/gated_delta_net.py | 11 ++++---- src/mcore_bridge/patcher.py | 1 - 5 files changed, 32 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index a98c3bd..9dc57fb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ *.log *.out +/megatron_output/ diff --git a/requirements.txt b/requirements.txt index 9c0d007..56e5eb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ modelscope peft>=0.11,<0.19 safetensors tqdm -transformers>=4.33,<5.4.0 +transformers>=4.33,<5.6.0 diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 15bc291..bd447ce 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1241,6 +1241,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i num_key_heads = config.linear_num_key_heads key_dim = config.linear_key_head_dim value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads + hidden_size_block = config.hidden_size // self.fp8_block_size if to_mcore: if isinstance(mg_attn.in_proj_qkvz, LoraParallelLinear): lora_A = hf_state_dict['in_proj_qkvz.lora_A.weight'].load() @@ -1267,7 +1268,19 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict['in_proj_z.weight'].load().reshape(num_key_heads, -1, config.hidden_size) ], dim=1).reshape((-1, config.hidden_size)) - self._set_weight(mg_attn.in_proj_qkvz.weight, in_proj_weight, 'in_proj_qkvz.weight') + in_scale_inv = None + if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: + qkv_scale_inv = hf_state_dict['in_proj_qkv.weight_scale_inv'].load() + q_si, k_si, v_si = torch.split( + qkv_scale_inv, [x * num_key_heads // 128 for x in [key_dim, key_dim, value_dim]], dim=0) + in_scale_inv = torch.cat([ + *(x.reshape(num_key_heads, -1, hidden_size_block) for x in [q_si, k_si, v_si]), + hf_state_dict['in_proj_z.weight_scale_inv'].load().reshape(num_key_heads, -1, + hidden_size_block), + ], + dim=1).reshape((-1, hidden_size_block)) + self._set_weight( + mg_attn.in_proj_qkvz.weight, in_proj_weight, 'in_proj_qkvz.weight', hf_scale_inv=in_scale_inv) else: qkv_dim = key_dim * 2 + value_dim is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj, @@ -1293,8 +1306,8 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0) hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:].reshape(-1, lora_B.shape[-1]).clone() elif not self._peft_format: - in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj_qkvz.weight.data, - 'in_proj_qkvz.weight') + in_proj_weight, scale_inv = self._get_weight( + None if mg_attn is None else mg_attn.in_proj_qkvz.weight.data, 'in_proj_qkvz.weight') if in_proj_weight is not None: in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) @@ -1303,6 +1316,15 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:].reshape(-1, config.hidden_size).clone() + if scale_inv is not None: + key_block = key_dim // self.fp8_block_size + qkv_block = qkv_dim // self.fp8_block_size + scale_inv = scale_inv.reshape(num_key_heads, -1, hidden_size_block) + q = scale_inv[:, :key_block].reshape(-1, hidden_size_block) + k = scale_inv[:, key_block:2 * key_block].reshape(-1, hidden_size_block) + v = scale_inv[:, 2 * key_block:qkv_block].reshape(-1, hidden_size_block) + hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) + hf_state_dict['in_proj_z.weight'] = scale_inv[:, qkv_block:].reshape(-1, hidden_size_block).clone() if to_mcore: if isinstance(mg_attn.in_proj_ba, LoraParallelLinear): lora_A = hf_state_dict['in_proj_b.lora_A.weight'].load() diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 3845882..3f7d10e 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -173,7 +173,10 @@ def forward( nvtx_range_push(suffix='in_proj') qkvz, _ = self.in_proj_qkvz(hidden_states) ba, _ = self.in_proj_ba(hidden_states) - qkvzba = torch.concat([qkvz, ba], dim=0) + num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size + qkvz = qkvz.view(qkvz.shape[:-1] + (num_key_heads_per_device, qkvz.shape[-1] // num_key_heads_per_device)) + ba = ba.view(ba.shape[:-1] + (num_key_heads_per_device, ba.shape[-1] // num_key_heads_per_device)) + qkvzba = torch.concat([qkvz, ba], dim=-1).view(*qkvz.shape[:2], -1) nvtx_range_pop(suffix='in_proj') if cp_size > 1: @@ -198,15 +201,11 @@ def forward( head_dim=-1, cp_group=self.pg_collection.cp, ) - # Transpose: s b x --> b s x # From sbhd to bshd format - qkvzba = qkvzba.transpose(0, 1) - - # Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha - num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size qkvzba = qkvzba.view(qkvzba.shape[:-1] + (num_key_heads_per_device, qkvzba.shape[-1] // num_key_heads_per_device)) + qkvzba = qkvzba.transpose(0, 1) qkv, gate, beta, alpha = torch.split( qkvzba, [ diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 4b17274..a4dd6d6 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -418,7 +418,6 @@ def forward( Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used. """ - # TODO: Multimodal compatible assert context is None, 'multi token prediction + cross attention is not yet supported.' input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, From 70490e5fced88d5addd149a47a0e775388fda30d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 22:40:49 +0800 Subject: [PATCH 05/13] update --- src/mcore_bridge/bridge/gpt_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 15bc291..5af530d 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1243,10 +1243,10 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads if to_mcore: if isinstance(mg_attn.in_proj_qkvz, LoraParallelLinear): - lora_A = hf_state_dict['in_proj_qkvz.lora_A.weight'].load() + lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all(), \ 'Need to ensure QKVZ\'s lora_A are consistent' - qkv_lora_B = hf_state_dict['in_proj_qkvz.lora_B.weight'].load() + qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() q_lora_B, k_lora_B, v_lora_B = torch.split( qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) lora_B = torch.cat([ From e41ff5365d3b71653bec12526284e56af1389fdd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 22:42:52 +0800 Subject: [PATCH 06/13] fix --- src/mcore_bridge/model/modules/gated_delta_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 3f7d10e..4470c31 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -104,7 +104,7 @@ def __init__(self, config: TransformerConfig, submodules: 'GatedDeltaNetSubmodul bias=self.bias, skip_bias_add=False, is_expert=False, - tp_comm_buffer_name='fc1_qkvz', + tp_comm_buffer_name='fc1', tp_group=self.pg_collection.tp, ) self.in_proj_ba = build_module( From b804add71f653f17a2917d33dc61b9b941855bcf Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 23:06:57 +0800 Subject: [PATCH 07/13] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 11 ++++--- .../model/modules/gated_delta_net.py | 33 +++++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 7b6bd27..2822eac 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1272,7 +1272,9 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: qkv_scale_inv = hf_state_dict['in_proj_qkv.weight_scale_inv'].load() q_si, k_si, v_si = torch.split( - qkv_scale_inv, [x * num_key_heads // 128 for x in [key_dim, key_dim, value_dim]], dim=0) + qkv_scale_inv, + [x * num_key_heads // self.fp8_block_size for x in [key_dim, key_dim, value_dim]], + dim=0) in_scale_inv = torch.cat([ *(x.reshape(num_key_heads, -1, hidden_size_block) for x in [q_si, k_si, v_si]), hf_state_dict['in_proj_z.weight_scale_inv'].load().reshape(num_key_heads, -1, @@ -1283,7 +1285,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i mg_attn.in_proj_qkvz.weight, in_proj_weight, 'in_proj_qkvz.weight', hf_scale_inv=in_scale_inv) else: qkv_dim = key_dim * 2 + value_dim - is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj, + is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj_qkvz, LoraParallelLinear) and self._peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') if self.pp_size > 1: @@ -1323,8 +1325,9 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i q = scale_inv[:, :key_block].reshape(-1, hidden_size_block) k = scale_inv[:, key_block:2 * key_block].reshape(-1, hidden_size_block) v = scale_inv[:, 2 * key_block:qkv_block].reshape(-1, hidden_size_block) - hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) - hf_state_dict['in_proj_z.weight'] = scale_inv[:, qkv_block:].reshape(-1, hidden_size_block).clone() + hf_state_dict['in_proj_qkv.weight_scale_inv'] = torch.concat([q, k, v], dim=0) + hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[:, qkv_block:].reshape( + -1, hidden_size_block).clone() if to_mcore: if isinstance(mg_attn.in_proj_ba, LoraParallelLinear): lora_A = hf_state_dict['in_proj_b.lora_A.weight'].load() diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 4470c31..b48f8d9 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -8,7 +8,7 @@ from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional - +import transformer_engine try: from fla.modules.convolution import causal_conv1d from fla.modules.l2norm import l2norm @@ -107,19 +107,24 @@ def __init__(self, config: TransformerConfig, submodules: 'GatedDeltaNetSubmodul tp_comm_buffer_name='fc1', tp_group=self.pg_collection.tp, ) - self.in_proj_ba = build_module( - submodules.in_proj, - self.hidden_size, - self.in_proj_ba_dim, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.bias, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name='fc1_ba', - tp_group=self.pg_collection.tp, - ) + if config.fp8_param: + fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) + else: + fp8_context = nullcontext() + with fp8_context: + self.in_proj_ba = build_module( + submodules.in_proj, + self.hidden_size, + self.in_proj_ba_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='fc1_ba', + tp_group=self.pg_collection.tp, + ) def forward( self, From 2a03a117c01beb07554629eca25d4e362669ed1b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 23:12:25 +0800 Subject: [PATCH 08/13] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 1 + src/mcore_bridge/model/modules/gated_delta_net.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 2822eac..63e809d 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1375,6 +1375,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj_ba.weight.data, 'in_proj_ba.weight') if in_proj_weight is not None: + in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, :-a_dim].reshape(-1, config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index b48f8d9..cb166ae 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch import torch.nn.functional as F +import transformer_engine from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import TransformerConfig @@ -8,7 +9,7 @@ from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional -import transformer_engine + try: from fla.modules.convolution import causal_conv1d from fla.modules.l2norm import l2norm From a760a2ef8ae92d6e2326d15db1ea699896d8c306 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 15 Apr 2026 17:17:42 +0800 Subject: [PATCH 09/13] fix --- src/mcore_bridge/model/modules/gated_delta_net.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index cb166ae..4255e9c 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -178,7 +178,13 @@ def forward( # Input projection nvtx_range_push(suffix='in_proj') qkvz, _ = self.in_proj_qkvz(hidden_states) - ba, _ = self.in_proj_ba(hidden_states) + if self.config.fp8_param: + fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) + else: + fp8_context = nullcontext() + fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) + with fp8_context: + ba, _ = self.in_proj_ba(hidden_states) num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size qkvz = qkvz.view(qkvz.shape[:-1] + (num_key_heads_per_device, qkvz.shape[-1] // num_key_heads_per_device)) ba = ba.view(ba.shape[:-1] + (num_key_heads_per_device, ba.shape[-1] // num_key_heads_per_device)) From 164675c8195f436008d6172c4fe76d968bb98e9a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 15 Apr 2026 22:46:49 +0800 Subject: [PATCH 10/13] fix --- src/mcore_bridge/model/modules/gated_delta_net.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 4255e9c..da1bbbb 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F import transformer_engine +from contextlib import nullcontext from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import TransformerConfig From 9f66381d326ca68333f59bfbd272a74be61104a2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 15 Apr 2026 23:46:35 +0800 Subject: [PATCH 11/13] update --- src/mcore_bridge/bridge/gpt_bridge.py | 114 +++++++++++++++++- src/mcore_bridge/config/model_config.py | 1 + src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py | 13 +- .../model/modules/gated_delta_net.py | 36 +++--- 4 files changed, 144 insertions(+), 20 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 63e809d..b403d90 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -15,6 +15,7 @@ from transformers.utils import ContextManagers from typing import Callable, List, Optional, Union +from mcore_bridge.config import ModelConfig from mcore_bridge.tuners import LoraParallelLinear from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, gc_collect, get_logger, is_master, unwrap_model) @@ -45,7 +46,7 @@ class GPTBridge: hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' - def __init__(self, config): + def __init__(self, config: ModelConfig): self.config = config self._disable_tqdm = False self._target_device = None @@ -114,6 +115,7 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: dim0_keys = { 'word_embeddings', 'linear_qkv', + 'in_proj', 'in_proj_qkvz', 'in_proj_ba', 'conv1d', @@ -1232,7 +1234,7 @@ def _set_indexer(self, mg_indexer, hf_state_dict, hf_prefix: str, to_mcore: bool hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + def _set_linear_decoupled_in_proj(self, mg_attn, hf_state_dict, hf_prefix: str, to_mcore: bool): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: @@ -1380,6 +1382,114 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_linear_in_proj(self, mg_attn, hf_state_dict, hf_prefix: str, to_mcore: bool): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + config = self.config + num_key_heads = config.linear_num_key_heads + key_dim = config.linear_key_head_dim + value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads + if to_mcore: + if isinstance(mg_attn.in_proj, LoraParallelLinear): + lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() + assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all() and \ + (lora_A == hf_state_dict['in_proj_b.lora_A.weight'].load()).all() and \ + (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ + 'Need to ensure QKVZBA\'s lora_A are consistent' + qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() + q_lora_B, k_lora_B, v_lora_B = torch.split( + qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) + lora_B = torch.cat([ + *(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]), + *(hf_state_dict[f'{key}.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) + for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']) + ], + dim=1).reshape(-1, qkv_lora_B.shape[-1]) + self._set_weight(mg_attn.in_proj.lora_A[self._adapter_name].weight, lora_A, 'in_proj.lora_A.weight') + self._set_weight(mg_attn.in_proj.lora_B[self._adapter_name].weight, lora_B, 'in_proj.lora_B.weight') + elif not self._peft_format: + qkv = hf_state_dict['in_proj_qkv.weight'].load() + q, k, v = torch.split( + qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) + in_proj_weight = torch.cat([ + *(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]), + *(hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size) + for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']), + ], + dim=1).reshape((-1, config.hidden_size)) + self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight') + else: + qkv_dim = key_dim * 2 + value_dim + z_dim = value_dim + a_dim = config.linear_num_value_heads // num_key_heads + is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj, + LoraParallelLinear) and self._peft_format + is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) + if is_lora: + lora_A, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj.lora_A[self._adapter_name].weight.data, + f'in_proj.lora_A.{self._adapter_name}.weight') + lora_B, _ = self._get_weight( + None if mg_attn is None else mg_attn.in_proj.lora_B[self._adapter_name].weight.data, + f'in_proj.lora_B.{self._adapter_name}.weight') + if lora_A is not None: + lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) + self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a'}) + for key in ['in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a']: + hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() + q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1]) + k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1]) + v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1]) + hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0) + hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:qkv_dim + z_dim].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, qkv_dim + z_dim:-a_dim].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() + elif not self._peft_format: + in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, + 'in_proj.weight') + if in_proj_weight is not None: + in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) + q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) + k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size) + v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size) + hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) + hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:(qkv_dim + z_dim)].reshape( + -1, config.hidden_size).clone() + hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim):-a_dim].reshape( + -1, config.hidden_size).clone() + hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, + config.hidden_size).clone() + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + config = self.config + num_key_heads = config.linear_num_key_heads + key_dim = config.linear_key_head_dim + value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads + if config.linear_decoupled_in_proj: + hf_state_dict.update(self._set_linear_decoupled_in_proj(mg_attn, hf_state_dict, hf_prefix, to_mcore)) + else: + hf_state_dict.update(self._set_linear_in_proj(mg_attn, hf_state_dict, hf_prefix, to_mcore)) if not self._peft_format: if to_mcore: conv1d = hf_state_dict['conv1d.weight'].load() diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 253d12b..10db680 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -193,6 +193,7 @@ class ModelConfig(TransformerConfig): linear_conv_kernel_dim: Optional[int] = None layernorm_zero_centered_gamma: bool = False attention_output_gate: bool = False + linear_decoupled_in_proj: bool = False # dsa experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py b/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py index dfb8113..d3c40a8 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_5_gdn.py @@ -21,7 +21,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo if is_linear_attention: hf_state_dict.update( self._set_linear_attn_state(mg_attn, hf_state_dict, 'linear_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) + + if self.config.linear_decoupled_in_proj: + self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', + to_mcore) + else: + self._set_state_dict(mg_layer, 'self_attention.in_proj.layer_norm_weight', hf_state_dict, + 'input_layernorm.weight', to_mcore) else: hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, @@ -43,9 +49,10 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): if issubclass(attn_module, SelfAttention): layer_spec.submodules.self_attention.module = GatedSelfAttention else: - layer_spec.submodules.input_layernorm = TENorm layer_spec.submodules.self_attention.module = GatedDeltaNet - layer_spec.submodules.self_attention.submodules.in_proj = TEColumnParallelLinear + if self.config.linear_decoupled_in_proj: + layer_spec.submodules.input_layernorm = TENorm + layer_spec.submodules.self_attention.submodules.in_proj = TEColumnParallelLinear return layer_specs def build_model( diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index da1bbbb..79fc08f 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -5,12 +5,13 @@ from contextlib import nullcontext from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer import TransformerConfig from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from typing import List, Optional +from mcore_bridge.config import ModelConfig + try: from fla.modules.convolution import causal_conv1d from fla.modules.l2norm import l2norm @@ -88,10 +89,13 @@ def get_parameter_local_cp( class GatedDeltaNet(_GatedDeltaNet): - def __init__(self, config: TransformerConfig, submodules: 'GatedDeltaNetSubmodules', *args, **kwargs): - in_proj = submodules.in_proj - submodules.in_proj = IdentityOp + def __init__(self, config: ModelConfig, submodules: 'GatedDeltaNetSubmodules', *args, **kwargs): + if config.linear_decoupled_in_proj: + in_proj = submodules.in_proj + submodules.in_proj = IdentityOp super().__init__(config, submodules, *args, **kwargs) + if not config.linear_decoupled_in_proj: + return submodules.in_proj = in_proj self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2 self.in_proj_ba_dim = self.num_value_heads * 2 @@ -178,18 +182,20 @@ def forward( cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q # Input projection nvtx_range_push(suffix='in_proj') - qkvz, _ = self.in_proj_qkvz(hidden_states) - if self.config.fp8_param: - fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) + if self.config.linear_decoupled_in_proj: + qkvz, _ = self.in_proj_qkvz(hidden_states) + if self.config.fp8_param: + fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) + else: + fp8_context = nullcontext() + with fp8_context: + ba, _ = self.in_proj_ba(hidden_states) + num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size + qkvz = qkvz.view(qkvz.shape[:-1] + (num_key_heads_per_device, qkvz.shape[-1] // num_key_heads_per_device)) + ba = ba.view(ba.shape[:-1] + (num_key_heads_per_device, ba.shape[-1] // num_key_heads_per_device)) + qkvzba = torch.concat([qkvz, ba], dim=-1).view(*qkvz.shape[:2], -1) else: - fp8_context = nullcontext() - fp8_context = transformer_engine.pytorch.fp8_autocast(enabled=False) - with fp8_context: - ba, _ = self.in_proj_ba(hidden_states) - num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size - qkvz = qkvz.view(qkvz.shape[:-1] + (num_key_heads_per_device, qkvz.shape[-1] // num_key_heads_per_device)) - ba = ba.view(ba.shape[:-1] + (num_key_heads_per_device, ba.shape[-1] // num_key_heads_per_device)) - qkvzba = torch.concat([qkvz, ba], dim=-1).view(*qkvz.shape[:2], -1) + qkvzba, _ = self.in_proj(hidden_states) nvtx_range_pop(suffix='in_proj') if cp_size > 1: From 1a786dfbf1704c609f2f0e1da80f2edaa6948249 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 16 Apr 2026 00:06:24 +0800 Subject: [PATCH 12/13] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 24 ++++--------------- .../model/modules/gated_delta_net.py | 2 +- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index b403d90..bd9118c 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1234,11 +1234,7 @@ def _set_indexer(self, mg_indexer, hf_state_dict, hf_prefix: str, to_mcore: bool hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_linear_decoupled_in_proj(self, mg_attn, hf_state_dict, hf_prefix: str, to_mcore: bool): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} + def _set_linear_decoupled_in_proj(self, mg_attn, hf_state_dict, to_mcore: bool): config = self.config num_key_heads = config.linear_num_key_heads key_dim = config.linear_key_head_dim @@ -1382,17 +1378,9 @@ def _set_linear_decoupled_in_proj(self, mg_attn, hf_state_dict, hf_prefix: str, config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_linear_in_proj(self, mg_attn, hf_state_dict, hf_prefix: str, to_mcore: bool): - if to_mcore: - hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) - else: - hf_state_dict = {} + def _set_linear_in_proj(self, mg_attn, hf_state_dict, to_mcore: bool): config = self.config num_key_heads = config.linear_num_key_heads key_dim = config.linear_key_head_dim @@ -1471,10 +1459,6 @@ def _set_linear_in_proj(self, mg_attn, hf_state_dict, hf_prefix: str, to_mcore: -1, config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() - if to_mcore: - hf_state_dict = {} - else: - hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -1487,9 +1471,9 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i key_dim = config.linear_key_head_dim value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads if config.linear_decoupled_in_proj: - hf_state_dict.update(self._set_linear_decoupled_in_proj(mg_attn, hf_state_dict, hf_prefix, to_mcore)) + hf_state_dict.update(self._set_linear_decoupled_in_proj(mg_attn, hf_state_dict, to_mcore)) else: - hf_state_dict.update(self._set_linear_in_proj(mg_attn, hf_state_dict, hf_prefix, to_mcore)) + hf_state_dict.update(self._set_linear_in_proj(mg_attn, hf_state_dict, to_mcore)) if not self._peft_format: if to_mcore: conv1d = hf_state_dict['conv1d.weight'].load() diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 79fc08f..9d372d0 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -181,6 +181,7 @@ def forward( cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q # Input projection + num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size nvtx_range_push(suffix='in_proj') if self.config.linear_decoupled_in_proj: qkvz, _ = self.in_proj_qkvz(hidden_states) @@ -190,7 +191,6 @@ def forward( fp8_context = nullcontext() with fp8_context: ba, _ = self.in_proj_ba(hidden_states) - num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size qkvz = qkvz.view(qkvz.shape[:-1] + (num_key_heads_per_device, qkvz.shape[-1] // num_key_heads_per_device)) ba = ba.view(ba.shape[:-1] + (num_key_heads_per_device, ba.shape[-1] // num_key_heads_per_device)) qkvzba = torch.concat([qkvz, ba], dim=-1).view(*qkvz.shape[:2], -1) From 0c1a840f6c1c9c0e9292ba6ae49586273e2e0774 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 16 Apr 2026 00:27:56 +0800 Subject: [PATCH 13/13] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index bd9118c..8864739 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -708,9 +708,9 @@ def _set_moe_state( def _get_hf_experts_attr(self, is_mtp: bool = False): # return hf_grouped, is_gate_up - if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe'} or self.llm_model_type in { + if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_5_moe'} or self.llm_model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe', - 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'qwen3_5_moe', 'glm_moe_dsa', 'deepseek_v32' + 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32' }: return False, False elif self.model_type in {'qwen3_vl_moe', 'llama4'} or self.llm_model_type in {'gpt_oss'}: @@ -1757,7 +1757,8 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = {} self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer - if not to_mcore and not self.llm_model_type == 'qwen3_next': + # TODO: check + if not to_mcore and self.llm_model_type in {'deepseek_v3', 'deepseek_v32', 'glm4_moe', 'glm4_moe_lite'}: self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) if self.config.untie_embeddings_and_output_weights: