Skip to content

Support MTP weight reuse with unrolled steps#29

Open
MDR-EX1000 wants to merge 3 commits intomodelscope:release/1.1from
MDR-EX1000:mtp-reuse-patch
Open

Support MTP weight reuse with unrolled steps#29
MDR-EX1000 wants to merge 3 commits intomodelscope:release/1.1from
MDR-EX1000:mtp-reuse-patch

Conversation

@MDR-EX1000
Copy link
Copy Markdown

中文

前言

Qwen3.5 系列支持 mtp-k 推理(共享一层 MTP 权重),但之前的实现还不支持 mtp-k 的共享权重训练。本次改动增加了“1 层 MTP 权重重复展开 k 次进行训练”的实现,并已验证其在 mtp-k 推理中可以有效提高投机解码的接受率。

P.S. 上游 MS-SWIFT 仓库还需要做小幅改动,后续会补充。

概述

本次 PR 为 MTP 增加了基于逻辑展开的权重复用能力。

变更内容

  • 在运行时配置中新增 mtp_unroll_steps
  • MultiTokenPredictionBlock.forward 中支持复用物理 MTP layer。
  • 显式传递 depth index,保证展开后的深度语义正确。
  • 在 GPT 的后处理中按逻辑展开深度计算 MTP loss。
  • 为多模态 MTP 的 decoder input 处理新增 decoder_input_detach
  • 新增针对共享层执行和展开后 loss 计算的测试。

验证

  • 新增测试文件:tests/test_mtp_reuse.py
  • 验证命令:
PYTHONPATH=src:${PYTHONPATH} python -m pytest -q tests/test_mtp_reuse.py
  • 结果:2 passed

说明

  • mtp_num_layers 表示物理层数、mtp_unroll_steps 表示逻辑展开步数时,可启用 MTP 权重复用。

English

Background

The Qwen3.5 series already supports mtp-k inference with a single shared MTP layer, but the previous implementation did not support shared-weight mtp-k training. This change adds support for training by repeatedly unrolling one physical MTP layer for k steps, and it has been validated to improve speculative decoding acceptance rate in mtp-k inference.

P.S. A small follow-up change is still needed in the upstream MS-SWIFT repository, and I will provide it separately.

Summary

This PR adds support for MTP weight reuse through logical unrolling.

What Changed

  • Add mtp_unroll_steps to runtime config.
  • Reuse physical MTP layers in MultiTokenPredictionBlock.forward.
  • Pass explicit depth indices so unrolled steps keep the correct depth semantics.
  • Compute MTP loss with unrolled depth in GPT postprocess.
  • Add decoder_input_detach for multimodal MTP decoder input handling.
  • Add focused tests for shared-layer execution and unrolled loss handling.

Validation

  • Added tests/test_mtp_reuse.py
  • Verified with:
PYTHONPATH=src:${PYTHONPATH} python -m pytest -q tests/test_mtp_reuse.py
  • Result: 2 passed

Notes

  • MTP weight reuse is enabled when mtp_num_layers is set to the physical layer count and mtp_unroll_steps is set to the logical unroll depth.

- add mtp_unroll_steps config plumbing for reused MTP execution\n- unroll MultiTokenPredictionBlock with shared physical layers and explicit depth indices\n- compute MTP loss using the unrolled depth and keep multimodal decoder input detachable\n- add focused tests covering shared-layer execution and unrolled loss handling
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the ability to unroll Multi-Token Prediction (MTP) layers beyond the number of physical layers using a new mtp_unroll_steps configuration. Key changes include updating GPTModel._postprocess to support variable MTP depths, implementing a block_forward method to handle layer reuse, and adding configuration options for detaching decoder inputs. Review feedback identifies critical issues in the block_forward implementation regarding Pipeline Parallelism compatibility, specifically potential division by zero and incorrect depth calculations. Additionally, a potential crash was noted where rotary_pos_emb is rolled without checking if it is null.

Comment on lines +496 to +524
offset = get_mtp_layer_offset(self.config, self.vp_stage)
hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0))
hidden_states = hidden_states_list[offset]

physical_num_layers = len(self.layers)
unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers

for step in range(unroll_steps):
layer = self.layers[step % physical_num_layers]
global_depth = offset + step + 1
hidden_states, input_ids, position_ids = layer(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=embedding,
depth_idx=global_depth,
**(extra_block_kwargs or {}),
)
hidden_states_list.append(hidden_states)

hidden_states = torch.cat(hidden_states_list, dim=0)
return hidden_states
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block_forward implementation is not fully compatible with Pipeline Parallelism (PP) when mtp_unroll_steps is used:

  1. ZeroDivisionError: If a PP stage contains no MTP layers (len(self.layers) == 0), line 504 will crash due to step % physical_num_layers. A guard is needed to return early or skip the loop in such stages.
  2. Incorrect Chunking and Offset: torch.chunk(hidden_states, 1 + offset, dim=0) and global_depth = offset + step + 1 rely on the physical layer offset. If previous stages have already performed logical unrolling, the number of chunks in hidden_states and the starting logical depth will be higher than what the physical offset indicates.
  3. Redundant Execution: Every stage with MTP layers will attempt to execute the full unroll_steps. In a PP setup where MTP layers are distributed, this leads to an incorrect total number of logical steps and mismatched chunk counts in _postprocess.

If this feature is primarily intended for the mtp_num_layers=1 case (single shared layer), please add a check for physical_num_layers > 0 and consider documenting the PP limitations.

else:
# mrope or not packed_seq
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0)
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.roll operation will fail if rotary_pos_emb is None. This occurs in models that do not use RoPE or MRoPE (e.g., models using absolute position embeddings).

            if rotary_pos_emb is not None:
                rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This rotary_pos_emb is None assumption is pre-existing in the base branch; this PR only changes the shift depth from self.layer_number to effective_depth for the unrolled MTP case, and does not introduce a new dereference path here. If needed, I can address the None guard separately in a follow-up cleanup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant