Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ class ModelConfig(TransformerConfig):
dsa_indexer_use_sparse_loss: bool = False
dsa_indexer_rotary_interleaved: bool = False

# mtp
mtp_unroll_steps: Optional[int] = None
decoder_input_detach: bool = True

# visual
hf_config: Optional[PretrainedConfig] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'linear_key_head_dim': ['linear_key_head_dim'],
'linear_value_head_dim': ['linear_value_head_dim'],
'linear_conv_kernel_dim': ['linear_conv_kernel_dim'],
'mtp_unroll_steps': ['mtp_unroll_steps'],
# dsa
'dsa_indexer_n_heads': ['index_n_heads'],
'dsa_indexer_head_dim': ['index_head_dim'],
Expand Down
12 changes: 7 additions & 5 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,10 @@ def _postprocess(
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)

if self.mtp_process and labels is not None:
mtp_depth = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers
if self.config.is_multimodal:
embedding_ = (self.embedding, decoder_input)
_decoder_input = decoder_input.detach() if self.config.decoder_input_detach else decoder_input
embedding_ = (self.embedding, _decoder_input)
else:
embedding_ = self.embedding
hidden_states = self.mtp(
Expand All @@ -423,12 +425,12 @@ def _postprocess(
**(extra_block_kwargs or {}),
)
mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states_list = torch.chunk(hidden_states, 1 + mtp_depth, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
for mtp_layer_number in range(mtp_depth):
# output
mtp_logits, _ = self.output_layer(
hidden_states_list[mtp_layer_number + 1],
Expand Down Expand Up @@ -460,10 +462,10 @@ def _postprocess(
MTPLossLoggingHelper.save_loss_to_tracker(
mtp_loss_for_log,
mtp_layer_number,
self.config.mtp_num_layers,
mtp_depth,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
mtp_loss_scale = self.config.mtp_loss_scaling_factor / mtp_depth
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
Expand Down
81 changes: 71 additions & 10 deletions src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
scatter_to_sequence_parallel_region)
from megatron.core.transformer import TransformerLayer
from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention
from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer
from megatron.core.transformer.multi_token_prediction import (MultiTokenPredictionBlock, MultiTokenPredictionLayer,
get_mtp_layer_offset)
from megatron.core.utils import deprecate_inference_params
from packaging import version
from peft.tuners.tuners_utils import BaseTuner
Expand Down Expand Up @@ -394,6 +395,7 @@ def forward(
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
embedding=None,
depth_idx: int = None,
):
"""
Execute the forward pass through the Multi-Token Prediction (MTP) layer.
Expand All @@ -417,14 +419,17 @@ def forward(
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
# TODO: Multimodal compatible
# current unroll depth
effective_depth = self.layer_number if depth_idx is None else depth_idx

assert context is None, 'multi token prediction + cross attention is not yet supported.'
input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding=embedding,
packed_seq_params=packed_seq_params,
hidden_states=hidden_states,
depth=effective_depth,
)
assert not self.transformer_layer.self_attention.config.apply_rope_fusion
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
Expand All @@ -433,7 +438,7 @@ def forward(
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
else:
# mrope or not packed_seq
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0)
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.roll operation will fail if rotary_pos_emb is None. This occurs in models that do not use RoPE or MRoPE (e.g., models using absolute position embeddings).

            if rotary_pos_emb is not None:
                rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This rotary_pos_emb is None assumption is pre-existing in the base branch; this PR only changes the shift depth from self.layer_number to effective_depth for the unrolled MTP case, and does not introduce a new dereference path here. If needed, I can address the None guard separately in a follow-up cleanup.

if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
Expand Down Expand Up @@ -471,13 +476,65 @@ def forward(

MultiTokenPredictionLayer.forward = forward

def block_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
context: torch.Tensor = None,
context_mask: torch.Tensor = None,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_cos: torch.Tensor = None,
rotary_pos_sin: torch.Tensor = None,
attention_bias: torch.Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
sequence_len_offset: torch.Tensor = None,
extra_block_kwargs: Optional[dict] = None,
embedding=None,
):
"""Perform the forward pass through all MTP modules with optional layer reuse."""
offset = get_mtp_layer_offset(self.config, self.vp_stage)
hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0))
hidden_states = hidden_states_list[offset]

physical_num_layers = len(self.layers)
unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers

for step in range(unroll_steps):
layer = self.layers[step % physical_num_layers]
global_depth = offset + step + 1
hidden_states, input_ids, position_ids = layer(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=embedding,
depth_idx=global_depth,
**(extra_block_kwargs or {}),
)
hidden_states_list.append(hidden_states)

hidden_states = torch.cat(hidden_states_list, dim=0)
return hidden_states
Comment on lines +498 to +526
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block_forward implementation is not fully compatible with Pipeline Parallelism (PP) when mtp_unroll_steps is used:

  1. ZeroDivisionError: If a PP stage contains no MTP layers (len(self.layers) == 0), line 504 will crash due to step % physical_num_layers. A guard is needed to return early or skip the loop in such stages.
  2. Incorrect Chunking and Offset: torch.chunk(hidden_states, 1 + offset, dim=0) and global_depth = offset + step + 1 rely on the physical layer offset. If previous stages have already performed logical unrolling, the number of chunks in hidden_states and the starting logical depth will be higher than what the physical offset indicates.
  3. Redundant Execution: Every stage with MTP layers will attempt to execute the full unroll_steps. In a PP setup where MTP layers are distributed, this leads to an incorrect total number of logical steps and mismatched chunk counts in _postprocess.

If this feature is primarily intended for the mtp_num_layers=1 case (single shared layer), please add a check for physical_num_layers > 0 and consider documenting the PP limitations.


MultiTokenPredictionBlock.forward = block_forward

def _get_embeddings(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
packed_seq_params: Optional[PackedSeqParams] = None,
depth: int = 1,
):
from megatron.core.transformer.multi_token_prediction import roll_tensor
from megatron.core.utils import make_viewless_tensor
Expand Down Expand Up @@ -508,13 +565,17 @@ def _get_embeddings(
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
if enable_sp:
decoder_input = gather_from_sequence_parallel_region(decoder_input)
decoder_input, _ = roll_tensor(
decoder_input.transpose(0, 2),
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
decoder_input = decoder_input.transpose(0, 2)
# Megatron's roll_tensor is implemented around single-token left shifts, especially
# for packed sequences / CP, so apply depth as repeated -1 rolls instead of -depth.
for _ in range(depth):
decoder_input, _ = roll_tensor(
decoder_input,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
decoder_input = decoder_input.transpose(0, 2).contiguous()
if enable_sp:
decoder_input = scatter_to_sequence_parallel_region(decoder_input)
Expand Down
Loading