diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 253d12b..d032b7b 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -203,6 +203,10 @@ class ModelConfig(TransformerConfig): dsa_indexer_use_sparse_loss: bool = False dsa_indexer_rotary_interleaved: bool = False + # mtp + mtp_unroll_steps: Optional[int] = None + decoder_input_detach: bool = True + # visual hf_config: Optional[PretrainedConfig] = None vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2' diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 84f2d7a..a083133 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -47,6 +47,7 @@ 'linear_key_head_dim': ['linear_key_head_dim'], 'linear_value_head_dim': ['linear_value_head_dim'], 'linear_conv_kernel_dim': ['linear_conv_kernel_dim'], + 'mtp_unroll_steps': ['mtp_unroll_steps'], # dsa 'dsa_indexer_n_heads': ['index_n_heads'], 'dsa_indexer_head_dim': ['index_head_dim'], diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index bf10c03..1081526 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -404,8 +404,10 @@ def _postprocess( input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if self.mtp_process and labels is not None: + mtp_depth = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers if self.config.is_multimodal: - embedding_ = (self.embedding, decoder_input) + _decoder_input = decoder_input.detach() if self.config.decoder_input_detach else decoder_input + embedding_ = (self.embedding, _decoder_input) else: embedding_ = self.embedding hidden_states = self.mtp( @@ -423,12 +425,12 @@ def _postprocess( **(extra_block_kwargs or {}), ) mtp_labels = labels.clone() - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states_list = torch.chunk(hidden_states, 1 + mtp_depth, dim=0) hidden_states = hidden_states_list[0] if loss_mask is None: # if loss_mask is not provided, use all ones as loss_mask loss_mask = torch.ones_like(mtp_labels) - for mtp_layer_number in range(self.config.mtp_num_layers): + for mtp_layer_number in range(mtp_depth): # output mtp_logits, _ = self.output_layer( hidden_states_list[mtp_layer_number + 1], @@ -460,10 +462,10 @@ def _postprocess( MTPLossLoggingHelper.save_loss_to_tracker( mtp_loss_for_log, mtp_layer_number, - self.config.mtp_num_layers, + mtp_depth, avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + mtp_loss_scale = self.config.mtp_loss_scaling_factor / mtp_depth if self.config.calculate_per_token_loss: hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) else: diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index b9d184c..3763e59 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -16,7 +16,8 @@ scatter_to_sequence_parallel_region) from megatron.core.transformer import TransformerLayer from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention -from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer +from megatron.core.transformer.multi_token_prediction import (MultiTokenPredictionBlock, MultiTokenPredictionLayer, + get_mtp_layer_offset) from megatron.core.utils import deprecate_inference_params from packaging import version from peft.tuners.tuners_utils import BaseTuner @@ -394,6 +395,7 @@ def forward( packed_seq_params: PackedSeqParams = None, sequence_len_offset: torch.Tensor = None, embedding=None, + depth_idx: int = None, ): """ Execute the forward pass through the Multi-Token Prediction (MTP) layer. @@ -417,7 +419,9 @@ def forward( Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used. """ - # TODO: Multimodal compatible + # current unroll depth + effective_depth = self.layer_number if depth_idx is None else depth_idx + assert context is None, 'multi token prediction + cross attention is not yet supported.' input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, @@ -425,6 +429,7 @@ def forward( embedding=embedding, packed_seq_params=packed_seq_params, hidden_states=hidden_states, + depth=effective_depth, ) assert not self.transformer_layer.self_attention.config.apply_rope_fusion packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' @@ -433,7 +438,7 @@ def forward( rotary_pos_emb = rotary_pos_emb[position_ids[0]] else: # mrope or not packed_seq - rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) + rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0) if self.config.recompute_granularity == 'full' and self.training: hidden_states = self._checkpointed_forward( partial( @@ -471,6 +476,57 @@ def forward( MultiTokenPredictionLayer.forward = forward + def block_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor = None, + context_mask: torch.Tensor = None, + rotary_pos_emb: torch.Tensor = None, + rotary_pos_cos: torch.Tensor = None, + rotary_pos_sin: torch.Tensor = None, + attention_bias: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + sequence_len_offset: torch.Tensor = None, + extra_block_kwargs: Optional[dict] = None, + embedding=None, + ): + """Perform the forward pass through all MTP modules with optional layer reuse.""" + offset = get_mtp_layer_offset(self.config, self.vp_stage) + hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) + hidden_states = hidden_states_list[offset] + + physical_num_layers = len(self.layers) + unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers + + for step in range(unroll_steps): + layer = self.layers[step % physical_num_layers] + global_depth = offset + step + 1 + hidden_states, input_ids, position_ids = layer( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + embedding=embedding, + depth_idx=global_depth, + **(extra_block_kwargs or {}), + ) + hidden_states_list.append(hidden_states) + + hidden_states = torch.cat(hidden_states_list, dim=0) + return hidden_states + + MultiTokenPredictionBlock.forward = block_forward + def _get_embeddings( self, input_ids: torch.Tensor, @@ -478,6 +534,7 @@ def _get_embeddings( embedding: Callable, hidden_states: torch.Tensor, packed_seq_params: Optional[PackedSeqParams] = None, + depth: int = 1, ): from megatron.core.transformer.multi_token_prediction import roll_tensor from megatron.core.utils import make_viewless_tensor @@ -508,13 +565,17 @@ def _get_embeddings( enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 if enable_sp: decoder_input = gather_from_sequence_parallel_region(decoder_input) - decoder_input, _ = roll_tensor( - decoder_input.transpose(0, 2), - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) + decoder_input = decoder_input.transpose(0, 2) + # Megatron's roll_tensor is implemented around single-token left shifts, especially + # for packed sequences / CP, so apply depth as repeated -1 rolls instead of -depth. + for _ in range(depth): + decoder_input, _ = roll_tensor( + decoder_input, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) decoder_input = decoder_input.transpose(0, 2).contiguous() if enable_sp: decoder_input = scatter_to_sequence_parallel_region(decoder_input) diff --git a/tests/test_mtp_reuse.py b/tests/test_mtp_reuse.py new file mode 100644 index 0000000..7c2ea92 --- /dev/null +++ b/tests/test_mtp_reuse.py @@ -0,0 +1,233 @@ +import pytest +import torch +from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock +from types import SimpleNamespace + +import mcore_bridge # noqa: F401 +import mcore_bridge.model.gpt_model as gpt_model_mod +from mcore_bridge.model.gpt_model import GPTModel + + +class RecordingLayer: + + def __init__(self): + self.depth_history = [] + + def __call__( + self, + *, + input_ids, + position_ids, + hidden_states, + attention_mask, + depth_idx=None, + **kwargs, + ): + self.depth_history.append(depth_idx) + return hidden_states + depth_idx, input_ids, position_ids + + +class RecordingOutputLayer: + + def __init__(self): + self.calls = [] + self.sequence_parallel = False + + def __call__(self, hidden_states, weight=None, runtime_gather_output=None): + self.calls.append(hidden_states.clone()) + return hidden_states.squeeze(-1).transpose(0, 1), None + + +def test_mtp_block_reuses_single_physical_layer_across_unroll_steps(): + layer = RecordingLayer() + block = SimpleNamespace( + config=SimpleNamespace( + mtp_num_layers=1, + mtp_unroll_steps=3, + pipeline_model_parallel_size=1, + pipeline_model_parallel_layout=None, + ), + vp_stage=None, + layers=[layer], + ) + + hidden_states = torch.zeros(2, 1, 1) + input_ids = torch.zeros(1, 2, dtype=torch.long) + position_ids = torch.zeros(1, 2, dtype=torch.long) + + output = MultiTokenPredictionBlock.forward( + block, + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=None, + ) + + assert layer.depth_history == [1, 2, 3] + + chunks = torch.chunk(output, 4, dim=0) + assert [chunk[0, 0, 0].item() for chunk in chunks] == [0.0, 1.0, 3.0, 6.0] + + +def test_mtp_block_rolls_multimodal_decoder_input_by_global_depth(monkeypatch): + from megatron.core.transformer import multi_token_prediction as mtp_mod + + monkeypatch.setattr( + mtp_mod, + 'roll_tensor', + lambda tensor, shifts, dims, cp_group=None, packed_seq_params=None: ( + torch.roll(tensor, shifts=shifts, dims=dims), tensor.numel()), + ) + + class RecordingMultimodalLayer: + + def __init__(self): + self.layer_number = 1 + self.cp_group = None + self.training = False + self.decoder_inputs = [] + self.config = SimpleNamespace( + position_embedding_type='rope', + recompute_granularity='selective', + sequence_parallel=False, + tensor_model_parallel_size=1, + ) + self.transformer_layer = SimpleNamespace( + self_attention=SimpleNamespace( + config=SimpleNamespace(apply_rope_fusion=False), + )) + + _get_embeddings = mtp_mod.MultiTokenPredictionLayer._get_embeddings + __call__ = mtp_mod.MultiTokenPredictionLayer.forward + + def _proj_and_transformer_layer( + self, + *, + hidden_states, + decoder_input, + attention_mask, + context, + context_mask, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + attention_bias, + inference_params, + packed_seq_params, + sequence_len_offset, + ): + self.decoder_inputs.append(decoder_input.clone()) + return hidden_states + + layer = RecordingMultimodalLayer() + block = SimpleNamespace( + config=SimpleNamespace( + mtp_num_layers=1, + mtp_unroll_steps=3, + pipeline_model_parallel_size=1, + pipeline_model_parallel_layout=None, + ), + vp_stage=None, + layers=[layer], + ) + + hidden_states = torch.zeros(4, 1, 1) + input_ids = torch.zeros(1, 4, dtype=torch.long) + position_ids = torch.zeros(1, 4, dtype=torch.long) + decoder_input = torch.arange(4, dtype=torch.float).view(4, 1, 1) + rotary_pos_emb = torch.zeros(4, 1, 1) + + MultiTokenPredictionBlock.forward( + block, + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + attention_mask=None, + rotary_pos_emb=rotary_pos_emb, + embedding=(lambda **kwargs: None, decoder_input), + ) + + assert [tensor[0, 0, 0].item() for tensor in layer.decoder_inputs] == [1.0, 2.0, 3.0] + + +def test_postprocess_uses_unroll_steps_for_mtp_loss(monkeypatch): + saved_losses = [] + monkeypatch.setattr( + gpt_model_mod, + 'roll_tensor', + lambda tensor, shifts, dims, cp_group=None, packed_seq_params=None: (tensor, tensor.numel()), + ) + monkeypatch.setattr( + gpt_model_mod.MTPLossAutoScaler, + 'apply', + lambda hidden_states, scaled_loss: hidden_states, + ) + monkeypatch.setattr( + gpt_model_mod.MTPLossLoggingHelper, + 'save_loss_to_tracker', + lambda loss, layer_number, total_layers, avg_group=None: saved_losses.append((layer_number, total_layers)), + ) + monkeypatch.setattr( + gpt_model_mod.parallel_state, + 'get_data_parallel_group', + lambda with_context_parallel=True: None, + ) + monkeypatch.setattr(gpt_model_mod, 'has_config_logger_enabled', lambda config: False) + + output_layer = RecordingOutputLayer() + + def mtp_forward(**kwargs): + hidden_states = kwargs['hidden_states'] + return torch.cat([hidden_states, hidden_states + 1, hidden_states + 2, hidden_states + 3], dim=0) + + model = SimpleNamespace( + post_process=True, + mtp_process=True, + training=True, + share_embeddings_and_output_weights=False, + cp_group=None, + embedding=lambda *args, **kwargs: None, + output_layer=output_layer, + mtp=mtp_forward, + config=SimpleNamespace( + task_type='causal_lm', + is_multimodal=False, + context_parallel_size=1, + mtp_num_layers=1, + mtp_unroll_steps=3, + decoder_input_detach=True, + calculate_per_token_loss=False, + mtp_loss_scaling_factor=0.3, + sequence_parallel=False, + tensor_model_parallel_size=1, + ), + compute_language_model_loss=lambda labels, logits: logits.float(), + ) + + loss = GPTModel._postprocess( + model, + hidden_states=torch.zeros(1, 1, 1), + input_ids=None, + position_ids=None, + labels=torch.ones(1, 1, dtype=torch.long), + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + loss_mask=torch.ones(1, 1, dtype=torch.bool), + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=False, + extra_block_kwargs=None, + inference_context=None, + ) + + assert loss.shape == (1, 1) + assert [call[0, 0, 0].item() for call in output_layer.calls] == [1.0, 2.0, 3.0, 0.0] + assert saved_losses == [(0, 3), (1, 3), (2, 3)] + + +if __name__ == '__main__': + raise SystemExit(pytest.main([__file__, '-q']))