Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def make_speculative_data_module(
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
if data_args.sample_size > 0:
dumped_files = dumped_files[: data_args.sample_size]
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss)
train_dataset = OfflineSupervisedDataset(
dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer
)
data_collator = EagleOfflineDataCollator(train_len=train_len)

return {
Expand Down
62 changes: 54 additions & 8 deletions modelopt/torch/speculative/eagle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,44 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


def compute_assistant_mask_kimi(tokenizer, input_ids):
"""Recover the assistant mask from already-tokenized Kimi chat IDs.

For every <|im_assistant|> token, find its matching <|im_end|> and mark
the inclusive span. An unmatched trailing assistant marker (i.e. a
generation prompt at the end of full_ids) is left as 0 — this matches
the prefix-diff behavior in apply_chat_template_kimi.
"""
ids_list = input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids)

role_to_id = {
role: tokenizer.convert_tokens_to_ids(role)
for role in ("<|im_user|>", "<|im_assistant|>", "<|im_system|>")
}
assistant_id = role_to_id["<|im_assistant|>"]
other_role_ids = {tid for r, tid in role_to_id.items() if r != "<|im_assistant|>"}
end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

mask = [0] * len(ids_list)
i = 0
n = len(ids_list)
while i < n:
if ids_list[i] != assistant_id:
i += 1
continue
j = i + 1
while j < n and ids_list[j] != end_id and ids_list[j] not in other_role_ids:
j += 1
if j < n and ids_list[j] == end_id:
for k in range(i, j + 1):
mask[k] = 1
i = j + 1
else:
i = j

return torch.tensor(mask, dtype=torch.long)


class OfflineSupervisedDataset(Dataset):
"""Offline dataset for supervised fine-tuning with pre-dumped hidden states.

Expand Down Expand Up @@ -105,34 +143,42 @@ def __init__(
self,
dumped_files,
answer_only_loss: bool = False,
tokenizer=None,
):
"""Initialize with a list of .pt file paths."""
super().__init__()
self.dumped_files = dumped_files
self.answer_only_loss = answer_only_loss
self.tokenizer = tokenizer

def __len__(self):
return len(self.dumped_files)

def __getitem__(self, i) -> dict[str, torch.Tensor]:
offline_data = torch.load(self.dumped_files[i], weights_only=True)
try:
offline_data = torch.load(self.dumped_files[i], weights_only=True)
except Exception as e:
print(f"Error loading {self.dumped_files[i]}: {e}, trying to load previous file")
return self.__getitem__(i - 1)

labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID)
labels[..., :-1] = offline_data["input_ids"][..., 1:]

if self.answer_only_loss:
if "loss_mask" not in offline_data:
raise ValueError(
f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
f"with --answer-only-loss in compute_hidden_states_*.py."
)
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
loss_mask = compute_assistant_mask_kimi(self.tokenizer, offline_data["input_ids"])
# loss_mask = torch.ones_like(offline_data["input_ids"])
# raise ValueError(
# f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
# f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
# f"with --answer-only-loss in compute_hidden_states_*.py."
# )
# loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
else:
loss_mask = torch.ones_like(offline_data["input_ids"])

ret = {
"input_ids": offline_data["input_ids"],
"input_ids": offline_data["input_ids"].to(torch.long),
"base_model_hidden_states": offline_data["hidden_states"],
"aux_hidden_states": offline_data["aux_hidden_states"],
"attention_mask": torch.ones_like(offline_data["input_ids"]),
Expand Down
15 changes: 8 additions & 7 deletions modelopt/torch/speculative/plugins/hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,14 @@ def modify(self, config):
self.dflash_config.block_size = self.dflash_block_size

# Target layer IDs
num_target_layers = (
base_config.num_orig_hidden_layers
if self.dflash_offline
else base_config.num_hidden_layers
)
num_draft_layers = self.dflash_config.num_hidden_layers
self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers)
# num_target_layers = (
# base_config.num_orig_hidden_layers
# if self.dflash_offline
# else base_config.num_hidden_layers
# )
# num_draft_layers = self.dflash_config.num_hidden_layers
# self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers)
self.target_layer_ids = [1, 12, 24, 35, 47, 58]
self.dflash_config.target_layer_ids = self.target_layer_ids

# mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context
Expand Down
Loading