Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/specdec_bench/dflash_kimi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
chat_template_args:
thinking: true
engine_args:
mem_fraction_static: 0.9
speculative_num_draft_tokens: 8
# cuda_graph_max_bs: 128
speculative_dflash_draft_window_size: 4096
disable_cuda_graph: true
sampling_kwargs:
temperature: 0
7 changes: 7 additions & 0 deletions examples/specdec_bench/dflash_qwen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
engine_args:
mem_fraction_static: 0.9
speculative_num_draft_tokens: 8
speculative_dflash_draft_window_size: 4096
mamba_scheduler_strategy: extra_buffer
sampling_kwargs:
temperature: 0
2 changes: 1 addition & 1 deletion examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def run_simple(args):
type=str,
required=False,
default="EAGLE3",
choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"],
choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "DFLASH", "NONE"],
help="Speculative algorithm to use",
)
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory")
Expand Down
93 changes: 57 additions & 36 deletions examples/specdec_bench/specdec_bench/models/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,46 +41,67 @@ def __init__(
speculative_algorithm = "STANDALONE"
elif speculative_algorithm == "NGRAM":
speculative_algorithm = "LOOKAHEAD"
elif speculative_algorithm == "DFLASH":
pass # SGLang native name, pass through
elif speculative_algorithm == "NONE":
speculative_algorithm = None

engine_kwargs = dict(
model_path=model_dir,
skip_tokenizer_init=True,
trust_remote_code=kwargs.get("trust_remote_code", False),
mem_fraction_static=kwargs.get("mem_fraction_static", 0.8),
disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False),
tp_size=kwargs.get("tensor_parallel_size", 1),
ep_size=kwargs.get("moe_expert_parallel_size", 1),
torch_compile_max_bs=max_concurrent_requests,
max_running_requests=max_concurrent_requests,
attention_backend=kwargs.get("attention_backend"),
enable_torch_compile=kwargs.get("enable_torch_compile", False),
cuda_graph_max_bs=max_concurrent_requests,
disable_cuda_graph=False,
)
if speculative_algorithm is not None:
# https://github.com/sgl-project/sglang/pull/3582
self.model = sgl.Engine(
model_path=model_dir,
skip_tokenizer_init=True,
trust_remote_code=kwargs.get("trust_remote_code", False),
mem_fraction_static=0.8,
disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False),
tp_size=kwargs.get("tensor_parallel_size", 1),
ep_size=kwargs.get("moe_expert_parallel_size", 1),
speculative_algorithm=speculative_algorithm,
speculative_num_steps=kwargs.get("speculative_num_steps", 3),
speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1),
speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4),
speculative_draft_model_path=kwargs.get("draft_model_dir"),
torch_compile_max_bs=max_concurrent_requests,
max_running_requests=max_concurrent_requests,
attention_backend=kwargs.get("attention_backend"),
enable_torch_compile=kwargs.get("enable_torch_compile", False),
cuda_graph_max_bs=max_concurrent_requests,
disable_cuda_graph=False,
)
else:
self.model = sgl.Engine(
model_path=model_dir,
skip_tokenizer_init=True,
trust_remote_code=kwargs.get("trust_remote_code", False),
mem_fraction_static=0.8,
disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False),
tp_size=kwargs.get("tensor_parallel_size", 1),
ep_size=kwargs.get("moe_expert_parallel_size", 1),
torch_compile_max_bs=max_concurrent_requests,
max_running_requests=max_concurrent_requests,
attention_backend=kwargs.get("attention_backend"),
enable_torch_compile=kwargs.get("enable_torch_compile", False),
cuda_graph_max_bs=max_concurrent_requests,
disable_cuda_graph=False,
)
engine_kwargs["speculative_algorithm"] = speculative_algorithm
num_draft_tokens = kwargs.get("speculative_num_draft_tokens", 4)
engine_kwargs["speculative_num_draft_tokens"] = num_draft_tokens
engine_kwargs["speculative_draft_model_path"] = kwargs.get("draft_model_dir")
if speculative_algorithm == "DFLASH":
if "speculative_dflash_draft_window_size" in kwargs:
engine_kwargs["speculative_dflash_draft_window_size"] = kwargs[
"speculative_dflash_draft_window_size"
]
print(
f"[specdec_bench] DFLASH ignores --draft_length / speculative_num_steps / "
f"speculative_eagle_topk; effective draft block = "
f"speculative_num_draft_tokens={num_draft_tokens}"
)
else:
engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3)
engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1)

# Forward any other kwargs (e.g. from runtime_params.engine_args) to
# sgl.Engine, letting yaml override the defaults set above. Skip only
# specdec_bench-internal routing keys that should never reach SGLang.
_internal_keys = frozenset({
"speculative_algorithm",
"draft_model_dir",
"speculative_num_steps",
"speculative_eagle_topk",
"speculative_num_draft_tokens",
"speculative_dflash_draft_window_size",
"tensor_parallel_size",
"moe_expert_parallel_size",
"tokenizer_path",
"use_draft_logits",
})
for _k, _v in kwargs.items():
if _k in _internal_keys:
continue
engine_kwargs[_k] = _v

self.model = sgl.Engine(**engine_kwargs)

self.sampling_config = sampling_kwargs

Expand Down
24 changes: 23 additions & 1 deletion examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def make_speculative_data_module(
raise ValueError("sample_size must be -1 (use all samples) or a positive integer")
if data_args.sample_size > 0:
dumped_files = dumped_files[: data_args.sample_size]
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss)
train_dataset = OfflineSupervisedDataset(dumped_files, answer_only_loss=answer_only_loss, tokenizer=tokenizer)
data_collator = EagleOfflineDataCollator(train_len=train_len)

return {
Expand Down Expand Up @@ -159,10 +159,14 @@ def compute_loss(self, *args, **kwargs):
self.state.training_accs = []
if not hasattr(self.state, "component_losses"):
self.state.component_losses = {"eagle": [], "preservation": []}
if not hasattr(self.state, "training_stats"):
self.state.training_stats = []
kwargs.pop("num_items_in_batch", None)
loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs)
if hasattr(outputs, "train_acc") and any(outputs.train_acc):
self.state.training_accs.append(outputs.train_acc)
if getattr(outputs, "train_stats", None):
self.state.training_stats.append(outputs.train_stats)
# Track per-component losses
for key, attr in [
("eagle", "eagle_loss"),
Expand Down Expand Up @@ -261,6 +265,22 @@ def on_log(self, args, state, control, **kwargs):
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
logs["estimated_training_ar"] = est_ar

# Aggregate dflash debug stats over the log window.
if getattr(state, "training_stats", None):
keys = set()
for s in state.training_stats:
keys.update(s.keys())
for k in keys:
vals = [s[k] for s in state.training_stats if k in s]
if not vals:
continue
if isinstance(vals[0], list):
arr = np.array(vals) # [N_steps, P]
for j, m in enumerate(arr.mean(axis=0).tolist()):
logs[f"train_stats/{k}_pos_{j}"] = float(m)
else:
logs[f"train_stats/{k}"] = float(np.mean(vals))

# log to wandb
if wandb is not None and wandb.run is not None and is_master():
if logs:
Expand All @@ -276,6 +296,8 @@ def on_log(self, args, state, control, **kwargs):
state.training_accs = []
if hasattr(state, "component_losses"):
state.component_losses = {"eagle": [], "preservation": []}
if hasattr(state, "training_stats"):
state.training_stats = []
return control

def on_step_end(self, args, state, control, **kwargs):
Expand Down
15 changes: 13 additions & 2 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
from modelopt.torch.utils import print_rank_0

torch.manual_seed(0)
torch.manual_seed(3)
mto.enable_huggingface_checkpointing()


Expand Down Expand Up @@ -250,6 +250,9 @@ def train():

checkpoint = training_args.resume_from_checkpoint or last_checkpoint

#NOTE: patch for k25 dflash
# checkpoint=None

use_offline_training = data_args.offline_data_path is not None

if checkpoint:
Expand Down Expand Up @@ -370,7 +373,15 @@ def train():
)

print_rank_0("Start training...")
trainer.train(resume_from_checkpoint=checkpoint)
# trainer.train(resume_from_checkpoint=checkpoint)
#NOTE:patch for k25 dflash
trainer.create_optimizer_and_scheduler(num_training_steps=training_args.max_steps)
optimizer_path = os.path.join(checkpoint, "optimizer.pt")
trainer.optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
for param_group in trainer.optimizer.param_groups:
param_group["lr"] = training_args.learning_rate
print_rank_0(f"Loaded optimizer from {optimizer_path}")
trainer.train() #NOTE: patch for k25 dflash
trainer.save_state()
trainer.save_model(training_args.output_dir)

Expand Down
96 changes: 81 additions & 15 deletions modelopt/torch/speculative/eagle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@

"""Eagle model utils."""

from tkinter.constants import NONE
from typing import Any

import torch
from torch.utils.data import Dataset
from transformers.trainer_pt_utils import LabelSmoother

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


Expand Down Expand Up @@ -78,6 +78,51 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


def compute_assistant_mask_kimi(tokenizer, input_ids):
"""Recover the assistant mask from already-tokenized Kimi chat IDs.

For every <|im_assistant|> token, locate the following <|im_middle|> and
matching <|im_end|>, and mark only the inner content span (exclusive of
both markers). This matches HF's generation-tag mask semantics: only the
assistant's actual reply tokens count, not role/separator markers.

An unmatched assistant span (interrupted by a new role marker, or a
trailing generation prompt at end of sequence) is marked from
<|im_middle|>+1 up to but not including the next role marker / EOS. If
<|im_middle|> is absent within the span, nothing is marked for it.
"""
ids_list = input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids)

role_to_id = {
role: tokenizer.convert_tokens_to_ids(role)
for role in ("<|im_user|>", "<|im_assistant|>", "<|im_system|>")
}
assistant_id = role_to_id["<|im_assistant|>"]
other_role_ids = {tid for r, tid in role_to_id.items() if r != "<|im_assistant|>"}
end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
middle_id = tokenizer.convert_tokens_to_ids("<|im_middle|>")

mask = [0] * len(ids_list)
i = 0
n = len(ids_list)
while i < n:
if ids_list[i] != assistant_id:
i += 1
continue
j = i + 1
m = -1
while j < n and ids_list[j] != end_id and ids_list[j] not in other_role_ids:
if m < 0 and ids_list[j] == middle_id:
m = j
j += 1
if m >= 0:
for k in range(m + 1, j):
mask[k] = 1
i = j + 1 if (j < n and ids_list[j] == end_id) else j

return torch.tensor(mask, dtype=torch.long)


class OfflineSupervisedDataset(Dataset):
"""Offline dataset for supervised fine-tuning with pre-dumped hidden states.

Expand Down Expand Up @@ -105,34 +150,49 @@ def __init__(
self,
dumped_files,
answer_only_loss: bool = False,
tokenizer = None,
):
"""Initialize with a list of .pt file paths."""
super().__init__()
self.dumped_files = dumped_files
self.answer_only_loss = answer_only_loss
self.tokenizer = tokenizer

def __len__(self):
return len(self.dumped_files)

def __getitem__(self, i) -> dict[str, torch.Tensor]:
offline_data = torch.load(self.dumped_files[i], weights_only=True)
try:
offline_data = torch.load(self.dumped_files[i], weights_only=True)
except Exception as e:
print(f"Error loading {self.dumped_files[i]}: {e}, trying to load previous file")
return self.__getitem__(i-1)

labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID)
labels[..., :-1] = offline_data["input_ids"][..., 1:]

if self.answer_only_loss:
if "loss_mask" not in offline_data:
raise ValueError(
f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
f"with --answer-only-loss in compute_hidden_states_*.py."
)
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
loss_mask = compute_assistant_mask_kimi(self.tokenizer, offline_data["input_ids"])
ratio = loss_mask.float().mean().item()
if ratio < 0.3:
# print(f"Drop sample id {i}, 1s ratio: {ratio:.4f}")
return self.__getitem__(i-1)
# print(f"sample id {i}, input ids length: {len(offline_data['input_ids'])}, loss_mask length: {len(loss_mask)}, 1s ratio: {loss_mask.float().mean().item():.4f}")

# loss_mask = torch.ones_like(offline_data["input_ids"])
# raise ValueError(
# f"answer_only_loss=True requires a 'loss_mask' entry in the offline "
# f".pt file, but {self.dumped_files[i]} does not have one. Re-dump "
# f"with --answer-only-loss in compute_hidden_states_*.py."
# )
# loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype)
else:
loss_mask = torch.ones_like(offline_data["input_ids"])
# loss_mask = torch.ones_like(offline_data["input_ids"])

ret = {
"input_ids": offline_data["input_ids"],
"input_ids": offline_data["input_ids"].to(torch.long),
"base_model_hidden_states": offline_data["hidden_states"],
"aux_hidden_states": offline_data["aux_hidden_states"],
"attention_mask": torch.ones_like(offline_data["input_ids"]),
Expand All @@ -149,14 +209,17 @@ def __init__(self, train_len):
"""Initialize with the target sequence length for truncation/padding."""
self.train_len = train_len

def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0):
"""Pad or truncate a tensor to length along a given dimension."""
def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0, padding_token_id: int = 0):
"""Pad or truncate a tensor to length along a given dimension.
For input_ids, fill the pad with token 163839.
"""
dim = dim % x.ndim # support negative dimension

# allocate output tensor
# Determine appropriate padding token
# Only use 163839 for input_ids (handled in the caller)
out_shape = list(x.shape)
out_shape[dim] = length
out = x.new_zeros(out_shape)
out = x.new_full(out_shape, padding_token_id)

# construct copy slice
slc = [slice(None)] * x.ndim
Expand All @@ -168,9 +231,12 @@ def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0):

def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
"""Collate a list of feature dicts into a single padded/truncated batch."""
# For input_ids, use 163839 as padding_token_id
base_batch = {
k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features])
for k in ["input_ids", "attention_mask", "loss_mask", "labels"]
"input_ids": torch.stack([self._pad_or_truncate(item["input_ids"], self.train_len, padding_token_id=163839) for item in features]),
"attention_mask": torch.stack([self._pad_or_truncate(item["attention_mask"], self.train_len) for item in features]),
"loss_mask": torch.stack([self._pad_or_truncate(item["loss_mask"], self.train_len) for item in features]),
"labels": torch.stack([self._pad_or_truncate(item["labels"], self.train_len) for item in features]),
}

base_model_outputs = {
Expand Down
Loading
Loading