From 59152efacb4b8b93362dfbb40a614e5944f71aed Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 29 Apr 2026 06:40:32 +0000 Subject: [PATCH] k25 dflash hardcode support Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 4 +- modelopt/torch/speculative/eagle/utils.py | 62 ++++++++++++++++--- .../torch/speculative/plugins/hf_dflash.py | 15 ++--- 3 files changed, 65 insertions(+), 16 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index a0a28d78a7..0e394255a3 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -108,7 +108,9 @@ def make_speculative_data_module( raise ValueError("sample_size must be -1 (use all samples) or a positive integer") if data_args.sample_size > 0: dumped_files = dumped_files[: data_args.sample_size] - train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss) + train_dataset = OfflineSupervisedDataset( + dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer + ) data_collator = EagleOfflineDataCollator(train_len=train_len) return { diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index f74fcb1e9f..09628aa0c8 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -78,6 +78,44 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) +def compute_assistant_mask_kimi(tokenizer, input_ids): + """Recover the assistant mask from already-tokenized Kimi chat IDs. + + For every <|im_assistant|> token, find its matching <|im_end|> and mark + the inclusive span. An unmatched trailing assistant marker (i.e. a + generation prompt at the end of full_ids) is left as 0 — this matches + the prefix-diff behavior in apply_chat_template_kimi. + """ + ids_list = input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) + + role_to_id = { + role: tokenizer.convert_tokens_to_ids(role) + for role in ("<|im_user|>", "<|im_assistant|>", "<|im_system|>") + } + assistant_id = role_to_id["<|im_assistant|>"] + other_role_ids = {tid for r, tid in role_to_id.items() if r != "<|im_assistant|>"} + end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + + mask = [0] * len(ids_list) + i = 0 + n = len(ids_list) + while i < n: + if ids_list[i] != assistant_id: + i += 1 + continue + j = i + 1 + while j < n and ids_list[j] != end_id and ids_list[j] not in other_role_ids: + j += 1 + if j < n and ids_list[j] == end_id: + for k in range(i, j + 1): + mask[k] = 1 + i = j + 1 + else: + i = j + + return torch.tensor(mask, dtype=torch.long) + + class OfflineSupervisedDataset(Dataset): """Offline dataset for supervised fine-tuning with pre-dumped hidden states. @@ -105,34 +143,42 @@ def __init__( self, dumped_files, answer_only_loss: bool = False, + tokenizer=None, ): """Initialize with a list of .pt file paths.""" super().__init__() self.dumped_files = dumped_files self.answer_only_loss = answer_only_loss + self.tokenizer = tokenizer def __len__(self): return len(self.dumped_files) def __getitem__(self, i) -> dict[str, torch.Tensor]: - offline_data = torch.load(self.dumped_files[i], weights_only=True) + try: + offline_data = torch.load(self.dumped_files[i], weights_only=True) + except Exception as e: + print(f"Error loading {self.dumped_files[i]}: {e}, trying to load previous file") + return self.__getitem__(i - 1) labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID) labels[..., :-1] = offline_data["input_ids"][..., 1:] if self.answer_only_loss: if "loss_mask" not in offline_data: - raise ValueError( - f"answer_only_loss=True requires a 'loss_mask' entry in the offline " - f".pt file, but {self.dumped_files[i]} does not have one. Re-dump " - f"with --answer-only-loss in compute_hidden_states_*.py." - ) - loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype) + loss_mask = compute_assistant_mask_kimi(self.tokenizer, offline_data["input_ids"]) + # loss_mask = torch.ones_like(offline_data["input_ids"]) + # raise ValueError( + # f"answer_only_loss=True requires a 'loss_mask' entry in the offline " + # f".pt file, but {self.dumped_files[i]} does not have one. Re-dump " + # f"with --answer-only-loss in compute_hidden_states_*.py." + # ) + # loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype) else: loss_mask = torch.ones_like(offline_data["input_ids"]) ret = { - "input_ids": offline_data["input_ids"], + "input_ids": offline_data["input_ids"].to(torch.long), "base_model_hidden_states": offline_data["hidden_states"], "aux_hidden_states": offline_data["aux_hidden_states"], "attention_mask": torch.ones_like(offline_data["input_ids"]), diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072..169b62f8d4 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -160,13 +160,14 @@ def modify(self, config): self.dflash_config.block_size = self.dflash_block_size # Target layer IDs - num_target_layers = ( - base_config.num_orig_hidden_layers - if self.dflash_offline - else base_config.num_hidden_layers - ) - num_draft_layers = self.dflash_config.num_hidden_layers - self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + # num_target_layers = ( + # base_config.num_orig_hidden_layers + # if self.dflash_offline + # else base_config.num_hidden_layers + # ) + # num_draft_layers = self.dflash_config.num_hidden_layers + # self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + self.target_layer_ids = [1, 12, 24, 35, 47, 58] self.dflash_config.target_layer_ids = self.target_layer_ids # mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context