Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/speculative/eagle1/qwen3_eagle1_perfectblend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
recipe: TrainEagle1Recipe

dist_env:
backend: nccl
timeout_minutes: 30

recipe_args:
target_model_name_or_path: Qwen/Qwen3-8B
train_data_path: /path/to/train.jsonl
val_data_path: null
train_split: null
val_split: null
output_dir: ./outputs/eagle1_qwen3_mvp
seq_length: 1024
micro_batch_size: 1
grad_accumulation_steps: 1
num_workers: 0
num_epochs: 1
draft_num_hidden_layers: 1
hidden_loss_weight: 1.0
token_loss_weight: 0.1
freeze_embeddings: true
trust_remote_code: false
shuffle_seed: 42
log_every_steps: 10
max_grad_norm: 1.0

optimizer:
lr: 1.0e-4
betas: [0.9, 0.95]
weight_decay: 0.0

checkpoint:
enabled: true
checkpoint_dir: ./outputs/eagle1_qwen3_mvp/checkpoints
model_save_format: safetensors
save_consolidated: true
37 changes: 37 additions & 0 deletions examples/speculative/eagle2/qwen3_eagle2_perfectblend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
recipe: TrainEagle2Recipe

dist_env:
backend: nccl
timeout_minutes: 30

recipe_args:
target_model_name_or_path: Qwen/Qwen3-8B
train_data_path: /path/to/train.jsonl
val_data_path: null
train_split: null
val_split: null
output_dir: ./outputs/eagle2_qwen3_mvp
seq_length: 1024
micro_batch_size: 1
grad_accumulation_steps: 1
num_workers: 0
num_epochs: 1
draft_num_hidden_layers: 1
hidden_loss_weight: 1.0
token_loss_weight: 0.1
freeze_embeddings: true
trust_remote_code: false
shuffle_seed: 42
log_every_steps: 10
max_grad_norm: 1.0

optimizer:
lr: 1.0e-4
betas: [0.9, 0.95]
weight_decay: 0.0

checkpoint:
enabled: true
checkpoint_dir: ./outputs/eagle2_qwen3_mvp/checkpoints
model_save_format: safetensors
save_consolidated: true
36 changes: 36 additions & 0 deletions examples/speculative/eagle3/qwen3_eagle3_perfectblend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
recipe: TrainEagle3Recipe

dist_env:
backend: nccl
timeout_minutes: 30

recipe_args:
target_model_name_or_path: Qwen/Qwen3-8B
train_data_path: /path/to/train.jsonl
val_data_path: null
train_split: null
val_split: null
output_dir: ./outputs/eagle3_qwen3_mvp
seq_length: 1024
micro_batch_size: 1
grad_accumulation_steps: 1
num_workers: 0
num_epochs: 1
ttt_steps: 4
draft_vocab_size: 8192
freeze_embeddings: true
trust_remote_code: false
shuffle_seed: 42
log_every_steps: 10
max_grad_norm: 1.0

optimizer:
lr: 1.0e-4
betas: [0.9, 0.95]
weight_decay: 0.0

checkpoint:
enabled: true
checkpoint_dir: ./outputs/eagle3_qwen3_mvp/checkpoints
model_save_format: safetensors
save_consolidated: true
12 changes: 7 additions & 5 deletions nemo_automodel/components/speculative/eagle/draft_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
decoder-only architecture whose layout matches Llama: GQA attention with
optional Q/K/V/O bias (`config.attention_bias`), SwiGLU MLP with optional
bias (`config.mlp_bias`), RMSNorm, and rotary position embeddings parameterized
by `config.rope_theta` / `config.rope_scaling`. This currently covers Llama
and Phi-3 dense (Phi-3 omits `attention_bias` / `mlp_bias`, which the
attention and MLP layers already read via
`getattr(config, "<field>", False)`).
by `config.rope_theta` / `config.rope_scaling`. This currently covers Llama,
Phi-3, and Qwen3 dense (Phi-3 omits `attention_bias` / `mlp_bias`, which
the attention and MLP layers already read via
`getattr(config, "<field>", False)`; Qwen3 decouples `head_dim` from
`hidden_size / num_attention_heads`, which the attention layer reads via
`getattr(config, "head_dim", ...)`).

Class names and the public `architectures` string remain ``LlamaEagle3*`` for
backward compatibility with already-trained checkpoints and with SGLang's
Expand Down Expand Up @@ -442,7 +444,7 @@ def __init__(self, config: PretrainedConfig):


class LlamaEagle3DraftModel(PreTrainedModel):
"""Llama-style dense EAGLE-3 draft model (Llama, Phi-3).
"""Llama-style dense EAGLE-3 draft model (Llama, Phi-3, Qwen3).

State dict keys match SGLang's ``LlamaForCausalLMEagle3`` so the saved
checkpoint can be loaded by SGLang's inference engine without any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Llama-style dense LLM draft model for EAGLE-1 / EAGLE-2 training.

Config-driven; supports Llama and Phi-3 dense via standard HF config
Config-driven; supports Llama, Phi-3, and Qwen3 dense via standard HF config
fields (``attention_bias``, ``mlp_bias``, ``rope_theta``/``rope_scaling``,
``rms_norm_eps``). Class names are retained for checkpoint-architectures
compatibility.
Expand Down Expand Up @@ -122,7 +122,9 @@ def __init__(self, config: PretrainedConfig):
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=getattr(config, "mlp_bias", False)
)
self.act_fn = nn.SiLU()
from transformers.activations import ACT2FN

self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
Expand Down Expand Up @@ -158,7 +160,7 @@ def forward(
class LlamaEagleDraftModel(PreTrainedModel):
"""Llama-style dense draft that predicts next-step hidden states.

Works with Llama and Phi-3 dense configs. The class name is
Works with Llama, Phi-3, and Qwen3 dense configs. The class name is
retained for backward compatibility with already-trained checkpoints.
"""

Expand Down
1 change: 1 addition & 0 deletions nemo_automodel/components/speculative/eagle/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DraftSpec:
_DENSE_ARCHITECTURES: tuple[str, ...] = (
"LlamaForCausalLM",
"Phi3ForCausalLM",
"Qwen3ForCausalLM",
)


Expand Down
4 changes: 2 additions & 2 deletions nemo_automodel/recipes/llm/train_eagle1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""EAGLE-1 / EAGLE-2 training recipe for Llama-style dense LLMs (Llama, Phi-3)."""
"""EAGLE-1 / EAGLE-2 training recipe for Llama-style dense LLMs (Llama, Phi-3, Qwen3)."""

from __future__ import annotations

Expand Down Expand Up @@ -62,7 +62,7 @@ def _all_reduce_mean(value: torch.Tensor) -> torch.Tensor:


class TrainEagle1Recipe(BaseRecipe):
"""Recipe for EAGLE-1 training on Llama-style dense LLMs (Llama, Phi-3)."""
"""Recipe for EAGLE-1 training on Llama-style dense LLMs (Llama, Phi-3, Qwen3)."""

def __init__(self, cfg):
self.cfg = cfg
Expand Down
4 changes: 2 additions & 2 deletions nemo_automodel/recipes/llm/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""EAGLE-3 training recipe for Llama-style dense LLMs (Llama, Phi-3)."""
"""EAGLE-3 training recipe for Llama-style dense LLMs (Llama, Phi-3, Qwen3)."""

from __future__ import annotations

Expand Down Expand Up @@ -81,7 +81,7 @@ def _all_reduce_mean(value: torch.Tensor) -> torch.Tensor:


class TrainEagle3Recipe(BaseRecipe):
"""Recipe for EAGLE-3 training on Llama-style dense LLMs (Llama, Phi-3)."""
"""Recipe for EAGLE-3 training on Llama-style dense LLMs (Llama, Phi-3, Qwen3)."""

def __init__(self, cfg):
self.cfg = cfg
Expand Down
Loading