From ba9b1296339c0a7dafc4539df09ff24a9b01696e Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 27 Feb 2026 12:49:43 -0800 Subject: [PATCH 1/9] inplement mix hidden_states for eagle3; deprecate eagle1 Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 10 +- modelopt/torch/speculative/config.py | 17 +- .../torch/speculative/eagle/conversion.py | 2 + .../torch/speculative/eagle/eagle_model.py | 4 + .../speculative/plugins/megatron_eagle.py | 166 +++++++++++------- .../torch/speculative/plugins/transformers.py | 76 +++++--- .../speculative_decoding/test_eagle.py | 9 +- .../test_speculative_megatron_modules.py | 52 +----- .../plugins/test_hf_speculative.py | 26 +-- 9 files changed, 189 insertions(+), 173 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 6821111849..25817ee94f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -43,7 +43,6 @@ make_eagle_supervised_data_module, patch_ring_attention_for_ttt, ) -from medusa_utils import make_medusa_supervised_data_module from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto @@ -127,6 +126,10 @@ class EagleArguments: default="llama", metadata={"help": "The class of eagle decoder to use. Available options: llama, kimik2"}, ) + mix_hidden_states: bool = field( + default=False, + metadata={"help": "Whether to mix hidden states from previous TTT step."}, + ) def train(): @@ -204,6 +207,7 @@ def train(): config = { "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, + "eagle_mix_hidden_states": eagle_args.mix_hidden_states, "eagle_architecture_config": custom_config, } @@ -221,9 +225,7 @@ def train(): raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") - if training_args.mode == "medusa": - data_module = make_medusa_supervised_data_module(tokenizer, data_args) - elif training_args.mode == "eagle3": + if training_args.mode == "eagle3": data_module = make_eagle_supervised_data_module( tokenizer, data_args, train_len=training_args.training_seq_len ) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 41987d4e41..b28dca61f5 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -29,12 +29,6 @@ eagle3_default_config.update({"use_aux_hidden_state": True, "use_last_layernorm": True}) eagle_mtp_default_config.update({"use_last_layernorm": True, "use_mtp_layernorm": True}) -EAGLE1_DEFAULT_CFG = { - "algorithm": "eagle", - "config": { - "eagle_architecture_config": deepcopy(default_eagle_config), - }, -} EAGLE3_DEFAULT_CFG = { "algorithm": "eagle", @@ -105,3 +99,14 @@ class EagleConfig(ModeloptBaseConfig): default="llama", description=("The class of eagle decoder to use. Available options: llama, kimik2"), ) + + eagle_ttt_steps: int = ModeloptField( + default=4, description=("The number of train-time-test steps in training.") + ) + + eagle_mix_hidden_states: bool = ModeloptField( + default=False, + description=( + "Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost." + ), + ) diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index 2b085d5e35..5f1cbfedb8 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -58,6 +58,8 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu eagle_loss_decay_factor=config.eagle_loss_decay_factor, eagle_architecture_config=config.eagle_architecture_config, eagle_decoder_type=config.eagle_decoder_type, + eagle_ttt_steps=config.eagle_ttt_steps, + eagle_mix_hidden_states=config.eagle_mix_hidden_states, ) # no metadata, all specified via config. diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index d54fdc843d..41ee83a3ac 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -35,6 +35,8 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + eagle_ttt_steps, + eagle_mix_hidden_states, ): """Base Eagle Model modify function. Child class should implement the details.""" self.eagle_offline = eagle_offline @@ -45,3 +47,5 @@ def modify( self.eagle_reuse_base_decoder = eagle_reuse_base_decoder self.eagle_loss_decay_factor = eagle_loss_decay_factor self.eagle_decoder_type = eagle_decoder_type + self.eagle_ttt_steps = eagle_ttt_steps + self.eagle_mix_hidden_states = eagle_mix_hidden_states diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index e37e8f931c..3e2dc04f7d 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -15,6 +15,7 @@ """Plugin to add EAGLE support for Megatron-Core GPT model.""" +import contextlib import copy import warnings from contextlib import contextmanager @@ -448,20 +449,21 @@ def __init__( self._num_aux_hidden_states if self._num_aux_hidden_states > 0 else 2 ) - # This linear was previously a ColumnParallelLinear. We changed it to a TELinear - # since ColumnParallelLinear will have try to gather the input sequence when sequence - # parallel is used and does not allow gathering the outputs. - with torch.device(device): - self.fc = TELinear( - config.hidden_size * fc_input_size_multiplier, - config.hidden_size, - parallel_mode="duplicated", - config=config, - init_method=(lambda w: None), # not used - bias=bias, - skip_bias_add=False, - skip_weight_param_allocation=False, - ) + if config.use_aux_hidden_state: + # This linear was previously a ColumnParallelLinear. We changed it to a TELinear + # since ColumnParallelLinear will have try to gather the input sequence when sequence + # parallel is used and does not allow gathering the outputs. + with torch.device(device): + self.fc = TELinear( + config.hidden_size * fc_input_size_multiplier, + config.hidden_size, + parallel_mode="duplicated", + config=config, + init_method=(lambda w: None), # not used + bias=bias, + skip_bias_add=False, + skip_weight_param_allocation=False, + ) self.rotary_pos_emb = rotary_pos_emb @@ -606,13 +608,9 @@ def forward( embeddings = self.enorm(embeddings) hidden_states = self.hnorm(hidden_states) - # EAGLE-1 uses [s, b, h] input but EAGLE-3 uses [s, b, 2h] input if self._num_aux_hidden_states == 0: - # [s, b, 2h] - decoder_input = torch.cat((embeddings, hidden_states), dim=-1) - decoder_input = self.fc(decoder_input)[0] + decoder_input = hidden_states else: - # EAGLE-3 forward # EAGLE-3 uses self.fc outside eagle_module forward to convert hidden_states from [s, b, 3h] self._embeddings = self.enorm(embeddings) decoder_input = hidden_states @@ -693,6 +691,8 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + eagle_ttt_steps, + eagle_mix_hidden_states, ): if self.config.pipeline_model_parallel_size > 1: warnings.warn( @@ -715,6 +715,8 @@ def modify( eagle_loss_decay_factor=eagle_loss_decay_factor, eagle_architecture_config=eagle_architecture_config, eagle_decoder_type=eagle_decoder_type, + eagle_ttt_steps=eagle_ttt_steps, + eagle_mix_hidden_states=eagle_mix_hidden_states, ) # sequence_parallel is not used in offline eagle @@ -824,7 +826,7 @@ def modify( self.kld = logits_kld_loss def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True): - """When _aux_hidden_states is not empty for online, then this is EAGLE-3. + """Get input hidden_states for EAGLE. Args: hidden_states: last hidden_states @@ -879,18 +881,6 @@ def _get_eagle_module_inputs( # [s,b] -> [b,s] padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() - attn_mask = attention_mask.clone().detach() - # [b, 1, sq, sk] -> [sq, 1, b, sk] - attn_mask = attn_mask.transpose(0, 2).contiguous() - attn_mask = gather_from_sequence_parallel_region( - attn_mask, group=get_context_parallel_group() - ) - # [sq, 1, b, sk] -> [b, 1, sq, sk] - attn_mask = attn_mask.transpose(0, 2).contiguous() - attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] - attn_mask[:, :, -1, :] = True - attn_mask[:, :, :, -1] = True - eagle_inputs = {} eagle_inputs["input_ids"] = padded_input_ids @@ -903,14 +893,29 @@ def _get_eagle_module_inputs( eagle_inputs["hidden_states"] = hidden_states - attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step) - # [b, 1, sq, sk] -> [sq, 1, b, sk] - attn_mask = attn_mask.transpose(0, 2).contiguous() - attn_mask = scatter_to_sequence_parallel_region( - attn_mask, group=get_context_parallel_group() - ) - # [sq, 1, b, sk] -> [b, 1, sq, sk] - eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() + if self.eagle_mix_hidden_states: + eagle_inputs["attention_mask"] = attention_mask + else: + attn_mask = attention_mask.clone().detach() + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = gather_from_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] + attn_mask[:, :, -1, :] = True + attn_mask[:, :, :, -1] = True + + attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step) + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = scatter_to_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] * (ttt_step + 1), @@ -1055,7 +1060,6 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict | None = None, return_eagle_inputs: bool = False, - ttt_steps=4, **kwargs, ) -> torch.Tensor: if position_ids is None or attention_mask is None: @@ -1068,11 +1072,6 @@ def forward( raise ValueError("return_eagle_inputs is unsupported in EAGLE offline mode.") aux_hidden_states = kwargs.get("aux_hidden_states") hidden_states = kwargs.get("hidden_states") - if aux_hidden_states is None or hidden_states is None: - raise ValueError( - "EAGLE offline mode requires kwargs: aux_hidden_states=[s,b,k*h], " - "hidden_states=[s,b,h]." - ) else: # When return_eagle_inputs is True, return decoder_input_for_eagle. # For LLM, decoder_input_for_eagle is just the text embeddings. However, for VLM @@ -1097,15 +1096,17 @@ def forward( output_weight = self.shared_embedding_or_output_weight() logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) - # EAGLE kv cache - eagle_inference_context = StaticInferenceContext( - input_ids.shape[0], - input_ids.shape[1] * ttt_steps, - ) + if not self.eagle_mix_hidden_states: + # EAGLE kv cache + eagle_inference_context = StaticInferenceContext( + input_ids.shape[0], + input_ids.shape[1] * self.eagle_ttt_steps, + ) if self.eagle_offline: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( - aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state + aux_hidden_states if self.eagle_config.use_aux_hidden_state else hidden_states, + apply_fc=self.eagle_config.use_aux_hidden_state, ) # If EAGLE-3, aux_hidden_states are gathered by the forward_hook elif return_eagle_inputs: @@ -1122,12 +1123,14 @@ def forward( # In case of VLM, there will be other fields for pixels. return { "input_ids": input_ids.squeeze(0).cpu(), - "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).cpu(), + "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).cpu() + if self.eagle_config.use_aux_hidden_state + else None, "hidden_states": hidden_states.squeeze(1).cpu(), } else: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( - hidden_states, apply_fc=True + hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state ) if labels is not None: @@ -1150,7 +1153,7 @@ def forward( loss = 0.0 * loss acc = [] - for ttt_step in range(ttt_steps): + for ttt_step in range(self.eagle_ttt_steps): eagle_inputs = self._get_eagle_module_inputs( input_ids=input_ids, hidden_states=eagle_module_input_hidden_states, @@ -1159,36 +1162,65 @@ def forward( ttt_step=ttt_step, ) - with te_dot_product_attention_with_cp( - eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads + with ( + te_dot_product_attention_with_cp( + eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads + ) + if not self.eagle_mix_hidden_states + else contextlib.nullcontext() ): - _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( + _, eagle_logits, eagle_module_output_hidden_states = self._eagle_forward( eagle_inputs, output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, + inference_context=None + if self.eagle_mix_hidden_states + else eagle_inference_context, **(extra_block_kwargs or {}), ) if self.config.sequence_parallel: - eagle_module_input_hidden_states = gather_from_sequence_parallel_region( - eagle_module_input_hidden_states + eagle_module_output_hidden_states = gather_from_sequence_parallel_region( + eagle_module_output_hidden_states ) - eagle_module_input_hidden_states = torch.cat( + eagle_module_output_hidden_states = torch.cat( ( torch.zeros( ( 1, - eagle_module_input_hidden_states.shape[1], - eagle_module_input_hidden_states.shape[2], + eagle_module_output_hidden_states.shape[1], + eagle_module_output_hidden_states.shape[2], ), - dtype=eagle_module_input_hidden_states.dtype, - device=eagle_module_input_hidden_states.device, + dtype=eagle_module_output_hidden_states.dtype, + device=eagle_module_output_hidden_states.device, ), - eagle_module_input_hidden_states[:-1, :, :], + eagle_module_output_hidden_states[:-1, :, :], ) ) + + if self.eagle_mix_hidden_states: + seq_len_s, batch_size, _ = eagle_module_output_hidden_states.shape + num_to_replace = max(1, seq_len_s // (2**ttt_step + 1)) + + # Randomly select positions for each batch to replace + rand_indices = torch.stack( + [ + torch.randperm(seq_len_s, device=eagle_module_output_hidden_states.device)[ + :num_to_replace + ] + for _ in range(batch_size) + ], + dim=0, + ) + + for batch_idx in range(batch_size): + eagle_module_input_hidden_states[rand_indices[batch_idx], batch_idx, :] = ( + eagle_module_output_hidden_states[rand_indices[batch_idx], batch_idx, :] + ) + else: + eagle_module_input_hidden_states = eagle_module_output_hidden_states + if self.config.sequence_parallel: eagle_module_input_hidden_states = scatter_to_sequence_parallel_region( eagle_module_input_hidden_states diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index c21594afe9..a51dbd88f0 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -260,10 +260,7 @@ def __init__(self, config, decoder_layer_cls, bias=False): bias=False, ) - if not config.use_aux_hidden_state: - # In Eagle-1, the FC concentrate input embeddings and hidden states - self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=bias) - else: + if config.use_aux_hidden_state: # In EAGLE-3, the FC concentrate hidden states from multiple base model layers self.fc = nn.Linear( len(config.eagle_aux_hidden_state_layer_ids) * config.hidden_size, @@ -369,13 +366,10 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) - if self.config.use_aux_hidden_state: - # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function - # Also, we normalize input embeddings and hidden states before concatenating them. - # The default input norm in first layer attn will be disabled. - self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) - else: # EAGLE-1 - hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) + # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function + # Also, we normalize input embeddings and hidden states before concatenating them. + # The default input norm in first layer attn will be disabled. + self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) if self.config.eagle_decoder_type == "llama": # Lazy init rope to avoid save/load meta tensor error @@ -564,6 +558,8 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + eagle_ttt_steps, + eagle_mix_hidden_states, ): """Constructor. @@ -580,6 +576,8 @@ def modify( eagle_loss_decay_factor=eagle_loss_decay_factor, eagle_architecture_config=eagle_architecture_config, eagle_decoder_type=eagle_decoder_type, + eagle_ttt_steps=eagle_ttt_steps, + eagle_mix_hidden_states=eagle_mix_hidden_states, ) if eagle_decoder_type == "llama": @@ -642,7 +640,6 @@ def modify( # https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566 self.is_quantized = False - self.num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later. self._cached_attn_blk_masks = {} def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): @@ -704,10 +701,10 @@ def _prepare_eagle_inputs( # Prepare eagle_input_hiddens if self.eagle_config.use_aux_hidden_state: - # Eagle3: concat base model intermediate (pre-norm) hiddens + # concat base model intermediate (pre-norm) hiddens eagle_input_hiddens = self.eagle_module.fc(base_outputs.aux_hiddens) else: - # Eagle1: use base model output (post-norm)hiddens + # use base model output (post-norm)hiddens eagle_input_hiddens = base_outputs.out_hiddens # Prepare attention_mask @@ -931,22 +928,52 @@ def forward( ) # ====Run eagle forward with extra training-time-test steps==== - for ttt_step in range(self.num_ttt_steps): + for ttt_step in range(self.eagle_ttt_steps): # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. - eagle_attention_mask = ( - eagle_attn_mask_0 - if ttt_step == 0 - else self._get_ttt_attention_mask(b, seq_length, ttt_step) - ) - with enable_cp_ttt_patch() if self.training else contextlib.nullcontext(): - _, eagle_input_hiddens, eagle_logits, eagle_cache = self._eagle_forward( + if self.eagle_mix_hidden_states: + eagle_attention_mask = eagle_attn_mask_0 + else: + eagle_attention_mask = ( + eagle_attn_mask_0 + if ttt_step == 0 + else self._get_ttt_attention_mask(b, seq_length, ttt_step) + ) + with ( + enable_cp_ttt_patch() + if self.training and not self.eagle_mix_hidden_states + else contextlib.nullcontext() + ): + _, eagle_output_hiddens, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hiddens, eagle_input_embeds, eagle_attention_mask, eagle_position_ids, - eagle_cache, + None if self.eagle_mix_hidden_states else eagle_cache, ) - eagle_input_hiddens = eagle_input_hiddens.roll(1, 1) + eagle_output_hiddens = eagle_output_hiddens.roll(1, 1) + + if self.eagle_mix_hidden_states: + batch_size, seq_len_s, _ = eagle_input_hiddens.shape + num_to_replace = max(1, seq_len_s // (2**ttt_step + 1)) + + # Randomly select positions for each batch to replace + rand_indices = torch.stack( + [ + torch.randperm(seq_len_s, device=eagle_input_hiddens.device)[ + :num_to_replace + ] + for _ in range(batch_size) + ], + dim=0, + ) + + for batch_idx in range(batch_size): + eagle_input_hiddens[batch_idx, rand_indices[batch_idx], :] = ( + eagle_output_hiddens[batch_idx, rand_indices[batch_idx], :] + ) + else: + eagle_input_hiddens = eagle_output_hiddens + for i in range(self.eagle_config.parallel_draft_step): eagle_logit = eagle_logits[i] classification_loss, acc = self._eagle_loss( @@ -1037,7 +1064,6 @@ def pseudo_speculative_generate( eagle_ids = torch.cat((input_ids[:, 1:], base_token), dim=-1) if self.eagle_config.use_aux_hidden_state: - # EAGLE-3 # Only the first iteration input_hidden_states are from aux_hidden_state layers # Gather _aux_hidden_states from all devices before concatenation eagle_input_hidden_states = self.eagle_module.fc(self.pop_and_gather_aux_hiddens()) diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 9c73ea96a2..a3542fa25f 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -61,8 +61,12 @@ def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft # fmt: off -@pytest.mark.parametrize("cp_size", [1, 2]) -def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, cp_size): +@pytest.mark.parametrize(("cp_size", "mix_hidden_states"), [(1, "false"), (2, "false"), (1, "true"), (2, "true")]) +def test_llama_eagle3(tiny_llama_path, + tiny_daring_anteater_path, + tmp_path, eagle_output_dir, + cp_size, + mix_hidden_states): """Test Eagle3 training with a tiny llama model, using different cp_size values.""" available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 if cp_size == 2 and available_gpus < 2: @@ -96,6 +100,7 @@ def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagl "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}", "--training_seq_len", "128", # Match max_position_embeddings "--cp_size", str(cp_size), + "--mix_hidden_states", mix_hidden_states, ], "speculative_decoding", ) diff --git a/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py index f1144bea4e..3101a43d48 100644 --- a/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py +++ b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py @@ -21,18 +21,14 @@ import modelopt.torch.speculative as mtsp from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel -from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel ALGO_TO_CONFIG = { - "eagle1": mtsp.config.EAGLE1_DEFAULT_CFG, "eagle3": mtsp.config.EAGLE3_DEFAULT_CFG, "eagle-mtp": mtsp.config.EAGLE_MTP_DEFAULT_CFG, } -def _test_speculative_gpt_model( - algo, num_medusa_heads_or_eagle_layers, activation_func, normalization, rank, size -): +def _test_speculative_gpt_model(algo, num_layers, activation_func, normalization, rank, size): num_attention_heads = 8 num_query_groups = size max_sequence_length = 32 @@ -51,22 +47,10 @@ def _test_speculative_gpt_model( normalization=normalization, ).cuda() - if algo == "medusa": - config = { - "medusa_num_heads": num_medusa_heads_or_eagle_layers, - "medusa_num_layers": 1, - } - - model = mtsp.convert(model, [("medusa", config)]) - - # Type checking - assert isinstance(model, _DynamicMedusaGPTModel) - elif algo in {"eagle1", "eagle3"}: - mtsp_config = copy.deepcopy(ALGO_TO_CONFIG[algo]) + if algo == "eagle3": + mtsp_config = ALGO_TO_CONFIG[algo] - mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] = ( - num_medusa_heads_or_eagle_layers - ) + mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] = num_layers mtsp_config["config"]["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size mtsp_config["config"]["eagle_architecture_config"]["vocab_size"] = model.vocab_size mtsp_config["config"]["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size @@ -89,14 +73,6 @@ def _test_speculative_gpt_model( assert len(first_layer.self_attention._forward_pre_hooks) > 0 # Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state assert len(last_layer._forward_hooks) > 0 - elif algo == "eagle1": - first_layer = model.eagle_module.decoder.layers[0] - last_layer = model.eagle_module.decoder.layers[-1] - # Eagle1 QKV input_dim the same as hidden_size - assert first_layer.self_attention.linear_qkv.weight.shape[-1] == model.config.hidden_size - # No forward_hook or forward_pre_hook are needed - assert len(first_layer.self_attention._forward_pre_hooks) == 0 - assert len(last_layer._forward_hooks) == 0 # Bfloat16 model = model.to(torch.bfloat16) @@ -113,19 +89,7 @@ def _test_speculative_gpt_model( assert logits.shape[1] == max_sequence_length assert logits.shape[2] == vocab_size / size - if algo == "medusa": - # When label provided, model.forward should return - # medusa_loss[b, s * (num_medusa_heads + 1), b] - labels = torch.randint( - 0, - vocab_size, - (batch_size, max_sequence_length), - ).cuda() - medusa_loss = model(prompt_tokens, position_ids, attention_mask, labels=labels) - - assert medusa_loss.shape[0] == batch_size - assert medusa_loss.shape[1] == max_sequence_length - elif algo in {"eagle1", "eagle3"}: + if algo == "eagle3": labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() eagle_loss = model(prompt_tokens, position_ids, attention_mask, labels=labels) @@ -134,14 +98,10 @@ def _test_speculative_gpt_model( @pytest.mark.parametrize( - ("algo", "num_medusa_heads_or_eagle_layers", "activation_func", "normalization"), + ("algo", "num_layers", "activation_func", "normalization"), [ - ("eagle1", 1, "squared_relu", "LayerNorm"), # MHA - ("eagle1", 2, "swiglu", "RMSNorm"), # GQA ("eagle3", 1, "swiglu", "RMSNorm"), # GQA ("eagle3", 2, "swiglu", "RMSNorm"), # GQA - ("medusa", 1, "squared_relu", "LayerNorm"), # MHA - ("medusa", 2, "swiglu", "RMSNorm"), # GQA ], ) def test_speculative_gpt_model( diff --git a/tests/unit/torch/speculative/plugins/test_hf_speculative.py b/tests/unit/torch/speculative/plugins/test_hf_speculative.py index 3810f31c58..b41b7fae2b 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_speculative.py +++ b/tests/unit/torch/speculative/plugins/test_hf_speculative.py @@ -18,36 +18,16 @@ import pytest from _test_utils.torch.transformers_models import ( - create_tiny_llama_dir, get_tiny_llama, tf_modelopt_state_and_output_tester, ) -from transformers import AutoModelForCausalLM, LlamaForCausalLM +from transformers import AutoModelForCausalLM import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.config import EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG +from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG -def test_medusa_model_convert_save_and_restore(tmp_path): - tiny_llama_dir = create_tiny_llama_dir(tmp_path) - model_ref = LlamaForCausalLM.from_pretrained(tiny_llama_dir) - - config = { - "medusa_num_heads": 2, - "medusa_num_layers": 1, - } - mtsp.convert(model_ref, mode=[("medusa", config)]) - assert isinstance(model_ref, mtsp.plugins.HFMedusaModel) - - model_ref.save_pretrained(tiny_llama_dir / "modelopt_model") - assert os.path.exists(tiny_llama_dir / "modelopt_model/modelopt_state.pth") - - model_test = AutoModelForCausalLM.from_pretrained(tiny_llama_dir / "modelopt_model") - assert isinstance(model_test, mtsp.plugins.HFMedusaModel) - tf_modelopt_state_and_output_tester(model_ref, model_test) - - -@pytest.mark.parametrize("eagle_config", [EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG]) +@pytest.mark.parametrize("eagle_config", [EAGLE3_DEFAULT_CFG]) def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config): model_ref = get_tiny_llama(num_hidden_layers=8) From 2dd7392ded56e9fdd88215785fab94f52a0ee593 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 27 Feb 2026 13:48:13 -0800 Subject: [PATCH 2/9] formatting Signed-off-by: Ye Yu --- .../torch/speculative/plugins/megatron_eagle.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 3e2dc04f7d..996c1dcf7d 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1121,12 +1121,18 @@ def forward( hidden_states = gather_from_sequence_parallel_region(hidden_states) logits_sbh = gather_from_tensor_model_parallel_region(logits_sbh) # In case of VLM, there will be other fields for pixels. + aux_hidden = None + if self.eagle_config.use_aux_hidden_state: + aux_hidden = eagle_module_input_hidden_states.squeeze(1).cpu() + + hidden_states_cpu = None + if hidden_states is not None: + hidden_states_cpu = hidden_states.squeeze(1).cpu() + return { "input_ids": input_ids.squeeze(0).cpu(), - "aux_hidden_states": eagle_module_input_hidden_states.squeeze(1).cpu() - if self.eagle_config.use_aux_hidden_state - else None, - "hidden_states": hidden_states.squeeze(1).cpu(), + "aux_hidden_states": aux_hidden, + "hidden_states": hidden_states_cpu, } else: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( From 096987364a00ffa5b742610d89533e4be1e4373c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Fri, 27 Feb 2026 14:36:14 -0800 Subject: [PATCH 3/9] fix tests Signed-off-by: Ye Yu --- .../plugins/test_speculative_megatron_modules.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py index 3101a43d48..3bac372f72 100644 --- a/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py +++ b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from functools import partial import pytest @@ -104,14 +103,12 @@ def _test_speculative_gpt_model(algo, num_layers, activation_func, normalization ("eagle3", 2, "swiglu", "RMSNorm"), # GQA ], ) -def test_speculative_gpt_model( - dist_workers, algo, num_medusa_heads_or_eagle_layers, activation_func, normalization -): +def test_speculative_gpt_model(dist_workers, algo, num_layers, activation_func, normalization): dist_workers.run( partial( _test_speculative_gpt_model, algo, - num_medusa_heads_or_eagle_layers, + num_layers, activation_func, normalization, ), From ad1d17c976844d9a8d190261299ee4c84a00f196 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 2 Mar 2026 09:26:23 -0800 Subject: [PATCH 4/9] debug Signed-off-by: Ye Yu --- examples/speculative_decoding/launch_train.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index bdfc4ee383..074151c5a0 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -110,6 +110,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi HEAD_NODE_IP="${1#*=}" ;; + --mix_hidden_states*) + if [[ "$1" != *=* ]]; then shift; fi + MIX_HIDDEN_STATES="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -149,6 +153,7 @@ CP_SIZE=${CP_SIZE:-1} DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} +MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} if [[ "$MODE" == "eagle3" ]]; then @@ -234,6 +239,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --disable_tqdm $DISABLE_TQDM \ --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ + --mix_hidden_states $MIX_HIDDEN_STATES \ $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ From 48cc15fb39149bdf89e60602163345d46e104d64 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 2 Mar 2026 11:06:49 -0800 Subject: [PATCH 5/9] refactor conversion API Signed-off-by: Ye Yu --- .../torch/speculative/eagle/conversion.py | 14 +------ .../torch/speculative/eagle/eagle_model.py | 32 ++++++---------- .../speculative/plugins/megatron_eagle.py | 30 ++------------- .../torch/speculative/plugins/transformers.py | 38 ++++--------------- 4 files changed, 24 insertions(+), 90 deletions(-) diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index 5f1cbfedb8..b7cb59f649 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -48,19 +48,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu config.eagle_architecture_config = {**default_arch_config, **custom_config} eagle_model = EagleDMRegistry.convert(model) - eagle_model.modify( - eagle_offline=config.eagle_offline, - eagle_hidden_state_distillation=config.eagle_hidden_state_distillation, - eagle_self_logit_distillation=config.eagle_self_logit_distillation, - eagle_freeze_base_model=config.eagle_freeze_base_model, - eagle_report_acc=config.eagle_report_acc, - eagle_reuse_base_decoder=config.eagle_reuse_base_decoder, - eagle_loss_decay_factor=config.eagle_loss_decay_factor, - eagle_architecture_config=config.eagle_architecture_config, - eagle_decoder_type=config.eagle_decoder_type, - eagle_ttt_steps=config.eagle_ttt_steps, - eagle_mix_hidden_states=config.eagle_mix_hidden_states, - ) + eagle_model.modify(config) # no metadata, all specified via config. metadata = {} diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index 41ee83a3ac..85251c86a2 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -26,26 +26,16 @@ def _setup(self): def modify( self, - eagle_offline, - eagle_hidden_state_distillation, - eagle_self_logit_distillation, - eagle_freeze_base_model, - eagle_report_acc, - eagle_reuse_base_decoder, - eagle_loss_decay_factor, - eagle_architecture_config, - eagle_decoder_type, - eagle_ttt_steps, - eagle_mix_hidden_states, + config, ): """Base Eagle Model modify function. Child class should implement the details.""" - self.eagle_offline = eagle_offline - self.eagle_hidden_state_distillation = eagle_hidden_state_distillation - self.eagle_self_logit_distillation = eagle_self_logit_distillation - self.eagle_freeze_base_model = eagle_freeze_base_model - self.eagle_report_acc = eagle_report_acc - self.eagle_reuse_base_decoder = eagle_reuse_base_decoder - self.eagle_loss_decay_factor = eagle_loss_decay_factor - self.eagle_decoder_type = eagle_decoder_type - self.eagle_ttt_steps = eagle_ttt_steps - self.eagle_mix_hidden_states = eagle_mix_hidden_states + self.eagle_offline = config.eagle_offline + self.eagle_hidden_state_distillation = config.eagle_hidden_state_distillation + self.eagle_self_logit_distillation = config.eagle_self_logit_distillation + self.eagle_freeze_base_model = config.eagle_freeze_base_model + self.eagle_report_acc = config.eagle_report_acc + self.eagle_reuse_base_decoder = config.eagle_reuse_base_decoder + self.eagle_loss_decay_factor = config.eagle_loss_decay_factor + self.eagle_decoder_type = config.eagle_decoder_type + self.eagle_ttt_steps = config.eagle_ttt_steps + self.eagle_mix_hidden_states = config.eagle_mix_hidden_states diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 996c1dcf7d..09499e3091 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -682,17 +682,7 @@ def _setup(self): def modify( self, - eagle_offline, - eagle_hidden_state_distillation, - eagle_self_logit_distillation, - eagle_freeze_base_model, - eagle_report_acc, - eagle_reuse_base_decoder, - eagle_loss_decay_factor, - eagle_architecture_config, - eagle_decoder_type, - eagle_ttt_steps, - eagle_mix_hidden_states, + config, ): if self.config.pipeline_model_parallel_size > 1: warnings.warn( @@ -705,26 +695,14 @@ def modify( if hasattr(self.config, "hetereogenous_dist_checkpoint"): self.config.hetereogenous_dist_checkpoint = True - super().modify( - eagle_offline=eagle_offline, - eagle_hidden_state_distillation=eagle_hidden_state_distillation, - eagle_self_logit_distillation=eagle_self_logit_distillation, - eagle_freeze_base_model=eagle_freeze_base_model, - eagle_report_acc=eagle_report_acc, - eagle_reuse_base_decoder=eagle_reuse_base_decoder, - eagle_loss_decay_factor=eagle_loss_decay_factor, - eagle_architecture_config=eagle_architecture_config, - eagle_decoder_type=eagle_decoder_type, - eagle_ttt_steps=eagle_ttt_steps, - eagle_mix_hidden_states=eagle_mix_hidden_states, - ) + super().modify(config) # sequence_parallel is not used in offline eagle if self.eagle_offline: self.config.sequence_parallel = False self.eagle_config = dict_to_config( - eagle_architecture_config, + config.eagle_architecture_config, self.config.use_cpu_initialization, self.config.fp16, self.config.bf16, @@ -740,7 +718,7 @@ def modify( ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - assert eagle_self_logit_distillation, ( + assert self.eagle_self_logit_distillation, ( "Only logit distillation is supported when draft_vocab_size != vocab_size!" ) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index a51dbd88f0..4c497475e2 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -549,45 +549,23 @@ def _get_eagle_device(self): def modify( self, - eagle_offline, - eagle_hidden_state_distillation, - eagle_self_logit_distillation, - eagle_freeze_base_model, - eagle_report_acc, - eagle_reuse_base_decoder, - eagle_loss_decay_factor, - eagle_architecture_config, - eagle_decoder_type, - eagle_ttt_steps, - eagle_mix_hidden_states, + config, ): """Constructor. Args: config: The config for eagle decoder layers. """ - super().modify( - eagle_offline=eagle_offline, - eagle_hidden_state_distillation=eagle_hidden_state_distillation, - eagle_self_logit_distillation=eagle_self_logit_distillation, - eagle_freeze_base_model=eagle_freeze_base_model, - eagle_report_acc=eagle_report_acc, - eagle_reuse_base_decoder=eagle_reuse_base_decoder, - eagle_loss_decay_factor=eagle_loss_decay_factor, - eagle_architecture_config=eagle_architecture_config, - eagle_decoder_type=eagle_decoder_type, - eagle_ttt_steps=eagle_ttt_steps, - eagle_mix_hidden_states=eagle_mix_hidden_states, - ) + super().modify(config) - if eagle_decoder_type == "llama": + if self.eagle_decoder_type == "llama": # Use default eagle config decoder_cls = LlamaDecoderLayer - elif eagle_decoder_type == "kimik2": + elif self.eagle_decoder_type == "kimik2": decoder_cls = _setup_kimi_k2_decoder() - self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config) - self.eagle_config.eagle_decoder_type = eagle_decoder_type + self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config) + self.eagle_config.eagle_decoder_type = self.eagle_decoder_type # Hidden size and vocab size must match base model self.eagle_config.hidden_size = self._base_llm_config.hidden_size self.eagle_config.vocab_size = self._base_llm_config.vocab_size @@ -626,14 +604,14 @@ def modify( self.eagle_module.to(self._base_model.dtype).to(self._get_eagle_device()) # EAGLE-3 auxiliary hidden_states - if (not eagle_offline) and self.eagle_config.use_aux_hidden_state: + if (not self.eagle_offline) and self.eagle_config.use_aux_hidden_state: self._aux_hidden_states = [] for layer_idx, layer in enumerate(self._base_model.layers): if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) # delete base model layers for offline training - if eagle_offline: + if self.eagle_offline: self._base_model._modules.pop("layers") # NOTE: this is a temporary hack to bypass hf trainer check: From ef832c447a876c5240f17db692d1c0a528617a42 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 4 Mar 2026 13:15:58 -0800 Subject: [PATCH 6/9] address comments Signed-off-by: Ye Yu --- .../torch/speculative/plugins/transformers.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 4c497475e2..1b85c342e7 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -908,14 +908,11 @@ def forward( # ====Run eagle forward with extra training-time-test steps==== for ttt_step in range(self.eagle_ttt_steps): # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. - if self.eagle_mix_hidden_states: - eagle_attention_mask = eagle_attn_mask_0 - else: - eagle_attention_mask = ( - eagle_attn_mask_0 - if ttt_step == 0 - else self._get_ttt_attention_mask(b, seq_length, ttt_step) - ) + eagle_attention_mask = ( + eagle_attn_mask_0 + if self.eagle_mix_hidden_states or ttt_step == 0 + else self._get_ttt_attention_mask(b, seq_length, ttt_step) + ) with ( enable_cp_ttt_patch() if self.training and not self.eagle_mix_hidden_states @@ -935,20 +932,14 @@ def forward( num_to_replace = max(1, seq_len_s // (2**ttt_step + 1)) # Randomly select positions for each batch to replace - rand_indices = torch.stack( - [ - torch.randperm(seq_len_s, device=eagle_input_hiddens.device)[ - :num_to_replace - ] - for _ in range(batch_size) - ], - dim=0, - ) + rand_indices = torch.rand( + batch_size, seq_len_s, device=eagle_input_hiddens.device + ).argsort(dim=1)[:, :num_to_replace] - for batch_idx in range(batch_size): - eagle_input_hiddens[batch_idx, rand_indices[batch_idx], :] = ( - eagle_output_hiddens[batch_idx, rand_indices[batch_idx], :] - ) + batch_indices = torch.arange(batch_size)[:, None] + eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[ + batch_indices, rand_indices + ] else: eagle_input_hiddens = eagle_output_hiddens From 486001b92e25750147d62f902cdccc5f9c888e87 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 5 Mar 2026 13:17:56 -0800 Subject: [PATCH 7/9] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 09499e3091..f04c1adf62 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -895,10 +895,13 @@ def _get_eagle_module_inputs( # [sq, 1, b, sk] -> [b, 1, sq, sk] eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() - eagle_inputs["rotary_pos_emb"] = torch.cat( - [rotary_pos_emb] * (ttt_step + 1), - dim=0, - ) + if self.eagle_mix_hidden_states: + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb + else: + eagle_inputs["rotary_pos_emb"] = torch.cat( + [rotary_pos_emb] * (ttt_step + 1), + dim=0, + ) return eagle_inputs From 00bf3e958143f7938bd7e9e27c389f8f4cbb1689 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 5 Mar 2026 13:23:31 -0800 Subject: [PATCH 8/9] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index f04c1adf62..81388fbfd6 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1201,6 +1201,8 @@ def forward( dim=0, ) + # Clone to avoid inplace modification of view created in no_grad mode + eagle_module_input_hidden_states = eagle_module_input_hidden_states.clone() for batch_idx in range(batch_size): eagle_module_input_hidden_states[rand_indices[batch_idx], batch_idx, :] = ( eagle_module_output_hidden_states[rand_indices[batch_idx], batch_idx, :] From e2b7e0e8f9971a2bf134f26211040e48e6d8b7f3 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 5 Mar 2026 13:30:58 -0800 Subject: [PATCH 9/9] minor Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 81388fbfd6..f9ee9873ec 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -873,6 +873,7 @@ def _get_eagle_module_inputs( if self.eagle_mix_hidden_states: eagle_inputs["attention_mask"] = attention_mask + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb else: attn_mask = attention_mask.clone().detach() # [b, 1, sq, sk] -> [sq, 1, b, sk] @@ -894,10 +895,6 @@ def _get_eagle_module_inputs( ) # [sq, 1, b, sk] -> [b, 1, sq, sk] eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() - - if self.eagle_mix_hidden_states: - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - else: eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] * (ttt_step + 1), dim=0,