Skip to content
Open
22 changes: 15 additions & 7 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import onnx
import torch
Expand Down Expand Up @@ -42,7 +42,6 @@
hash_dict_params,
load_json,
require_value,
to_named_specializations,
)
from QEfficient.utils.export_utils import export_wrapper

Expand Down Expand Up @@ -487,7 +486,7 @@ def _compile(
specializations: Optional[List[Dict[str, int]]] = None,
custom_io: Optional[Dict[str, str]] = None,
mdp_ts_num_devices: int = 1,
num_speculative_tokens: Optional[int] = None,
num_speculative_tokens: Optional[Union[int, List[int]]] = None,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
Expand All @@ -509,7 +508,7 @@ def _compile(
:specializations (list): List of specializations to compile for
:custom_io (dict): Custom IO to specify the input and outputs in different formats than default
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
:num_speculative_tokens (int | List[int], optional): Number of speculative tokens for TLM decode. A plain int K compiles one decode specialization (seq_len=K+1). A list [K0, K1, ...] compiles one specialization per value, enabling per-step dispatch to the cheapest kernel.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
:qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.``
:compiler_options: Pass any compiler option as input.
Expand Down Expand Up @@ -635,9 +634,18 @@ def _compile(
# Write specializations.json file
if specializations is not None:
specializations_json = compile_dir / "specializations.json"
specializations_data = {
"specializations": to_named_specializations(specializations, module_name=specialization_module_name)
}
# Strip internal _graph_name tags and write flat format for qaic-compile.
# Named format ({"name": ..., "symbols": {...}}) is only required for the
# QNN path (already branched off above). The qaic-compile binary and its
# MDP (multi-device partition) firmware support only the flat format:
# {"batch_size": "4", "seq_len": "5", ...}
# Using named format for MDP QPCs causes a RuntimeError at ExecObj
# creation time ("Failed to create ExecObj") on 4-device tensor-parallel.
# All values must be strings — qaic-compile rejects integer values.
flat_specs = [
{key: str(val) for key, val in spec.items() if key != "_graph_name"} for spec in specializations
]
specializations_data = {"specializations": flat_specs}
create_json(str(specializations_json), specializations_data)
command.append(f"-network-specialization-config={specializations_json}")

Expand Down
124 changes: 111 additions & 13 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3370,6 +3370,58 @@ def build_decode_specialization(
result["_graph_name"] = "Decode"
return result

def _build_decode_spec_for_k(
self,
k: int,
ctx_len: int = 128,
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
prefill_seq_len: int = 32,
comp_ctx_lengths: Optional[int] = None,
):
"""
Builds a TLM decode specialization for proposal length *k* (``seq_len = k+1``).

Parameters
----------
k : int
Number of speculative (draft) tokens. ``seq_len`` and ``num_logits_to_keep``
are both set to ``k + 1``.
ctx_len : int, optional
Maximum context length. Default is 128.
batch_size : int, optional
Batch size. Default is 1.
kv_cache_batch_size : int, optional
Batch size for KV cache allocation.
full_batch_size : int, optional
Continuous batching full batch size.
prefill_seq_len : int, optional
Used to detect and skip duplicate specializations (when ``seq_len == prefill_seq_len``
and continuous batching is disabled).

Returns
-------
Optional[Dict[str, Union[int, str]]]
Specialization dict, or ``None`` if it would duplicate the prefill specialization.
"""
seq_len = k + 1
if seq_len == prefill_seq_len and not self.continuous_batching:
return None
spec = {
"seq_len": seq_len,
"ctx_len": ctx_len,
"num_logits_to_keep": seq_len,
}
if comp_ctx_lengths is not None:
spec["comp_ctx_lengths"] = comp_ctx_lengths
if self.continuous_batching:
spec["batch_size"] = full_batch_size
spec["full_batch_size"] = kv_cache_batch_size
else:
spec["batch_size"] = kv_cache_batch_size
return {k_: v for k_, v in spec.items() if v is not None}

def compile(
self,
onnx_path: Optional[str] = None,
Expand All @@ -3386,7 +3438,7 @@ def compile(
num_cores: int = 16, # FIXME: Make this mandatory arg
mxfp6_matmul: bool = False,
mxint8_kv_cache: bool = False,
num_speculative_tokens: Optional[int] = None,
num_speculative_tokens: Optional[List[int]] = None,
prefill_only: Optional[bool] = None,
use_onnx_subfunctions: bool = False,
offload_pt_weights: Optional[bool] = True,
Expand Down Expand Up @@ -3428,9 +3480,12 @@ def compile(
Use MXFP6 compression for weights. Default is False.
mxint8_kv_cache : bool, optional
Use MXINT8 compression for KV cache. Default is False.
num_speculative_tokens : int, optional
Number of speculative tokens for Speculative Decoding Target Language Model.
Required if the model is configured as a Target Language Model (`is_tlm=True`).
num_speculative_tokens : list[int], optional
List of proposal lengths for Speculative Decoding Target Language Model.
Each value K generates a decode specialization with seq_len=K+1 and
num_logits_to_keep=K+1. Include 0 to compile a cheap single-token fallback
(e.g. ``[0, 3]`` for a fallback + full K=3 decode). Required if the model is
configured as a Target Language Model (``is_tlm=True``).
prefill_only : bool, optional
If True, compiles only for the prefill stage. If False, compiles only for
the decode stage. If None, compiles for both stages. Default is None.
Expand Down Expand Up @@ -3465,10 +3520,10 @@ def compile(
TypeError
If `prefill_only` is not a boolean.
If `full_batch_size` is None when `continuous_batching` is True.
If `num_speculative_tokens` is None when the model is a TLM.
If `num_speculative_tokens` is None or empty when the model is a TLM.
ValueError
If KV caching is requested without continuous batching (`full_batch_size`).
If `include_sampler` is True and `num_speculative_tokens` is greater than 0.
If `include_sampler` is True and `num_speculative_tokens` contains a value > 0.
If `num_speculative_tokens` is not an integer greater than 1.
If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models.

Expand Down Expand Up @@ -3533,14 +3588,32 @@ def compile(
if prefill_only is not None and not isinstance(prefill_only, bool):
raise TypeError("`prefill_only` must be a boolean.")

_decode_ks = (
sorted(set(num_speculative_tokens))
if isinstance(num_speculative_tokens, (list, tuple))
else ([num_speculative_tokens] if num_speculative_tokens is not None else None)
)

if self.is_tlm:
num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len)
_max_k = _decode_ks[-1] if _decode_ks else None
validated_k = self.check_and_get_num_speculative_tokens(_max_k, prefill_seq_len)
if validated_k is not None and validated_k != _max_k:
# speculative_config in model.config overrides num_speculative_tokens.
# Warn if the user passed a list — the extra values are discarded.
if _decode_ks is not None and len(_decode_ks) > 1:
discarded = [k for k in _decode_ks if k != validated_k]
logger.warning(
f"speculative_config in model.config fixes num_speculative_tokens={validated_k}. "
f"Ignoring user-supplied values {discarded}. "
f"Pass num_speculative_tokens={validated_k} (or [{validated_k}]) to suppress this warning."
)
_decode_ks = [validated_k]

if (
self.model.qaic_config is not None
and self.model.qaic_config.get("include_sampler", False)
and num_speculative_tokens is not None
and num_speculative_tokens > 0
and _decode_ks is not None
and max(_decode_ks) > 0
):
raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.")

Expand Down Expand Up @@ -3577,8 +3650,33 @@ def compile(
)

if (prefill_only is None or not prefill_only) and prefill_seq_len != 1:
if self.comp_ctx_lengths_decode is not None:
# Adding elements from self.comp_ctx_lengths_decode to decode_specialization
if _decode_ks is not None and self.is_tlm:
# TLM multi-spec path: one decode specialization per K in num_speculative_tokens.
# CCL (comp_ctx_lengths) + multi-spec TLM is not yet supported: the per-K call
# to _build_decode_spec_for_k would need to iterate over CCL values, producing
# len(decode_ks) × len(comp_ctx_lengths_decode) decode specializations whose
# naming and ordering is untested. Reject early so users get a clear error
# instead of a silently wrong QPC.
if self.comp_ctx_lengths_decode is not None:
raise NotImplementedError(
"TLM multi-spec (num_speculative_tokens as a list) combined with "
"comp_ctx_lengths_decode is not yet supported. Pass a plain int for "
"num_speculative_tokens when using CCL."
)
for k in _decode_ks:
spec = self._build_decode_spec_for_k(
k=k,
ctx_len=ctx_len,
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
prefill_seq_len=prefill_seq_len,
)
if spec is not None:
specializations.append(spec)

elif self.comp_ctx_lengths_decode is not None:
# CCL loop (non-TLM)
for i in range(0, len(self.comp_ctx_lengths_decode)):
decode_spec = self.build_decode_specialization(
prefill_seq_len=prefill_seq_len,
Expand All @@ -3587,7 +3685,7 @@ def compile(
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
num_speculative_tokens=None,
)
if decode_spec:
specializations.append(decode_spec)
Expand All @@ -3599,7 +3697,7 @@ def compile(
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
num_speculative_tokens=None,
prefill_only=prefill_only,
)
if decode_spec:
Expand Down
Loading
Loading