Support MTP weight reuse with unrolled steps#29
Support MTP weight reuse with unrolled steps#29MDR-EX1000 wants to merge 3 commits intomodelscope:release/1.1from
Conversation
- add mtp_unroll_steps config plumbing for reused MTP execution\n- unroll MultiTokenPredictionBlock with shared physical layers and explicit depth indices\n- compute MTP loss using the unrolled depth and keep multimodal decoder input detachable\n- add focused tests covering shared-layer execution and unrolled loss handling
There was a problem hiding this comment.
Code Review
This pull request introduces the ability to unroll Multi-Token Prediction (MTP) layers beyond the number of physical layers using a new mtp_unroll_steps configuration. Key changes include updating GPTModel._postprocess to support variable MTP depths, implementing a block_forward method to handle layer reuse, and adding configuration options for detaching decoder inputs. Review feedback identifies critical issues in the block_forward implementation regarding Pipeline Parallelism compatibility, specifically potential division by zero and incorrect depth calculations. Additionally, a potential crash was noted where rotary_pos_emb is rolled without checking if it is null.
| offset = get_mtp_layer_offset(self.config, self.vp_stage) | ||
| hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) | ||
| hidden_states = hidden_states_list[offset] | ||
|
|
||
| physical_num_layers = len(self.layers) | ||
| unroll_steps = getattr(self.config, 'mtp_unroll_steps', None) or self.config.mtp_num_layers | ||
|
|
||
| for step in range(unroll_steps): | ||
| layer = self.layers[step % physical_num_layers] | ||
| global_depth = offset + step + 1 | ||
| hidden_states, input_ids, position_ids = layer( | ||
| input_ids=input_ids, | ||
| position_ids=position_ids, | ||
| hidden_states=hidden_states, | ||
| attention_mask=attention_mask, | ||
| inference_params=inference_params, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| embedding=embedding, | ||
| depth_idx=global_depth, | ||
| **(extra_block_kwargs or {}), | ||
| ) | ||
| hidden_states_list.append(hidden_states) | ||
|
|
||
| hidden_states = torch.cat(hidden_states_list, dim=0) | ||
| return hidden_states |
There was a problem hiding this comment.
The block_forward implementation is not fully compatible with Pipeline Parallelism (PP) when mtp_unroll_steps is used:
- ZeroDivisionError: If a PP stage contains no MTP layers (
len(self.layers) == 0), line 504 will crash due tostep % physical_num_layers. A guard is needed to return early or skip the loop in such stages. - Incorrect Chunking and Offset:
torch.chunk(hidden_states, 1 + offset, dim=0)andglobal_depth = offset + step + 1rely on the physical layer offset. If previous stages have already performed logical unrolling, the number of chunks inhidden_statesand the starting logical depth will be higher than what the physical offset indicates. - Redundant Execution: Every stage with MTP layers will attempt to execute the full
unroll_steps. In a PP setup where MTP layers are distributed, this leads to an incorrect total number of logical steps and mismatched chunk counts in_postprocess.
If this feature is primarily intended for the mtp_num_layers=1 case (single shared layer), please add a check for physical_num_layers > 0 and consider documenting the PP limitations.
| else: | ||
| # mrope or not packed_seq | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-self.layer_number, dims=0) | ||
| rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-effective_depth, dims=0) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Good catch. This rotary_pos_emb is None assumption is pre-existing in the base branch; this PR only changes the shift depth from self.layer_number to effective_depth for the unrolled MTP case, and does not introduce a new dereference path here. If needed, I can address the None guard separately in a follow-up cleanup.
中文
前言
Qwen3.5 系列支持 mtp-k 推理(共享一层 MTP 权重),但之前的实现还不支持 mtp-k 的共享权重训练。本次改动增加了“1 层 MTP 权重重复展开 k 次进行训练”的实现,并已验证其在 mtp-k 推理中可以有效提高投机解码的接受率。
P.S. 上游 MS-SWIFT 仓库还需要做小幅改动,后续会补充。
概述
本次 PR 为 MTP 增加了基于逻辑展开的权重复用能力。
变更内容
mtp_unroll_steps。MultiTokenPredictionBlock.forward中支持复用物理 MTP layer。decoder_input_detach。验证
tests/test_mtp_reuse.pyPYTHONPATH=src:${PYTHONPATH} python -m pytest -q tests/test_mtp_reuse.py2 passed说明
mtp_num_layers表示物理层数、mtp_unroll_steps表示逻辑展开步数时,可启用 MTP 权重复用。English
Background
The Qwen3.5 series already supports mtp-k inference with a single shared MTP layer, but the previous implementation did not support shared-weight mtp-k training. This change adds support for training by repeatedly unrolling one physical MTP layer for k steps, and it has been validated to improve speculative decoding acceptance rate in mtp-k inference.
P.S. A small follow-up change is still needed in the upstream MS-SWIFT repository, and I will provide it separately.
Summary
This PR adds support for MTP weight reuse through logical unrolling.
What Changed
mtp_unroll_stepsto runtime config.MultiTokenPredictionBlock.forward.decoder_input_detachfor multimodal MTP decoder input handling.Validation
tests/test_mtp_reuse.pyPYTHONPATH=src:${PYTHONPATH} python -m pytest -q tests/test_mtp_reuse.py2 passedNotes
mtp_num_layersis set to the physical layer count andmtp_unroll_stepsis set to the logical unroll depth.