support multiple TLM decode specializations via num_speculative_tokens list #984
Open
eplatero97 wants to merge 11 commits into
Open
support multiple TLM decode specializations via num_speculative_tokens list #984eplatero97 wants to merge 11 commits into
eplatero97 wants to merge 11 commits into
Conversation
added 10 commits
May 13, 2026 12:59
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
num_speculative_tokens now accepts List[int] in QEFFAutoModelForCausalLM.compile().
Each K in the list compiles one decode specialization: seq_len=K+1,
num_logits_to_keep=K+1. Passing a plain int still works (backward compat).
Key changes
-----------
* modeling_auto.py
- _build_decode_spec_for_k(k, ...) — new private helper that builds one
decode specialisation for proposal length K; returns None when it would
duplicate the prefill spec.
- compile(): accept List[int] for num_speculative_tokens; sort + deduplicate;
loop over K values to build specialisations; validate only max-K against
model config.
- Fix batch_size assignment in _build_decode_spec_for_k for CB mode.
- check_and_get_num_speculative_tokens: forward override from model config
back to caller so _decode_ks can be adjusted.
* prompt_lookup.py
- _select_k() — picks the cheapest specialisation that covers the batch.
- pld_spec_decode_inference() — per-K buffer allocation; per-step dispatch.
- Remove --enable-fallback-decode-spec flag (superseded by [0, K] list).
* test_modeling_auto_cpu.py
- TestTLMMultiSpecSpecializations (8 tests): structure + dedup + sort +
backward compat.
- test_multi_spec_structure (4 parametrized): per-K spec fields.
- test_select_k (5 parametrized): dispatch helper.
- test_tlm_multi_spec_logit_consistency (4): causal attention invariant.
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
The MDP (multi-device partition) firmware for 4-device tensor-parallel only
accepts flat specialization JSON:
{"batch_size": "4", "seq_len": "5", "ctx_len": "2048", ...}
Previously _compile() called to_named_specializations() which wraps each entry
as {"name": "Prefill_0", "symbols": {...}}. This named format is only required
by the QNN compiler path (handled separately before reaching this code). Using
named format for qaic-compile with MDP produces a binary that fails at ExecObj
creation time with "Failed to create ExecObj".
Fix: strip the internal _graph_name tag and write the flat list directly.
The QNN path branches off early (line 548) and never reaches this block, so it
is not affected.
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
qaic-compile requires all values in specializations.json to be strings
(e.g. "4" not 4). The to_named_specializations() call that was removed in
the previous commit was implicitly performing this str() conversion as part
of building the {name, symbols} wrapper. Without it, integer values from
_build_decode_spec_for_k and build_prefill_specialization were written
directly to JSON, causing qaic-compile to exit with code 255.
Fix: apply str() to every value while stripping _graph_name tags.
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Add two new "Errors Encountered" entries (6 & 7) documenting the named-format MDP firmware rejection and integer-value rejection found during vLLM integration. Add new "modeling_qeff.py — Flat-Format specializations.json" section explaining: - what changed in _compile() and why - the _graph_name internal tag mechanism - why to_named_specializations() is still needed (QNN path) - the correct 3-spec output format - commit references (5f1b6c5, 82f2d6c) Also add modeling_qeff.py to the Files Changed table. Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Addresses all 5 required fixes from code review: Critical -------- Fix 1 (prompt_lookup.py): Guard _select_k against empty actual_proposals array. When all batch items finish generating, valid_batch_indices is all-False and actual_proposals[valid_batch_indices] is empty; .max() raised ValueError. Return decode_ks[-1] (max K) for empty input. Fix 2 (modeling_auto.py): Emit logger.warning when speculative_config in model.config overrides a user-supplied list, naming the discarded values. Previously the list was silently truncated to [validated_k]. Major ----- Fix 3 (modeling_auto.py): Raise NotImplementedError when TLM multi-spec and comp_ctx_lengths_decode are both active. The combination would compile specializations missing comp_ctx_lengths and fail at runtime on CCL hardware. Option B (early rejection) is safer than silently wrong QPCs. Fix 4 (prompt_lookup.py): Wrap smaller-K session.run() in try/finally to restore the logit buffer placeholder. Without this, an exception leaves the session pointing at an undersized buffer; the next max-K path would write max_k+1 rows into a buffer sized selected_k+1. Fix 5 (modeling_qeff.py): Update _compile() type annotation from Optional[int] to Optional[Union[int, List[int]]] and update docstring. Nits ---- Nit 6: Rename k/v → key/val in flat_specs dict-comprehension (modeling_qeff.py). Nit 7: Remove unused **kwargs from _build_decode_spec_for_k (modeling_auto.py). Bonus: Remove dead `has_empty_tokens = False` initializer (prompt_lookup.py). Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
- Files Changed table: annotate review-fix changes across all 3 files - _select_k description: add empty-batch safety note - New "Post-Review Fixes" section (errors 8-12): 8. _select_k crash on empty batch (Critical) 9. Silent discard of user Ks when speculative_config overrides (Critical) 10. TLM + CCL produces wrong specializations (Major) 11. set_buffers state leak on exception (Major) 12. _compile() typed Optional[int] instead of Optional[Union[int,List[int]]] (Major) - modeling_qeff.py Commits table: add all 5 commits on spd_specs Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
The class reimplemented the dispatch loop inline (_run_dispatch) rather than calling pld_spec_decode_inference directly, so it gave false confidence: a change to the real function would leave all six tests passing. Coverage of the dispatch path is provided by the on-qaic test_multi_spec_qpc_logit_correctness; _select_k is still tested directly via test_select_k. Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
Signed-off-by: eplatero <eplatero@qti.qualcomm.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Speculative decoding on QAIC requires statically compiled shapes. Previously, a TLM
could only be compiled for a single proposal length K, forcing every decode step to
run the full
seq_len=K+1kernel regardless of how many tokens the draft modelactually proposed. This PR allows compiling multiple decode specializations in one
QPC so the runtime can dispatch to the smallest kernel that covers the actual proposal
count, reducing unnecessary compute on short proposals without changing correctness.
Summary
QEFFAutoModelForCausalLM.compile()now acceptsnum_speculative_tokensas aList[int]. Each value K compiles one TLM decode specialization (seq_len=K+1,num_logits_to_keep=K+1), enabling per-step dispatch to the cheapest kernel thatcovers the actual proposal count.
intinput still works (backward compatible — treated as[K]).enable_fallback_decode_spec; equivalent behavior is[0, K].specializations.jsonwrite in_compile(named format causedRuntimeError: Failed to create ExecObjon 4-device MDP QPCs).find_candidate_pred_tokensto return the longest n-gram continuationacross all candidates instead of early-returning on the first match.
Results
Measured on Llama-3.1-8B-Instruct (mxfp6/mxint8, 4 SOCs), MT-bench, 80 prompts.
num_speculative_tokens=[0, 4](K=4) vs the prior fixed-K baseline.Request throughput vs nospec (req/s):
varK vs fixed-K improvement:
Key takeaways:
Test plan
pytest tests/unit_test/models/test_modeling_auto_cpu.py::TestTLMMultiSpecSpecializations -vpytest tests/unit_test/transforms/test_speculative_decoding.py::TestTLMForwardExecution::test_tlm_multi_spec_logit_consistency -vpytest tests/transformers/spd/test_pld_inference.py::test_multi_spec_structure tests/transformers/spd/test_pld_inference.py::test_select_k -vpytest tests/transformers/spd/test_spd_inference.py -m on_qaic -k "pld"(on QAIC device)Notes
comp_ctx_lengths_decode) combination raisesNotImplementedError— not yet supported.speculative_configin model config overrides user-supplied Ks; alogger.warningis emitted when values are discarded.