Skip to content

support multiple TLM decode specializations via num_speculative_tokens list #984

Open
eplatero97 wants to merge 11 commits into
quic:release/v1.22.0_tmpfrom
eplatero97:spd_specs
Open

support multiple TLM decode specializations via num_speculative_tokens list #984
eplatero97 wants to merge 11 commits into
quic:release/v1.22.0_tmpfrom
eplatero97:spd_specs

Conversation

@eplatero97
Copy link
Copy Markdown
Contributor

@eplatero97 eplatero97 commented May 13, 2026

Description

Speculative decoding on QAIC requires statically compiled shapes. Previously, a TLM
could only be compiled for a single proposal length K, forcing every decode step to
run the full seq_len=K+1 kernel regardless of how many tokens the draft model
actually proposed. This PR allows compiling multiple decode specializations in one
QPC so the runtime can dispatch to the smallest kernel that covers the actual proposal
count, reducing unnecessary compute on short proposals without changing correctness.

Summary

  • QEFFAutoModelForCausalLM.compile() now accepts num_speculative_tokens as a
    List[int]. Each value K compiles one TLM decode specialization (seq_len=K+1,
    num_logits_to_keep=K+1), enabling per-step dispatch to the cheapest kernel that
    covers the actual proposal count.
  • Plain int input still works (backward compatible — treated as [K]).
  • Removes enable_fallback_decode_spec; equivalent behavior is [0, K].
  • Fixes flat-format specializations.json write in _compile (named format caused
    RuntimeError: Failed to create ExecObj on 4-device MDP QPCs).
  • Improves find_candidate_pred_tokens to return the longest n-gram continuation
    across all candidates instead of early-returning on the first match.

Results

Measured on Llama-3.1-8B-Instruct (mxfp6/mxint8, 4 SOCs), MT-bench, 80 prompts.
num_speculative_tokens=[0, 4] (K=4) vs the prior fixed-K baseline.

Request throughput vs nospec (req/s):

image
max new seqs nospec ngram K=4 (fixed) ngram varK suffix K=4 (fixed) suffix varK
1 0.137 0.108 (−21%) 0.282 (+106%) 0.131 (−4%) 0.297 (+116%)
2 0.257 0.197 (−23%) 0.467 (+82%) 0.237 (−8%) 0.495 (+93%)
4 0.448 0.340 (−24%) 0.640 (+43%) 0.405 (−9%) 0.758 (+69%)
8 0.674 0.580 (−14%) 1.224 (+82%) 0.692 (+3%) 1.142 (+69%)

varK vs fixed-K improvement:

max new seqs ngram Δ suffix Δ
1 +162% +126%
2 +137% +109%
4 +88% +87%
8 +111% +65%

Key takeaways:

  • Fixed-K SpD is 14–24% slower than nospec on QAIC (each step runs the full K+1 kernel even when no tokens are proposed).
  • Variable-K reverses this regression: +43–116% throughput vs nospec across both methods.
  • TPOT at mns=1 drops from ~34 ms/token (nospec) to ~16 ms/token with varK (~53% reduction).

Test plan

  • pytest tests/unit_test/models/test_modeling_auto_cpu.py::TestTLMMultiSpecSpecializations -v
  • pytest tests/unit_test/transforms/test_speculative_decoding.py::TestTLMForwardExecution::test_tlm_multi_spec_logit_consistency -v
  • pytest tests/transformers/spd/test_pld_inference.py::test_multi_spec_structure tests/transformers/spd/test_pld_inference.py::test_select_k -v
  • Hardware: pytest tests/transformers/spd/test_spd_inference.py -m on_qaic -k "pld" (on QAIC device)

Notes

  • TLM + CCL (comp_ctx_lengths_decode) combination raises NotImplementedError — not yet supported.
  • speculative_config in model config overrides user-supplied Ks; a logger.warning is emitted when values are discarded.

eplatero added 10 commits May 13, 2026 12:59
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
num_speculative_tokens now accepts List[int] in QEFFAutoModelForCausalLM.compile().
Each K in the list compiles one decode specialization: seq_len=K+1,
num_logits_to_keep=K+1.  Passing a plain int still works (backward compat).

Key changes
-----------
* modeling_auto.py
  - _build_decode_spec_for_k(k, ...) — new private helper that builds one
    decode specialisation for proposal length K; returns None when it would
    duplicate the prefill spec.
  - compile(): accept List[int] for num_speculative_tokens; sort + deduplicate;
    loop over K values to build specialisations; validate only max-K against
    model config.
  - Fix batch_size assignment in _build_decode_spec_for_k for CB mode.
  - check_and_get_num_speculative_tokens: forward override from model config
    back to caller so _decode_ks can be adjusted.

* prompt_lookup.py
  - _select_k() — picks the cheapest specialisation that covers the batch.
  - pld_spec_decode_inference() — per-K buffer allocation; per-step dispatch.
  - Remove --enable-fallback-decode-spec flag (superseded by [0, K] list).

* test_modeling_auto_cpu.py
  - TestTLMMultiSpecSpecializations (8 tests): structure + dedup + sort +
    backward compat.
  - test_multi_spec_structure (4 parametrized): per-K spec fields.
  - test_select_k (5 parametrized): dispatch helper.
  - test_tlm_multi_spec_logit_consistency (4): causal attention invariant.

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
The MDP (multi-device partition) firmware for 4-device tensor-parallel only
accepts flat specialization JSON:
  {"batch_size": "4", "seq_len": "5", "ctx_len": "2048", ...}

Previously _compile() called to_named_specializations() which wraps each entry
as {"name": "Prefill_0", "symbols": {...}}.  This named format is only required
by the QNN compiler path (handled separately before reaching this code).  Using
named format for qaic-compile with MDP produces a binary that fails at ExecObj
creation time with "Failed to create ExecObj".

Fix: strip the internal _graph_name tag and write the flat list directly.
The QNN path branches off early (line 548) and never reaches this block, so it
is not affected.

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
qaic-compile requires all values in specializations.json to be strings
(e.g. "4" not 4).  The to_named_specializations() call that was removed in
the previous commit was implicitly performing this str() conversion as part
of building the {name, symbols} wrapper.  Without it, integer values from
_build_decode_spec_for_k and build_prefill_specialization were written
directly to JSON, causing qaic-compile to exit with code 255.

Fix: apply str() to every value while stripping _graph_name tags.

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Add two new "Errors Encountered" entries (6 & 7) documenting the named-format
MDP firmware rejection and integer-value rejection found during vLLM integration.

Add new "modeling_qeff.py — Flat-Format specializations.json" section explaining:
- what changed in _compile() and why
- the _graph_name internal tag mechanism
- why to_named_specializations() is still needed (QNN path)
- the correct 3-spec output format
- commit references (5f1b6c5, 82f2d6c)

Also add modeling_qeff.py to the Files Changed table.

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Addresses all 5 required fixes from code review:

Critical
--------
Fix 1 (prompt_lookup.py): Guard _select_k against empty actual_proposals array.
  When all batch items finish generating, valid_batch_indices is all-False and
  actual_proposals[valid_batch_indices] is empty; .max() raised ValueError.
  Return decode_ks[-1] (max K) for empty input.

Fix 2 (modeling_auto.py): Emit logger.warning when speculative_config in
  model.config overrides a user-supplied list, naming the discarded values.
  Previously the list was silently truncated to [validated_k].

Major
-----
Fix 3 (modeling_auto.py): Raise NotImplementedError when TLM multi-spec and
  comp_ctx_lengths_decode are both active. The combination would compile
  specializations missing comp_ctx_lengths and fail at runtime on CCL hardware.
  Option B (early rejection) is safer than silently wrong QPCs.

Fix 4 (prompt_lookup.py): Wrap smaller-K session.run() in try/finally to
  restore the logit buffer placeholder. Without this, an exception leaves the
  session pointing at an undersized buffer; the next max-K path would write
  max_k+1 rows into a buffer sized selected_k+1.

Fix 5 (modeling_qeff.py): Update _compile() type annotation from
  Optional[int] to Optional[Union[int, List[int]]] and update docstring.

Nits
----
Nit 6: Rename k/v → key/val in flat_specs dict-comprehension (modeling_qeff.py).
Nit 7: Remove unused **kwargs from _build_decode_spec_for_k (modeling_auto.py).
Bonus: Remove dead `has_empty_tokens = False` initializer (prompt_lookup.py).

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
- Files Changed table: annotate review-fix changes across all 3 files
- _select_k description: add empty-batch safety note
- New "Post-Review Fixes" section (errors 8-12):
  8. _select_k crash on empty batch (Critical)
  9. Silent discard of user Ks when speculative_config overrides (Critical)
  10. TLM + CCL produces wrong specializations (Major)
  11. set_buffers state leak on exception (Major)
  12. _compile() typed Optional[int] instead of Optional[Union[int,List[int]]] (Major)
- modeling_qeff.py Commits table: add all 5 commits on spd_specs

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
The class reimplemented the dispatch loop inline (_run_dispatch) rather
than calling pld_spec_decode_inference directly, so it gave false
confidence: a change to the real function would leave all six tests
passing.  Coverage of the dispatch path is provided by the on-qaic
test_multi_spec_qpc_logit_correctness; _select_k is still tested
directly via test_select_k.

Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
@quic-rishinr quic-rishinr added the 1.22 Release 1.22 candidate label May 25, 2026
@quic-rishinr quic-rishinr changed the base branch from main to release/v1.22.0_tmp May 25, 2026 16:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

1.22 Release 1.22 candidate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants