From 18cf2261cb444d915db6fadaedf679a53b6fd3c2 Mon Sep 17 00:00:00 2001 From: eplatero Date: Thu, 7 May 2026 11:21:53 -0500 Subject: [PATCH 01/12] Add variable seq-len to SpD methods Signed-off-by: eplatero --- .../transformers/models/modeling_auto.py | 103 +++++++++++-- .../speculative_decoding/prompt_lookup.py | 71 +++++++-- tests/transformers/spd/test_pld_inference.py | 71 +++++++++ .../models/test_modeling_auto_cpu.py | 137 +++++++++++++++++- 4 files changed, 351 insertions(+), 31 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4ad56592fb..94510686f5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3370,6 +3370,59 @@ def build_decode_specialization( result["_graph_name"] = "Decode" return result + def _build_decode_spec_for_k( + self, + k: int, + ctx_len: int = 128, + batch_size: int = 1, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + prefill_seq_len: int = 32, + comp_ctx_lengths: Optional[int] = None, + **kwargs, + ): + """ + Builds a TLM decode specialization for proposal length *k* (``seq_len = k+1``). + + Parameters + ---------- + k : int + Number of speculative (draft) tokens. ``seq_len`` and ``num_logits_to_keep`` + are both set to ``k + 1``. + ctx_len : int, optional + Maximum context length. Default is 128. + batch_size : int, optional + Batch size. Default is 1. + kv_cache_batch_size : int, optional + Batch size for KV cache allocation. + full_batch_size : int, optional + Continuous batching full batch size. + prefill_seq_len : int, optional + Used to detect and skip duplicate specializations (when ``seq_len == prefill_seq_len`` + and continuous batching is disabled). + + Returns + ------- + Optional[Dict[str, Union[int, str]]] + Specialization dict, or ``None`` if it would duplicate the prefill specialization. + """ + seq_len = k + 1 + if seq_len == prefill_seq_len and not self.continuous_batching: + return None + spec = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": seq_len, + "ctx_len": ctx_len, + "num_logits_to_keep": seq_len, + } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: + spec["full_batch_size"] = kv_cache_batch_size + else: + spec["batch_size"] = kv_cache_batch_size + return {k_: v for k_, v in spec.items() if v is not None} + def compile( self, onnx_path: Optional[str] = None, @@ -3386,7 +3439,7 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, + num_speculative_tokens: Optional[List[int]] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, @@ -3428,9 +3481,12 @@ def compile( Use MXFP6 compression for weights. Default is False. mxint8_kv_cache : bool, optional Use MXINT8 compression for KV cache. Default is False. - num_speculative_tokens : int, optional - Number of speculative tokens for Speculative Decoding Target Language Model. - Required if the model is configured as a Target Language Model (`is_tlm=True`). + num_speculative_tokens : list[int], optional + List of proposal lengths for Speculative Decoding Target Language Model. + Each value K generates a decode specialization with seq_len=K+1 and + num_logits_to_keep=K+1. Include 0 to compile a cheap single-token fallback + (e.g. ``[0, 3]`` for a fallback + full K=3 decode). Required if the model is + configured as a Target Language Model (``is_tlm=True``). prefill_only : bool, optional If True, compiles only for the prefill stage. If False, compiles only for the decode stage. If None, compiles for both stages. Default is None. @@ -3465,10 +3521,10 @@ def compile( TypeError If `prefill_only` is not a boolean. If `full_batch_size` is None when `continuous_batching` is True. - If `num_speculative_tokens` is None when the model is a TLM. + If `num_speculative_tokens` is None or empty when the model is a TLM. ValueError If KV caching is requested without continuous batching (`full_batch_size`). - If `include_sampler` is True and `num_speculative_tokens` is greater than 0. + If `include_sampler` is True and `num_speculative_tokens` contains a value > 0. If `num_speculative_tokens` is not an integer greater than 1. If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. @@ -3533,14 +3589,21 @@ def compile( if prefill_only is not None and not isinstance(prefill_only, bool): raise TypeError("`prefill_only` must be a boolean.") + _decode_ks = ( + sorted(set(num_speculative_tokens)) + if isinstance(num_speculative_tokens, (list, tuple)) + else ([num_speculative_tokens] if num_speculative_tokens is not None else None) + ) + if self.is_tlm: - num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) + _max_k = _decode_ks[-1] if _decode_ks else None + self.check_and_get_num_speculative_tokens(_max_k, prefill_seq_len) if ( self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False) - and num_speculative_tokens is not None - and num_speculative_tokens > 0 + and _decode_ks is not None + and max(_decode_ks) > 0 ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") @@ -3577,8 +3640,22 @@ def compile( ) if (prefill_only is None or not prefill_only) and prefill_seq_len != 1: - if self.comp_ctx_lengths_decode is not None: - # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + if _decode_ks is not None and self.is_tlm: + # TLM multi-spec path: one decode specialization per K in num_speculative_tokens + for k in _decode_ks: + spec = self._build_decode_spec_for_k( + k=k, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + prefill_seq_len=prefill_seq_len, + ) + if spec is not None: + specializations.append(spec) + + elif self.comp_ctx_lengths_decode is not None: + # CCL loop (non-TLM) for i in range(0, len(self.comp_ctx_lengths_decode)): decode_spec = self.build_decode_specialization( prefill_seq_len=prefill_seq_len, @@ -3587,7 +3664,7 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=None, ) if decode_spec: specializations.append(decode_spec) @@ -3599,7 +3676,7 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=None, prefill_only=prefill_only, ) if decode_spec: diff --git a/examples/performance/speculative_decoding/prompt_lookup.py b/examples/performance/speculative_decoding/prompt_lookup.py index 53b1f4e851..eb8d4fbe33 100644 --- a/examples/performance/speculative_decoding/prompt_lookup.py +++ b/examples/performance/speculative_decoding/prompt_lookup.py @@ -196,9 +196,18 @@ def find_candidate_pred_tokens( return np.full(num_pred_tokens, fill_tok, dtype=np.int64), has_empty_tokens +def _select_k(actual_proposals: np.ndarray, decode_ks: List[int]) -> int: + """Return the smallest K in decode_ks that covers the maximum proposal count in the batch.""" + need = int(actual_proposals.max()) + for k in decode_ks: + if k >= need: + return k + return decode_ks[-1] + + def pld_spec_decode_inference( prompts: List[str], - num_speculative_tokens: int, + num_speculative_tokens: Union[int, List[int]], prefill_seq_len: int, ctx_len: int, prefill_bsz: int, @@ -212,7 +221,10 @@ def pld_spec_decode_inference( Args: prompts (List[str]): List of prompts to perform inference on. - num_speculative_tokens (int): Number of speculative tokens. + num_speculative_tokens (Union[int, List[int]]): Number of speculative tokens, or a list + of proposal lengths to compile specializations for. Each value K generates a + specialization with seq_len=K+1. Include 0 for a cheap single-token fallback + (e.g. [0, 3]). A plain int is treated as a single-element list. prefill_seq_len (int): Prefill sequence length. ctx_len (int): Context length. prefill_bsz (int): Prefill batch size. @@ -224,6 +236,8 @@ def pld_spec_decode_inference( Returns: SpDCloudAI100ExecInfo: Execution information, including performance metrics and generated text. """ + decode_ks = sorted(set(num_speculative_tokens)) if isinstance(num_speculative_tokens, list) else [num_speculative_tokens] + max_k = decode_ks[-1] # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size # get vocab size tokenizer = AutoTokenizer.from_pretrained(target_model_name, padding_side="right") @@ -245,7 +259,7 @@ def pld_spec_decode_inference( ctx_len=ctx_len, aic_enable_depth_first=True, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=decode_ks, ) # init qaic session target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) @@ -278,16 +292,18 @@ def pld_spec_decode_inference( # run prefill on both draft and target models # mock input key "logits" to store the first batch of output logits tlm_precode_inputs = dict( - input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), - position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), + input_ids=np.zeros((decode_batch_size, max_k + 1), dtype=np.int64), + position_ids=np.zeros((decode_batch_size, max_k + 1), dtype=np.int64), batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), - num_logits_to_keep=np.arange(num_speculative_tokens + 1, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.arange(max_k + 1, dtype=np.int64).reshape(-1, 1), ) - num_logits_to_keep = num_speculative_tokens + 1 + num_logits_to_keep = max_k + 1 max_gen_len = [ctx_len] * decode_batch_size # setup buffers tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) precode_logits_ph = np.zeros((decode_batch_size, num_logits_to_keep, vocab_size), dtype=np.float32) + # Pre-allocate per-K logit buffers for smaller specializations + logit_buffers = {k: np.zeros((decode_batch_size, k + 1, vocab_size), dtype=np.float32) for k in decode_ks if k != max_k} target_model_session.set_buffers({"logits": tlm_prefill_logits_ph}) e2e_start = perf_counter() @@ -311,7 +327,7 @@ def pld_spec_decode_inference( tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() input_len = prompts_tokenized[bi]["position_ids"].max(1).item() + 1 tlm_precode_inputs["position_ids"][bi] = np.arange( - input_len, input_len + num_speculative_tokens + 1, dtype=np.int64 + input_len, input_len + max_k + 1, dtype=np.int64 ) # assumes that prefill queue will always be popped from the front input_lengths[bi] = input_len @@ -329,7 +345,7 @@ def pld_spec_decode_inference( decode_start = perf_counter() mean_num_accepted_tokens = 0 all_accept = np.full(decode_batch_size, False, dtype=bool) - tlm_position_ids = np.arange(num_speculative_tokens + 1).reshape(1, -1).repeat(decode_batch_size, axis=0) + tlm_position_ids = np.arange(max_k + 1).reshape(1, -1).repeat(decode_batch_size, axis=0) empty_indices = np.zeros(decode_batch_size, dtype=bool) decode_draft_time = 0.0 decode_target_time = 0.0 @@ -347,7 +363,7 @@ def pld_spec_decode_inference( all_ids[bi : bi + 1, : prompt_plus_gen_idx[bi]], fill_tok=-1, max_ngram_size=max_ngram_size, - num_pred_tokens=num_speculative_tokens, + num_pred_tokens=max_k, ) empty_indices[bi] = has_empty_tokens # prepare target model inputs @@ -360,15 +376,33 @@ def pld_spec_decode_inference( decode_draft_time += draft_end # run precode on TLM to score the proposed tokens target_start = perf_counter() - tlm_outputs = target_model_session.run(tlm_precode_inputs) - target_logits = tlm_outputs["logits"] + # Pick the smallest K specialization that covers the actual proposal count + actual_proposals = np.where(empty_indices, 0, max_k) + selected_k = _select_k(actual_proposals[valid_batch_indices], decode_ks) + if selected_k == max_k: + tlm_outputs = target_model_session.run(tlm_precode_inputs) + target_logits = tlm_outputs["logits"] + else: + sel_inputs = { + "input_ids": tlm_precode_inputs["input_ids"][:, : selected_k + 1], + "position_ids": tlm_precode_inputs["position_ids"][:, : selected_k + 1], + "batch_index": tlm_precode_inputs["batch_index"], + "num_logits_to_keep": np.arange(selected_k + 1, dtype=np.int64).reshape(-1, 1), + } + target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) + tlm_outputs = target_model_session.run(sel_inputs) + raw_logits = tlm_outputs["logits"] # [batch, selected_k+1, vocab] + target_model_session.set_buffers({"logits": precode_logits_ph}) + # Pad to [batch, max_k+1] so downstream acceptance logic is unchanged + pad = np.zeros((decode_batch_size, max_k - selected_k, vocab_size), dtype=np.float32) + target_logits = np.concatenate([raw_logits, pad], axis=1) # greedy sampling from target model target_tokens = target_logits.argmax(-1) target_end = perf_counter() - target_start decode_target_time += target_end # exact matching between draft and target tokens num_tokens_selected = np.ones(decode_batch_size, dtype=np.int64) - tlm_precode_position_ids = np.full((decode_batch_size, num_speculative_tokens + 1), -1, dtype=np.int64) + tlm_precode_position_ids = np.full((decode_batch_size, max_k + 1), -1, dtype=np.int64) non_empty_valid_indices = ~empty_indices & valid_batch_indices matching = ( tlm_precode_inputs["input_ids"][non_empty_valid_indices, 1:] == target_tokens[non_empty_valid_indices, :-1] @@ -383,7 +417,7 @@ def pld_spec_decode_inference( non_empty_valid_indices ] + num_tokens_selected[non_empty_valid_indices].reshape(-1, 1) # record accepted tokens - all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == num_speculative_tokens + 1 + all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == max_k + 1 mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item() # append selected tokens to the generated_ids for bi, valid in enumerate(valid_batch_indices): @@ -439,7 +473,7 @@ def pld_spec_decode_inference( batch_decode, generated_ids, perf_metrics, - num_speculative_tokens, + max_k, prefill_seq_len, ctx_len, prefill_bsz, @@ -457,7 +491,12 @@ def comma_separated_ints(x: str): def arg_parse(): parser = ArgumentParser(description="Draft-based SpD Inference") parser.add_argument("--prompts", action="append", default=None, help="Input prompt(s)") - parser.add_argument("--num-speculative-tokens", type=int, default=3, help="Number of speculative tokens") + parser.add_argument( + "--num-speculative-tokens", + type=comma_separated_ints, + default="3", + help="Comma-separated list of proposal lengths (e.g. '0,3' or '3'). Each value K compiles a specialization with seq_len=K+1.", + ) parser.add_argument("--prefill-seq-len", type=int, default=256, help="Prefill sequence length") parser.add_argument("--ctx-len", type=int, default=1024, help="Context length") parser.add_argument("--prefill-bsz", type=int, default=1, help="Prefill batch size") diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 28428394c2..3107d3c1e0 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -477,3 +477,74 @@ def test_dummy_pld_inference(model_id, manual_cleanup): model_config_dict[model_id]["target_model_name"], **model_config_dict[model_id]["additional_params"] ) check_pld_spec_decode_inference(model_id, config=hf_config, manual_cleanup=manual_cleanup) + + +@pytest.mark.parametrize("model_id", test_models_id) +@pytest.mark.parametrize("decode_ks", [[3], [0, 3], [1, 2, 3], [0, 1, 2, 3]]) +def test_multi_spec_structure(model_id, decode_ks): + """ + Verify that _build_decode_spec_for_k produces correct specializations for each K value. + No hardware required. + """ + target_model_name = model_config_dict[model_id]["target_model_name"] + prefill_seq_len = model_config_dict[model_id]["prefill_seq_len"] + ctx_len = model_config_dict[model_id]["ctx_len"] + full_batch_size = model_config_dict[model_id]["full_batch_size"] + continuous_batching = full_batch_size is not None + + target_model = load_qeff_causal_lm_model( + target_model_name, + num_hidden_layers=2, + continuous_batching=continuous_batching, + qaic_config={"speculative_model_type": "target"}, + ) + + kv_cache_batch_size = full_batch_size or 1 + batch_size = 1 + + specs = [] + for k in sorted(set(decode_ks)): + spec = target_model._build_decode_spec_for_k( + k=k, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + prefill_seq_len=prefill_seq_len, + ) + assert spec is not None, f"_build_decode_spec_for_k returned None for k={k}" + assert spec["seq_len"] == k + 1, f"Expected seq_len={k+1}, got {spec['seq_len']}" + assert spec["num_logits_to_keep"] == k + 1, f"Expected num_logits_to_keep={k+1}, got {spec['num_logits_to_keep']}" + assert spec["ctx_len"] == ctx_len + specs.append(spec) + + seq_lens = [s["seq_len"] for s in specs] + assert len(seq_lens) == len(set(seq_lens)), f"Duplicate seq_len values in specs: {seq_lens}" + + +# --------------------------------------------------------------------------- +# _select_k dispatch helper tests (no hardware required) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "actual_proposals,decode_ks,expected_k", + [ + # All batch items have 0 proposals → smallest k >= 0 + (np.array([0, 0, 0]), [0, 3], 0), + # Mix: some proposals, some not → smallest k >= max=3 + (np.array([0, 3, 3]), [0, 3], 3), + # Single spec: always returns only option + (np.array([0, 0]), [3], 3), + # need=2, ks=[0,2,4] → returns 2 + (np.array([1, 2]), [0, 2, 4], 2), + # need exceeds all → returns max + (np.array([5, 5]), [0, 3], 3), + ], +) +def test_select_k(actual_proposals, decode_ks, expected_k): + """_select_k returns the smallest K in decode_ks covering the max actual proposal count.""" + from examples.performance.speculative_decoding.prompt_lookup import _select_k + + result = _select_k(actual_proposals, decode_ks) + assert result == expected_k, f"Expected {expected_k}, got {result} for proposals={actual_proposals}, ks={decode_ks}" diff --git a/tests/unit_test/models/test_modeling_auto_cpu.py b/tests/unit_test/models/test_modeling_auto_cpu.py index 06302b2309..d2000c9186 100644 --- a/tests/unit_test/models/test_modeling_auto_cpu.py +++ b/tests/unit_test/models/test_modeling_auto_cpu.py @@ -980,5 +980,138 @@ def test_export_onnx_has_logits_output(self, tmp_export_dir): qeff = QEFFAutoModelForCTC(model) onnx_path = qeff.export(export_dir=str(tmp_export_dir)) onnx_model = onnx.load(str(onnx_path)) - output_names = {out.name for out in onnx_model.graph.output} - assert "logits" in output_names + output_names_ctc = {out.name for out in onnx_model.graph.output} + assert "logits" in output_names_ctc + + +# --------------------------------------------------------------------------- +# TLM multi-spec specialization unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +@pytest.mark.causal_lm +class TestTLMMultiSpecSpecializations: + """Tests for the multi-spec decode specialization API (num_speculative_tokens as list).""" + + # ---- _build_decode_spec_for_k ---- + + def test_build_decode_spec_for_k_seq_len(self): + """_build_decode_spec_for_k sets seq_len = k+1.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + qeff.is_tlm = True + for k in [0, 1, 3, 7]: + spec = qeff._build_decode_spec_for_k(k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32) + assert spec is not None + assert spec["seq_len"] == k + 1 + + def test_build_decode_spec_for_k_num_logits_to_keep(self): + """_build_decode_spec_for_k sets num_logits_to_keep = k+1.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + qeff.is_tlm = True + for k in [0, 1, 3]: + spec = qeff._build_decode_spec_for_k(k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32) + assert spec["num_logits_to_keep"] == k + 1 + + def test_build_decode_spec_for_k_returns_none_when_duplicate_prefill(self): + """Returns None when seq_len == prefill_seq_len and no continuous batching.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + qeff.is_tlm = True + # k=0 → seq_len=1 == prefill_seq_len=1 → should be None + spec = qeff._build_decode_spec_for_k(k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=1) + assert spec is None + + def test_build_decode_spec_for_k_not_none_with_continuous_batching(self): + """Returns spec even when seq_len == prefill_seq_len if continuous_batching is enabled.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + qeff.is_tlm = True + # k=0 → seq_len=1 == prefill_seq_len=1, but CB is True → should not be None + spec = qeff._build_decode_spec_for_k(k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=2, full_batch_size=2, prefill_seq_len=1) + assert spec is not None + + # ---- compile() specialization count via mock ---- + + def test_compile_list_produces_correct_spec_count(self): + """compile(num_speculative_tokens=[0, 3]) → 1 prefill + 2 decode specializations.""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + try: + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[0, 3]) + except Exception: + pass # _compile may raise; we only care about specializations + + if captured.get("specializations") is not None: + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 2, f"Expected 2 decode specs, got {len(decode_specs)}: {specs}" + + def test_compile_deduplication(self): + """compile(num_speculative_tokens=[3, 3, 3]) → only one decode spec for K=3.""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + try: + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 3, 3]) + except Exception: + pass + + if captured.get("specializations") is not None: + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec (deduplicated), got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 4 + + def test_compile_sorting(self): + """compile(num_speculative_tokens=[3, 1, 2]) → decode specs in ascending seq_len order.""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + try: + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 1, 2]) + except Exception: + pass + + if captured.get("specializations") is not None: + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 3 + seq_lens = [s["seq_len"] for s in decode_specs] + assert seq_lens == sorted(seq_lens), f"Decode specs not in sorted order: {seq_lens}" + + def test_compile_int_backward_compat(self): + """compile(num_speculative_tokens=3) as plain int still works (treated as [3]).""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + try: + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=3) + except Exception: + pass + + if captured.get("specializations") is not None: + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec for int input, got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 4 # k=3 → seq_len=4 + From f07e19f49c09a0917c0d13098681b498932a6041 Mon Sep 17 00:00:00 2001 From: eplatero Date: Fri, 8 May 2026 13:38:17 -0500 Subject: [PATCH 02/12] add functional unit tests Signed-off-by: eplatero --- tests/transformers/spd/conftest.py | 34 +++ tests/transformers/spd/test_spd_inference.py | 209 ++++++++++++++++++ .../transforms/test_speculative_decoding.py | 70 ++++++ 3 files changed, 313 insertions(+) create mode 100644 tests/transformers/spd/conftest.py diff --git a/tests/transformers/spd/conftest.py b/tests/transformers/spd/conftest.py new file mode 100644 index 0000000000..2b3f8d7a7c --- /dev/null +++ b/tests/transformers/spd/conftest.py @@ -0,0 +1,34 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Local conftest for SPD hardware tests. + +Patches QAICInferenceSession so that calls without an explicit device_ids +argument default to [4] instead of device 0. Set QAIC_TEST_DEVICE_ID in +the environment to override (e.g. QAIC_TEST_DEVICE_ID=5 pytest ...). +""" + +import os + +import pytest + +_DEVICE_ID = int(os.environ.get("QAIC_TEST_DEVICE_ID", "4")) + + +@pytest.fixture(autouse=True) +def _use_test_device(monkeypatch): + """Redirect all bare QAICInferenceSession() calls to _DEVICE_ID.""" + from QEfficient.generation.cloud_infer import QAICInferenceSession + + _orig_init = QAICInferenceSession.__init__ + + def _patched_init(self, qpc_path, device_ids=None, **kwargs): + if device_ids is None: + device_ids = [_DEVICE_ID] + _orig_init(self, qpc_path, device_ids=device_ids, **kwargs) + + monkeypatch.setattr(QAICInferenceSession, "__init__", _patched_init) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index a79f17d556..d69e6fcacc 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -375,3 +375,212 @@ def test_dummy_spd_inference(model_id, manual_cleanup): **model_config_dict[model_id]["additional_params"], ) check_spec_decode_inference(model_id, config=hf_config, manual_cleanup=manual_cleanup) + + +# --------------------------------------------------------------------------- +# Multi-spec logit correctness — hardware-level QPC test +# --------------------------------------------------------------------------- + +_MULTI_SPEC_MODEL = "JackFram/llama-68m" +_MULTI_SPEC_NUM_LAYERS = 2 +_MULTI_SPEC_PREFILL_LEN = 32 +_MULTI_SPEC_CTX_LEN = 128 +_MULTI_SPEC_N_STEPS = 8 # decode positions to verify per specialisation +_MULTI_SPEC_PROMPT = "My name is" + + +def _run_prefill(session, tokenized, vocab_size, num_logits_to_keep=None): + """Run chunked prefill and return the logit from the last chunk.""" + inputs = dict(tokenized) + if num_logits_to_keep is not None: + inputs["num_logits_to_keep"] = num_logits_to_keep + ph = np.zeros((1, 1, vocab_size), dtype=np.float32) + session.set_buffers({"logits": ph}) + out = session.run(inputs) + return out["logits"][0, 0, :] # [vocab] + + +def _collect_vanilla_reference(session, first_token, start_pos, vocab_size, n_steps): + """ + Teacher-forced decode: feed ground-truth tokens one at a time and collect + (logit, next_token) at each position. + + Returns: + ref_tokens : list[int] – tokens[i] is fed at position start_pos+i + ref_logits : list[ndarray] – ref_logits[i] is the logit produced after + feeding tokens[i], shape [vocab] + """ + ref_tokens = [int(first_token)] + ref_logits = [] + ph = np.zeros((1, 1, vocab_size), dtype=np.float32) + session.set_buffers({"logits": ph}) + for step in range(n_steps): + out = session.run({ + "input_ids": np.array([[ref_tokens[-1]]], dtype=np.int64), + "position_ids": np.array([[start_pos + step]], dtype=np.int64), + }) + logit = out["logits"][0, 0, :].copy() + ref_logits.append(logit) + ref_tokens.append(int(logit.argmax())) + return ref_tokens, ref_logits + + +def _verify_tlm_spec(tlm_session, k, ref_tokens, ref_logits, start_pos, vocab_size): + """ + Run TLM with seq_len=k+1 (teacher-forced in chunks) and assert that every + output logit matches the corresponding vanilla reference logit. + + Both the accepted-token (argmax) and the full logit vector (atol=5e-2) are + checked. Chunks are non-overlapping; leftover positions at the end are skipped. + + Returns the number of (position, specialisation) pairs that were asserted. + """ + seq_len = k + 1 + n_logits_to_keep = np.arange(seq_len, dtype=np.int64).reshape(-1, 1) + ph = np.zeros((1, seq_len, vocab_size), dtype=np.float32) + tlm_session.set_buffers({"logits": ph}) + + n_assertions = 0 + n_chunks = len(ref_logits) // seq_len + for chunk in range(n_chunks): + chunk_tokens = ref_tokens[chunk * seq_len: chunk * seq_len + seq_len] + chunk_positions = np.array( + [[start_pos + chunk * seq_len + i for i in range(seq_len)]], dtype=np.int64 + ) + out = tlm_session.run({ + "input_ids": np.array([chunk_tokens], dtype=np.int64), + "position_ids": chunk_positions, + "num_logits_to_keep": n_logits_to_keep, + }) + tlm_logits = out["logits"] # [1, seq_len, vocab] + + for i in range(seq_len): + ref_pos = chunk * seq_len + i + ref_logit = ref_logits[ref_pos] + tlm_logit = tlm_logits[0, i, :] + + assert np.allclose(tlm_logit, ref_logit, atol=5e-2), ( + f"K={k}, chunk={chunk}, offset={i} (abs pos {start_pos + ref_pos}): " + f"logit mismatch — max_diff={np.abs(tlm_logit - ref_logit).max():.3e}" + ) + assert int(tlm_logit.argmax()) == int(ref_logit.argmax()), ( + f"K={k}, chunk={chunk}, offset={i} (abs pos {start_pos + ref_pos}): " + f"accepted-token mismatch — " + f"TLM={int(tlm_logit.argmax())} vs ref={int(ref_logit.argmax())}" + ) + n_assertions += 1 + + return n_assertions + + +@pytest.mark.on_qaic +@pytest.mark.feature +@pytest.mark.parametrize( + "decode_ks", + [ + [0], # fallback-only (seq_len=1) + [3], # full-K only (seq_len=4) + [0, 3], # fallback + full-K (typical PLD config) + [1, 2, 3], # suffix-decoding range + ], +) +def test_multi_spec_qpc_logit_correctness(decode_ks, manual_cleanup): + """ + Verify that every decode specialisation in `decode_ks` produces logits that + match the vanilla (DLM) reference at every token position, for ALL output + positions of each specialisation. + + Strategy + -------- + 1. Compile vanilla model (seq_len=1 decode) → collect ref_logits[pos] at each + of _MULTI_SPEC_N_STEPS positions using teacher-forcing with greedy outputs. + 2. Compile TLM with all K values in decode_ks. + 3. For each K: fresh TLM session (resets KV cache) → prefill same prompt → + teacher-forced decode in non-overlapping K+1 chunks → assert ALL K+1 output + logits per chunk match ref_logits at corresponding positions. + + Scaling: adding K values to decode_ks automatically adds new assertions. + """ + tokenizer = AutoTokenizer.from_pretrained(_MULTI_SPEC_MODEL, padding_side="right") + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + vocab_size = len(tokenizer) + + # Tokenise prompt (padded to prefill_seq_len) + raw = tokenizer(_MULTI_SPEC_PROMPT, return_tensors="np") + input_len = int(raw.input_ids.shape[1]) + pad_len = _MULTI_SPEC_PREFILL_LEN + tokenized = tokenizer( + _MULTI_SPEC_PROMPT, + return_tensors="np", + padding="max_length", + max_length=pad_len, + ) + position_ids = np.where( + tokenized.pop("attention_mask"), + np.arange(pad_len), + -1, + ) + prefill_inputs = { + "input_ids": tokenized["input_ids"], + "position_ids": position_ids, + } + + # ── 1. Compile vanilla (DLM) model ────────────────────────────────────── + vanilla = load_qeff_causal_lm_model(_MULTI_SPEC_MODEL, num_hidden_layers=_MULTI_SPEC_NUM_LAYERS) + vanilla_qpc = vanilla.compile( + num_cores=2, + prefill_seq_len=_MULTI_SPEC_PREFILL_LEN, + ctx_len=_MULTI_SPEC_CTX_LEN, + aic_enable_depth_first=True, + ) + + # ── 2. Compile TLM with all decode specialisations ─────────────────────── + tlm = load_qeff_causal_lm_model( + _MULTI_SPEC_MODEL, + num_hidden_layers=_MULTI_SPEC_NUM_LAYERS, + qaic_config={"speculative_model_type": "target"}, + ) + tlm_qpc = tlm.compile( + num_cores=2, + prefill_seq_len=_MULTI_SPEC_PREFILL_LEN, + ctx_len=_MULTI_SPEC_CTX_LEN, + aic_enable_depth_first=True, + num_speculative_tokens=decode_ks, + ) + + # ── 3. Collect vanilla reference logits ────────────────────────────────── + van_session = QAICInferenceSession(vanilla_qpc) + van_session.skip_buffers([x for x in van_session.input_names if x.startswith("past_")]) + van_session.skip_buffers([x for x in van_session.output_names if x.endswith("_RetainedState")]) + + prefill_logit = _run_prefill(van_session, prefill_inputs, vocab_size) + first_token = int(prefill_logit.argmax()) + ref_tokens, ref_logits = _collect_vanilla_reference( + van_session, first_token, input_len, vocab_size, _MULTI_SPEC_N_STEPS + ) + assert len(ref_logits) == _MULTI_SPEC_N_STEPS + + # ── 4. Verify each specialisation ──────────────────────────────────────── + total_assertions = 0 + for k in sorted(set(decode_ks)): + # Fresh TLM session for each K (resets retained KV state) + tlm_session = QAICInferenceSession(tlm_qpc) + tlm_session.skip_buffers([x for x in tlm_session.input_names if x.startswith("past_")]) + tlm_session.skip_buffers([x for x in tlm_session.output_names if x.endswith("_RetainedState")]) + + # Prefill TLM (num_logits_to_keep=[[1]]) + _run_prefill( + tlm_session, + prefill_inputs, + vocab_size, + num_logits_to_keep=np.ones((1, 1), dtype=np.int64), + ) + + n = _verify_tlm_spec(tlm_session, k, ref_tokens, ref_logits, input_len, vocab_size) + assert n > 0, f"K={k}: no positions were verified — check _MULTI_SPEC_N_STEPS vs seq_len" + total_assertions += n + + assert total_assertions > 0 + manual_cleanup([vanilla.onnx_path, tlm.onnx_path]) + diff --git a/tests/unit_test/transforms/test_speculative_decoding.py b/tests/unit_test/transforms/test_speculative_decoding.py index cdffb7c46a..532db225d9 100644 --- a/tests/unit_test/transforms/test_speculative_decoding.py +++ b/tests/unit_test/transforms/test_speculative_decoding.py @@ -380,6 +380,76 @@ def test_tlm_forward_greedy_tokens_in_valid_range(self): assert (greedy_tokens >= 0).all() assert (greedy_tokens < VOCAB_SIZE).all() + @pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 5]) + def test_tlm_multi_spec_logit_consistency(self, num_spec_tokens): + """ + The anchor-token logit from seq_len=1 must equal the anchor-token logit at + position 0 from seq_len=K+1 — for the same input and standard causal attention. + + This is the core correctness guarantee for multi-spec dispatch on QAIC hardware. + + We test this using the raw HuggingFace LlamaForCausalLM (no QEffDynamicCache) + because the eager-mode QEffDynamicCache simulation uses max(position_ids) as the + KV gather limit, which exposes speculative positions to the anchor query and breaks + the property in Python. On QAIC hardware, per-query causal masking is applied + correctly by the hardware attention kernel — the property is verified empirically + by test_few_spd_inference, which asserts mean_num_accepted_tokens == K+1 + (100% acceptance rate when TLM == DLM). + + Why it holds: Standard causal attention masks position P from seeing positions + P+1..P+K, so the hidden state at P is identical regardless of what follows it. + SpDTransform's filter_hidden_states extracts this hidden state at index 0 of the + K+1 output, so the accepted token is always the same. + """ + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + raw_model = LlamaForCausalLM(cfg).eval() + + batch = 1 + anchor_token = torch.randint(0, VOCAB_SIZE, (batch, 1)) + anchor_pos = torch.tensor([[0]], dtype=torch.long) # start of sequence, no past + + # ── seq_len=1: just the anchor ─────────────────────────────────────────────── + with torch.no_grad(): + out_k0 = raw_model( + input_ids=anchor_token, + position_ids=anchor_pos, + ) + logit_k0 = out_k0.logits[:, 0:1, :] # [batch, 1, vocab] + + # ── seq_len=K+1: anchor at position 0, K random speculative tokens ────────── + spec_tokens = torch.randint(0, VOCAB_SIZE, (batch, num_spec_tokens)) + full_input_ids = torch.cat([anchor_token, spec_tokens], dim=1) + full_pos_ids = torch.arange(num_spec_tokens + 1).unsqueeze(0).expand(batch, -1) + + with torch.no_grad(): + out_kK = raw_model( + input_ids=full_input_ids, + position_ids=full_pos_ids, + ) + logit_kK_anchor = out_kK.logits[:, 0:1, :] # anchor is at index 0 + + # The anchor logit must be numerically identical regardless of K + assert torch.allclose(logit_k0, logit_kK_anchor, atol=1e-5), ( + f"Causal property violated: anchor logit differs between seq_len=1 and " + f"seq_len={num_spec_tokens+1}: " + f"max_diff={(logit_k0 - logit_kK_anchor).abs().max().item():.2e}" + ) + # Accepted token (greedy argmax) must also be identical + assert logit_k0.argmax(dim=-1).eq(logit_kK_anchor.argmax(dim=-1)).all(), ( + "Accepted token differs between seq_len=1 and seq_len=K+1 — " + "causal property violated in raw model" + ) + # --------------------------------------------------------------------------- # Tests: SpDTransform for Qwen2 From 30dbe7bd34fefada5e3c4d2e684e40a544248cd6 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 11 May 2026 22:39:32 -0500 Subject: [PATCH 03/12] Variable decode specializations for ngram/suffix SpD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit num_speculative_tokens now accepts List[int] in QEFFAutoModelForCausalLM.compile(). Each K in the list compiles one decode specialization: seq_len=K+1, num_logits_to_keep=K+1. Passing a plain int still works (backward compat). Key changes ----------- * modeling_auto.py - _build_decode_spec_for_k(k, ...) — new private helper that builds one decode specialisation for proposal length K; returns None when it would duplicate the prefill spec. - compile(): accept List[int] for num_speculative_tokens; sort + deduplicate; loop over K values to build specialisations; validate only max-K against model config. - Fix batch_size assignment in _build_decode_spec_for_k for CB mode. - check_and_get_num_speculative_tokens: forward override from model config back to caller so _decode_ks can be adjusted. * prompt_lookup.py - _select_k() — picks the cheapest specialisation that covers the batch. - pld_spec_decode_inference() — per-K buffer allocation; per-step dispatch. - Remove --enable-fallback-decode-spec flag (superseded by [0, K] list). * test_modeling_auto_cpu.py - TestTLMMultiSpecSpecializations (8 tests): structure + dedup + sort + backward compat. - test_multi_spec_structure (4 parametrized): per-K spec fields. - test_select_k (5 parametrized): dispatch helper. - test_tlm_multi_spec_logit_consistency (4): causal attention invariant. Signed-off-by: eplatero --- .../transformers/models/modeling_auto.py | 7 +- .../speculative_decoding/prompt_lookup.py | 38 ++++++---- .../models/test_modeling_auto_cpu.py | 70 ++++++++----------- 3 files changed, 60 insertions(+), 55 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 94510686f5..714924e364 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3410,7 +3410,6 @@ def _build_decode_spec_for_k( if seq_len == prefill_seq_len and not self.continuous_batching: return None spec = { - "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": seq_len, "ctx_len": ctx_len, "num_logits_to_keep": seq_len, @@ -3418,6 +3417,7 @@ def _build_decode_spec_for_k( if comp_ctx_lengths is not None: spec["comp_ctx_lengths"] = comp_ctx_lengths if self.continuous_batching: + spec["batch_size"] = full_batch_size spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size @@ -3597,7 +3597,10 @@ def compile( if self.is_tlm: _max_k = _decode_ks[-1] if _decode_ks else None - self.check_and_get_num_speculative_tokens(_max_k, prefill_seq_len) + validated_k = self.check_and_get_num_speculative_tokens(_max_k, prefill_seq_len) + if validated_k is not None and validated_k != _max_k: + # speculative_config in model.config overrides num_speculative_tokens + _decode_ks = [validated_k] if ( self.model.qaic_config is not None diff --git a/examples/performance/speculative_decoding/prompt_lookup.py b/examples/performance/speculative_decoding/prompt_lookup.py index eb8d4fbe33..7e7e8d1c7c 100644 --- a/examples/performance/speculative_decoding/prompt_lookup.py +++ b/examples/performance/speculative_decoding/prompt_lookup.py @@ -169,6 +169,9 @@ def find_candidate_pred_tokens( raise ValueError("Invalid max_ngram_size or num_pred_tokens") has_empty_tokens = False + best_result = np.full(num_pred_tokens, fill_tok, dtype=np.int64) + best_count = 0 + for ngram_size in range(max_ngram_size, 0, -1): # Extract the last n tokens as our search ngram ngram = input_ids[0, -ngram_size:] @@ -182,18 +185,24 @@ def find_candidate_pred_tokens( # Get the indices of matches match_indices = np.where(matches)[0] - # Iterate through match indices to find a valid continuation + # Iterate through match indices to find the longest available continuation for idx in match_indices: start_idx = idx + ngram_size - end_idx = start_idx + num_pred_tokens - # Ensure we don't go beyond the length of input_ids and avoid self-match - if end_idx <= input_length and start_idx < input_length - ngram_size: - return input_ids[0, start_idx:end_idx], has_empty_tokens + # Avoid self-match + if start_idx >= input_length - ngram_size: + continue + + available = min(input_length - start_idx, num_pred_tokens) + if available > best_count: + best_result = np.full(num_pred_tokens, fill_tok, dtype=np.int64) + best_result[:available] = input_ids[0, start_idx : start_idx + available] + best_count = available + if best_count == num_pred_tokens: + return best_result, False # full match found - # If no match is found, return invalid array - has_empty_tokens = True - return np.full(num_pred_tokens, fill_tok, dtype=np.int64), has_empty_tokens + # has_empty_tokens is True only when zero proposals were found + return best_result, (best_count == 0) def _select_k(actual_proposals: np.ndarray, decode_ks: List[int]) -> int: @@ -366,18 +375,23 @@ def pld_spec_decode_inference( num_pred_tokens=max_k, ) empty_indices[bi] = has_empty_tokens - # prepare target model inputs + # prepare target model inputs — always write spec_tokens (fill_tok for empty slots) + tlm_precode_inputs["input_ids"][bi, 1:] = spec_tokens if has_empty_tokens: # avoid read/write of KV$ for meaningless tokens tlm_precode_inputs["position_ids"][bi, 1:] = -1 else: - tlm_precode_inputs["input_ids"][bi, 1:] = spec_tokens + # For partial matches: mask position_ids for unfilled proposal slots + fill_mask = spec_tokens == -1 + if fill_mask.any(): + tlm_precode_inputs["position_ids"][bi, 1:][fill_mask] = -1 draft_end = perf_counter() - draft_start decode_draft_time += draft_end # run precode on TLM to score the proposed tokens target_start = perf_counter() - # Pick the smallest K specialization that covers the actual proposal count - actual_proposals = np.where(empty_indices, 0, max_k) + # Count actual proposal tokens per batch item (fill_tok=-1 marks unfilled positions) + actual_proposals = (tlm_precode_inputs["input_ids"][:, 1:] != -1).sum(axis=1).astype(np.int64) + actual_proposals[~valid_batch_indices] = 0 selected_k = _select_k(actual_proposals[valid_batch_indices], decode_ks) if selected_k == max_k: tlm_outputs = target_model_session.run(tlm_precode_inputs) diff --git a/tests/unit_test/models/test_modeling_auto_cpu.py b/tests/unit_test/models/test_modeling_auto_cpu.py index d2000c9186..e3c368be48 100644 --- a/tests/unit_test/models/test_modeling_auto_cpu.py +++ b/tests/unit_test/models/test_modeling_auto_cpu.py @@ -1043,16 +1043,13 @@ def test_compile_list_produces_correct_spec_count(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): - try: - qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[0, 3]) - except Exception: - pass # _compile may raise; we only care about specializations + with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[0, 3]) - if captured.get("specializations") is not None: - specs = captured["specializations"] - decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] - assert len(decode_specs) == 2, f"Expected 2 decode specs, got {len(decode_specs)}: {specs}" + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 2, f"Expected 2 decode specs, got {len(decode_specs)}: {specs}" def test_compile_deduplication(self): """compile(num_speculative_tokens=[3, 3, 3]) → only one decode spec for K=3.""" @@ -1062,17 +1059,14 @@ def test_compile_deduplication(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): - try: - qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 3, 3]) - except Exception: - pass + with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 3, 3]) - if captured.get("specializations") is not None: - specs = captured["specializations"] - decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] - assert len(decode_specs) == 1, f"Expected 1 decode spec (deduplicated), got: {decode_specs}" - assert decode_specs[0]["seq_len"] == 4 + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec (deduplicated), got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 4 def test_compile_sorting(self): """compile(num_speculative_tokens=[3, 1, 2]) → decode specs in ascending seq_len order.""" @@ -1082,18 +1076,15 @@ def test_compile_sorting(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): - try: - qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 1, 2]) - except Exception: - pass + with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 1, 2]) - if captured.get("specializations") is not None: - specs = captured["specializations"] - decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] - assert len(decode_specs) == 3 - seq_lens = [s["seq_len"] for s in decode_specs] - assert seq_lens == sorted(seq_lens), f"Decode specs not in sorted order: {seq_lens}" + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 3 + seq_lens = [s["seq_len"] for s in decode_specs] + assert seq_lens == sorted(seq_lens), f"Decode specs not in sorted order: {seq_lens}" def test_compile_int_backward_compat(self): """compile(num_speculative_tokens=3) as plain int still works (treated as [3]).""" @@ -1103,15 +1094,12 @@ def test_compile_int_backward_compat(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda self, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): - try: - qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=3) - except Exception: - pass - - if captured.get("specializations") is not None: - specs = captured["specializations"] - decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] - assert len(decode_specs) == 1, f"Expected 1 decode spec for int input, got: {decode_specs}" - assert decode_specs[0]["seq_len"] == 4 # k=3 → seq_len=4 + with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=3) + + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec for int input, got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 4 # k=3 → seq_len=4 From 56d3c34b470a4b7437fba18a4d752b2d5a3122af Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 11 May 2026 22:40:28 -0500 Subject: [PATCH 04/12] Fix: write flat-format specializations.json for qaic-compile The MDP (multi-device partition) firmware for 4-device tensor-parallel only accepts flat specialization JSON: {"batch_size": "4", "seq_len": "5", "ctx_len": "2048", ...} Previously _compile() called to_named_specializations() which wraps each entry as {"name": "Prefill_0", "symbols": {...}}. This named format is only required by the QNN compiler path (handled separately before reaching this code). Using named format for qaic-compile with MDP produces a binary that fails at ExecObj creation time with "Failed to create ExecObj". Fix: strip the internal _graph_name tag and write the flat list directly. The QNN path branches off early (line 548) and never reaches this block, so it is not affected. Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 443f56a038..f80ae59d1b 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -42,7 +42,6 @@ hash_dict_params, load_json, require_value, - to_named_specializations, ) from QEfficient.utils.export_utils import export_wrapper @@ -635,9 +634,15 @@ def _compile( # Write specializations.json file if specializations is not None: specializations_json = compile_dir / "specializations.json" - specializations_data = { - "specializations": to_named_specializations(specializations, module_name=specialization_module_name) - } + # Strip internal _graph_name tags and write flat format for qaic-compile. + # Named format ({"name": ..., "symbols": {...}}) is only required for the + # QNN path (already branched off above). The qaic-compile binary and its + # MDP (multi-device partition) firmware support only the flat format: + # {"batch_size": "4", "seq_len": "5", ...} + # Using named format for MDP QPCs causes a RuntimeError at ExecObj + # creation time ("Failed to create ExecObj") on 4-device tensor-parallel. + flat_specs = [{k: v for k, v in spec.items() if k != "_graph_name"} for spec in specializations] + specializations_data = {"specializations": flat_specs} create_json(str(specializations_json), specializations_data) command.append(f"-network-specialization-config={specializations_json}") From f9ca07e3c4abcc352457b69d0f130d9577dd4c6c Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 11 May 2026 22:52:19 -0500 Subject: [PATCH 05/12] Fix: convert specialization values to strings for qaic-compile qaic-compile requires all values in specializations.json to be strings (e.g. "4" not 4). The to_named_specializations() call that was removed in the previous commit was implicitly performing this str() conversion as part of building the {name, symbols} wrapper. Without it, integer values from _build_decode_spec_for_k and build_prefill_specialization were written directly to JSON, causing qaic-compile to exit with code 255. Fix: apply str() to every value while stripping _graph_name tags. Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index f80ae59d1b..87d16b0169 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -641,7 +641,8 @@ def _compile( # {"batch_size": "4", "seq_len": "5", ...} # Using named format for MDP QPCs causes a RuntimeError at ExecObj # creation time ("Failed to create ExecObj") on 4-device tensor-parallel. - flat_specs = [{k: v for k, v in spec.items() if k != "_graph_name"} for spec in specializations] + # All values must be strings — qaic-compile rejects integer values. + flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} for spec in specializations] specializations_data = {"specializations": flat_specs} create_json(str(specializations_json), specializations_data) command.append(f"-network-specialization-config={specializations_json}") From 5720ac1cdee5c723204c026e388a323d6bb5140b Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 12 May 2026 10:22:32 -0500 Subject: [PATCH 06/12] Docs: update variable_spd_specializations with modeling_qeff fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add two new "Errors Encountered" entries (6 & 7) documenting the named-format MDP firmware rejection and integer-value rejection found during vLLM integration. Add new "modeling_qeff.py — Flat-Format specializations.json" section explaining: - what changed in _compile() and why - the _graph_name internal tag mechanism - why to_named_specializations() is still needed (QNN path) - the correct 3-spec output format - commit references (5f1b6c5, 82f2d6c) Also add modeling_qeff.py to the Files Changed table. Signed-off-by: eplatero --- docs/variable_spd_specializations.md | 650 +++++++++++++++++++++++++++ 1 file changed, 650 insertions(+) create mode 100644 docs/variable_spd_specializations.md diff --git a/docs/variable_spd_specializations.md b/docs/variable_spd_specializations.md new file mode 100644 index 0000000000..922c148eae --- /dev/null +++ b/docs/variable_spd_specializations.md @@ -0,0 +1,650 @@ +# Variable Speculative Decode Specializations + +## Background + +Speculative decoding on QAIC hardware requires **statically compiled shapes**. Each distinct +input shape must be registered as a "specialization" in `specializations.json` before +`qaic-compile` is invoked. The compiler produces a single QPC binary that dispatches to the +right kernel at runtime based on which input shape is presented. + +Before this change, `QEFFAutoModelForCausalLM.compile()` accepted `num_speculative_tokens: int` +and compiled exactly **two** entries: + +1. **Prefill** — `seq_len = prefill_seq_len`, `num_logits_to_keep = 1` +2. **TLM decode** — `seq_len = K+1`, `num_logits_to_keep = K+1` + +This hard-wired a single proposal length. For PLD/n-gram methods, a second "fallback" +specialization (`seq_len=1`) was available via the now-removed `enable_fallback_decode_spec` +flag. For suffix decoding — where average proposal lengths of 1 or 2 are typical — there was +no way to compile for the right mix. + +--- + +## What Was Implemented + +### 1. `num_speculative_tokens: Optional[List[int]]` in `compile()` (`modeling_auto.py`) + +The parameter type changed from `Optional[int]` to `Optional[List[int]]`. Each element `K` +compiles one decode specialization: `seq_len = K+1`, `num_logits_to_keep = K+1`. The list is +sorted and deduplicated before processing. + +Passing a plain `int` still works — it is silently promoted to `[K]` for backward compatibility +with existing code. + +`enable_fallback_decode_spec` is removed. Its old behavior is exactly +`num_speculative_tokens=[0, K]`. + +### 2. `_build_decode_spec_for_k(k, ...)` private method (`modeling_auto.py`) + +Replaces both `build_decode_specialization` (for TLM) and the removed +`build_fallback_decode_specialization`. Builds one decode specialization dict for a given K. +Returns `None` if the spec would duplicate the prefill spec (guards against `seq_len=1` when +`prefill_seq_len=1` and continuous batching is disabled). + +### 3. `_select_k(actual_proposals, decode_ks)` helper (`prompt_lookup.py`) + +Picks the **smallest K in `decode_ks` that covers the maximum actual proposal count** in the +current batch. This ensures the cheapest specialization is used per iteration while still +covering every batch item. + +`actual_proposals` is an integer array of shape `[batch]` containing the number of non-fill +tokens in each item's proposal slots (`input_ids[:, 1:]`). Items with no valid proposals have +count 0; items with a full match have count `max_k`; items with a partial n-gram match +(continuation shorter than `max_k`, e.g. near the end of the prompt history) have a count +between 1 and `max_k − 1`. + +### 4. Multi-spec runtime dispatch (`prompt_lookup.py`) + +`pld_spec_decode_inference()` now accepts `num_speculative_tokens: Union[int, List[int]]`. +Per-K logit buffers are pre-allocated. Each decode iteration calls `_select_k`, dispatches to +the matching specialization, and pads the logit output back to `[batch, max_K+1, vocab]` so +the downstream token acceptance logic is unchanged. + +--- + +## Files Changed + +| File | Change | +|------|--------| +| `QEfficient/transformers/models/modeling_auto.py` | `num_speculative_tokens: Optional[List[int]]`, added `_build_decode_spec_for_k()`, replaced decode/fallback block with K-loop, removed `enable_fallback_decode_spec` | +| `QEfficient/base/modeling_qeff.py` | `_compile()`: write flat-format specializations.json for qaic-compile; strip `_graph_name` tag; convert values to strings | +| `examples/performance/speculative_decoding/prompt_lookup.py` | `_select_k()` helper, per-K buffer allocation, multi-spec dispatch, updated arg parser | +| `tests/transformers/spd/test_pld_inference.py` | `test_multi_spec_structure` (4 parametrized cases), `test_select_k` (5 parametrized cases) | +| `tests/unit_test/models/test_modeling_auto_cpu.py` | `TestTLMMultiSpecSpecializations` (8 tests) | + +--- + +## Specializations.json — Examples + +All examples use `prefill_seq_len=32`, `ctx_len=128`, `batch_size=1`, `full_batch_size=1` +(continuous batching enabled). + +Each specialization entry contains: + +| Field | Meaning | +|-------|---------| +| `seq_len` | Number of input tokens this kernel accepts | +| `num_logits_to_keep` | Number of output logits returned (always equals `seq_len` for TLM decode) | +| `ctx_len` | Static KV-cache allocation size | +| `batch_size` | Batch size for this phase | +| `full_batch_size` | Continuous-batching full-batch size | + +--- + +### Case 1 — Baseline (plain int, backward compat) + +```python +model.compile(num_speculative_tokens=4) # treated internally as [4] +``` + +```json +{ + "specializations": [ + { + "seq_len": "32", + "num_logits_to_keep": "1", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "5", + "num_logits_to_keep": "5", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + } + ] +} +``` + +| Entry | Role | +|-------|------| +| `seq_len=32` | **Prefill** — processes full prompt, returns 1 logit | +| `seq_len=5` | **TLM decode** — K=4 draft tokens + 1 anchor = 5 positions, returns 5 logits | + +--- + +### Case 2 — PLD with fallback (replaces `enable_fallback_decode_spec=True`) + +```python +model.compile(num_speculative_tokens=[0, 4]) +``` + +```json +{ + "specializations": [ + { + "seq_len": "32", + "num_logits_to_keep": "1", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "1", + "num_logits_to_keep": "1", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "5", + "num_logits_to_keep": "5", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + } + ] +} +``` + +| Entry | Role | +|-------|------| +| `seq_len=32` | **Prefill** | +| `seq_len=1` | **Fallback decode (K=0)** — used when no n-gram matches are found. Cheap single-token forward pass; avoids running the full K+1 kernel wastefully | +| `seq_len=5` | **Full TLM decode (K=4)** — used when n-gram proposals are available | + +At runtime, `_select_k` dispatches to `seq_len=1` when all valid batch items have 0 proposals, +and to `seq_len=5` when any item has proposals (even a partial match). For `decode_ks=[0, 4]` +the dispatch is effectively binary because `find_candidate_pred_tokens` either produces a full +4-token continuation or 0; use `[0, 1, 2, 3, 4]` (Case 4) if you want fine-grained dispatch +for partial continuations near the end of the prompt history. + +--- + +### Case 3 — Intermediate proposal lengths (PLD near end-of-history, or suffix decoding) + +```python +model.compile(num_speculative_tokens=[1, 2]) +``` + +```json +{ + "specializations": [ + { + "seq_len": "32", + "num_logits_to_keep": "1", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "2", + "num_logits_to_keep": "2", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "3", + "num_logits_to_keep": "3", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + } + ] +} +``` + +| Entry | Role | +|-------|------| +| `seq_len=32` | **Prefill** | +| `seq_len=2` | **Short decode (K=1)** — used when max proposal in batch is 1 | +| `seq_len=3` | **Longer decode (K=2)** — used when max proposal in batch is 2 | + +--- + +### Case 4 — Full range (maximum fine-grained dispatch) + +```python +model.compile(num_speculative_tokens=[0, 1, 2, 3, 4]) +``` + +```json +{ + "specializations": [ + { + "seq_len": "32", + "num_logits_to_keep": "1", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "1", + "num_logits_to_keep": "1", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "2", + "num_logits_to_keep": "2", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "3", + "num_logits_to_keep": "3", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "4", + "num_logits_to_keep": "4", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + }, + { + "seq_len": "5", + "num_logits_to_keep": "5", + "ctx_len": "128", + "batch_size": "1", + "full_batch_size": "1" + } + ] +} +``` + +The hardware can dispatch to the cheapest kernel that covers the actual proposal distribution +each decode iteration. Trade-off: 6 specializations vs 2 means a larger QPC binary and longer +compile time, but maximum throughput efficiency at inference. + +--- + +## Runtime Dispatch Logic + +``` +decode iteration + │ + ├── for each batch item: count actual proposals + │ 0 → no n-gram match found + │ 1 .. max_k−1 → partial n-gram match (continuation shorter than max_k, + │ e.g. match is near the end of the prompt history) + │ max_k → full n-gram match + │ + ├── selected_k = _select_k(max_actual_proposals, decode_ks) + │ └── smallest K in decode_ks such that K >= max_actual_proposals + │ (clamps to max(decode_ks) if need exceeds all) + │ + ├── set logit buffer to logit_buffers[selected_k] (shape [batch, K+1, vocab]) + ├── slice input_ids, position_ids to [:, :selected_k+1] + ├── run TLM kernel for selected_k + │ + └── if selected_k < max_k: + pad output → [batch, max_k+1, vocab] (zeros for unused positions) + acceptance logic unchanged regardless of selected_k +``` + +--- + +## Unit Tests Written + +**Total new tests: 17.** Combined with the 41 pre-existing unit tests, the full no-hardware +SPD suite is **58 tests, all passing**. + +### `TestTLMMultiSpecSpecializations` — `test_modeling_auto_cpu.py` (8 tests, CPU-only) + +| Test | What it verifies | +|------|-----------------| +| `test_build_decode_spec_for_k_seq_len` | `_build_decode_spec_for_k(k=K)` → `spec["seq_len"] == K+1` for K in {0,1,3,7} | +| `test_build_decode_spec_for_k_num_logits_to_keep` | Same method → `spec["num_logits_to_keep"] == K+1` | +| `test_build_decode_spec_for_k_returns_none_when_duplicate_prefill` | K=0, `prefill_seq_len=1`, no CB → returns `None` (would duplicate prefill) | +| `test_build_decode_spec_for_k_not_none_with_continuous_batching` | Same scenario but CB=True → spec returned (CB always needs decode) | +| `test_compile_list_produces_correct_spec_count` | `[0,3]` → 1 prefill + 2 decode entries | +| `test_compile_deduplication` | `[3,3,3]` → only 1 decode spec | +| `test_compile_sorting` | `[3,1,2]` → decode specs appear in ascending `seq_len` order | +| `test_compile_int_backward_compat` | `num_speculative_tokens=3` (plain int) → treated as `[3]`, 1 decode spec with `seq_len=4` | + +### `test_multi_spec_structure` — `test_pld_inference.py` (4 parametrized, CPU-only) + +Each `decode_ks` from `[[3], [0, 3], [1, 2, 3], [0, 1, 2, 3]]`: +- Every K produces `spec["seq_len"] == K+1` and `spec["num_logits_to_keep"] == K+1` +- All seq_lens are distinct (no duplicate specs generated) + +### `test_select_k` — `test_pld_inference.py` (5 parametrized, CPU-only) + +| Scenario | `actual_proposals` | `decode_ks` | Expected K | +|---|---|---|---| +| All zeros — all items missed n-gram | `[0, 0, 0]` | `[0, 3]` | `0` | +| Mix — some items have proposals | `[0, 3, 3]` | `[0, 3]` | `3` | +| Single spec — no choice | `[0, 0]` | `[3]` | `3` | +| Exact mid-point — picks smallest fitting K | `[1, 2]` | `[0, 2, 4]` | `2` | +| Need exceeds all — clamps to max | `[5, 5]` | `[0, 3]` | `3` | + +--- + +## Errors Encountered + +### 1. `TypeError: 'int' object is not iterable` + +**Cause:** The initial implementation used `sorted(set(num_speculative_tokens))` unconditionally. +Existing tests in `test_spd_inference.py` and `test_pld_inference.py` pass a plain `int` +directly to `compile()`, which caused `set(4)` to raise. + +**Fix:** Added a type guard before normalizing: + +```python +_decode_ks = ( + sorted(set(num_speculative_tokens)) + if isinstance(num_speculative_tokens, (list, tuple)) + else ([num_speculative_tokens] if num_speculative_tokens is not None else None) +) +``` + +--- + +### 2. Duplicate `@pytest.mark.parametrize("model_id")` decorator + +**Cause:** The old test function `test_fallback_decode_spec_structure` already had a +`@pytest.mark.parametrize("model_id", ...)` decorator. When replacing the function body, the +replacement text included the decorator again, resulting in it appearing twice and pytest +raising `duplicate parametrization of 'model_id'`. + +**Fix:** Removed the duplicate, leaving exactly one `@pytest.mark.parametrize("model_id")` +and one `@pytest.mark.parametrize("decode_ks")`. + +--- + +### 3. Syntax error — missing newline in `arg_parse()` + +**Cause:** When removing the `--enable-fallback-decode-spec` argument block, the closing `)` +of the preceding `add_argument` call merged onto the same line as the next `add_argument` +call, producing a `SyntaxError: Simple statements must be separated by newlines`. + +**Fix:** Restored the newline between the two statements. + +--- + +### 4. Orphaned test method body (`assert "logits" in output_names`) + +**Cause:** The string `assert "logits" in output_names` appeared in two different test +methods in `test_modeling_auto_cpu.py`. The Edit tool matched the first occurrence (in a +different class), replacing it — and leaving the CTC test's method body without its `def` +line, causing an `IndentationError`. + +**Fix:** Restored the `def test_export_onnx_has_logits_output(...)` method header and renamed +the local variable to `output_names_ctc` to make it unique. + +--- + +### 5. `qaic-compile: malloc.c: sysmalloc: Assertion` (SIGABRT) — pre-existing + +**Status: Pre-existing SDK bug, not caused by this change.** + +All hardware compile tests (`test_spd_inference` and `test_pld_inference` full/few/dummy +variants) fail with `Compiler exitcode: -6` — a C-level heap assertion inside the +`qaic-compile` binary itself: + +``` +qaic-compile: malloc.c:2617: sysmalloc: Assertion + `(old_top == initial_top (av) && old_size == 0) || ...` failed. +``` + +The generated `specializations.json` is correct — verified by reading the file from disk +before the compiler crash: + +```json +{ + "specializations": [ + { "seq_len": "32", "num_logits_to_keep": "1", "ctx_len": "128", ... }, + { "seq_len": "5", "num_logits_to_keep": "5", "ctx_len": "128", ... } + ] +} +``` + +The crash reproduces identically regardless of which specialization configuration is used and +is not related to this change. + +> **Update (2026-05-08):** Resolved after SDK upgrade. See Hardware Test Results below. + +--- + +### 6. Named-format specializations.json rejected by MDP firmware (`Failed to create ExecObj`) + +**Discovered during:** vLLM integration testing on 4-device Llama-3.1-8B. + +**Cause:** `_compile()` in `QEfficient/base/modeling_qeff.py` previously called +`to_named_specializations()` to wrap each specialization entry as: + +```json +{"name": "Prefill_0", "symbols": {"batch_size": "4", "seq_len": "128", ...}} +``` + +This named format is **only** required by the QNN compiler path (which branches off early in +`_compile()` and passes specs directly to `qnn_compile()` without touching the JSON file). The +qaic-compile binary and its MDP (multi-device partition) firmware for 4-device +tensor-parallel only accept the legacy **flat format**: + +```json +{"batch_size": "4", "seq_len": "128", "ctx_len": "2048", ...} +``` + +Loading a named-format MDP binary caused `RuntimeError: Failed to create ExecObj` at +`qaicrt.ExecObj(context, program)`. Single-device QPCs (llama-68m, E2E correctness tests) +tolerate the named format — only MDP QPCs are affected. + +**Fix:** `_compile()` now strips the internal `_graph_name` tag and writes flat dicts +directly, bypassing `to_named_specializations()`: + +```python +flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} + for spec in specializations] +specializations_data = {"specializations": flat_specs} +create_json(str(specializations_json), specializations_data) +``` + +**Commit:** `5f1b6c5` — _Fix: write flat-format specializations.json for qaic-compile_ + +--- + +### 7. Integer values in specializations.json rejected by qaic-compile (exit status 255) + +**Discovered during:** first attempt at running the fixed flat-format compile. + +**Cause:** `_build_decode_spec_for_k()` and `build_prefill_specialization()` in +`modeling_auto.py` produce dictionaries with **integer** values (`seq_len=5`, not +`seq_len="5"`). When the previous fix bypassed `to_named_specializations()`, those integers +were written directly to JSON. qaic-compile rejects integer values in +`specializations.json`, exiting with code 255: + +``` +CalledProcessError: Command '['/opt/qti-aic/exec/qaic-compile', ...]' +returned non-zero exit status 255. +``` + +The integer form could be observed in the generated file: + +```json +{"batch_size": 1, "seq_len": 128, ...} ← rejected +``` + +vs. the correct string form: + +```json +{"batch_size": "1", "seq_len": "128", ...} ← accepted +``` + +`to_named_specializations()` had been converting ints to strings as a side-effect of the +`{k: str(v), ...}` construction in `_infer_specialization_name`. The flat-format path missed +this conversion. + +**Fix:** Apply `str()` to every value while stripping `_graph_name`: + +```python +flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} + for spec in specializations] +``` + +**Commit:** `82f2d6c` — _Fix: convert specialization values to strings for qaic-compile_ + +--- + +## `modeling_qeff.py` — Flat-Format specializations.json + +The multi-spec feature in `modeling_auto.py` exposed a latent issue in `_compile()` inside +`QEfficient/base/modeling_qeff.py`: the method was calling `to_named_specializations()` when +writing `specializations.json` for qaic-compile. This had been harmless for 2-spec QPCs +(the converter produced valid output that qaic-compile accepted), but the 3-spec format +surfaced two problems: (a) the `{name, symbols}` wrapper was rejected by the MDP firmware, +and (b) integer values rather than strings caused qaic-compile to exit 255. + +### What changed in `_compile()` + +```python +# Before (called to_named_specializations — produced named format, int values): +specializations_data = { + "specializations": to_named_specializations(specializations, ...) +} + +# After (flat format, all values converted to strings): +flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} + for spec in specializations] +specializations_data = {"specializations": flat_specs} +``` + +**What `_graph_name` is:** an internal routing tag set by `build_prefill_specialization()`, +`_build_decode_spec_for_k()`, and similar helpers so that `to_named_specializations()` can +assign human-readable names (`"Prefill"`, `"Decode"`, etc.). It is not a qaic-compile +field and must be stripped before writing. + +**Why `to_named_specializations()` is still needed:** the QNN compiler path (`enable_qnn=True`) +branches off at line 548 of `_compile()` and calls `qnn_compile()` directly, passing the +raw `specializations` list. `qnn_compile()` internally calls `to_named_specializations()` on +that list. The flat-format fix only affects the qaic-compile branch; QNN is untouched. + +### Correct 3-spec flat-format output + +```json +{ + "specializations": [ + { "batch_size": "1", "ctx_len": "2048", "full_batch_exec_size": "1", + "full_batch_size": "1", "num_logits_to_keep": "1", "seq_len": "128" }, + { "batch_size": "1", "ctx_len": "2048", + "full_batch_size": "1", "num_logits_to_keep": "1", "seq_len": "1" }, + { "batch_size": "1", "ctx_len": "2048", + "full_batch_size": "1", "num_logits_to_keep": "5", "seq_len": "5" } + ] +} +``` + +This matches the format used by all working 2-spec QPCs in the cache (`46f4c4d93f23c51c` +and `f51c4a3e527ccff9`). + +### Commits + +| Hash | Message | +|------|---------| +| `5f1b6c5` | Fix: write flat-format specializations.json for qaic-compile | +| `82f2d6c` | Fix: convert specialization values to strings for qaic-compile | + +--- + +## Hardware Test Results (New SDK — 2026-05-08) +Tests were run on devices 5 and 6 using the `QAIC_TEST_DEVICE_ID` fixture. + +| Test | Device | Result | Duration | +|---|---|---|---| +| `test_few_pld_inference[CB llama]` | 5 | ✅ PASSED | 43s | +| `test_few_spd_inference[CB llama]` | 6 | ✅ PASSED | ~2.5m | +| `test_few_spd_inference[CB qwen]` | 6 | ✅ PASSED | ~2.5m | + +**Observed inference metrics (CB llama, K=4):** +- Avg accepted tokens = **5.0** (= K+1, i.e., 100% acceptance rate when TLM = DLM ✓) +- Decode throughput = **~519 tokens/sec** + +**Note on first re-run failure:** When CB llama and CB qwen compiled back-to-back on the +same device in the same pytest session, the Llama TLM compile (more complex: `num_cores=6`, +`num_logits_to_keep`) failed silently. Running each model sequentially on the same device +(the default behaviour of `pytest` without `-n`) is sufficient to avoid this. + +--- + +## Test Results Summary + +| Test suite | Count | Result | +|---|---|---| +| `test_speculative_decoding.py` (pre-existing unit tests) | 41 | ✅ All pass | +| `TestTLMMultiSpecSpecializations` (new) | 8 | ✅ All pass | +| `test_multi_spec_structure` (new) | 4 | ✅ All pass | +| `test_select_k` (new) | 5 | ✅ All pass | +| `test_tlm_multi_spec_logit_consistency` (new) | 4 | ✅ All pass | +| `test_multi_spec_qpc_logit_correctness` (new, hardware) | 4 | ✅ All pass | +| Hardware few-layers tests (PLD + SPD, new SDK) | 3 | ✅ All pass | + +**Total Python-layer tests: 62 / 62 pass.** +**Total hardware tests: 7 / 7 pass.** + +--- + +## Functional Correctness Guarantee + +### The key property + +For the multi-spec dispatch to be correct, dispatching to a *smaller* specialization (e.g. +`seq_len=1`, K=0) must produce the **same accepted token** as the full specialization +(`seq_len=K+1`) for the same anchor token and KV cache state. + +This holds because Transformer self-attention is **causal**: the hidden state at position P +depends only on positions 0..P, never on positions P+1..K. Therefore: + +``` +tlm_forward(seq_len=1, anchor_at_P, kv_cache) → logit_P +tlm_forward(seq_len=K+1, anchor_at_P + K_specs, kv_cache) → [logit_P, logit_P1, ..., logit_PK] +``` + +`logit_P` is identical in both cases. The accepted token (`argmax(logit_P)`) is therefore +the same regardless of which specialization is dispatched. + +### What the test verifies (`test_tlm_multi_spec_logit_consistency`) + +Added to `TestTLMForwardExecution` in `tests/unit_test/transforms/test_speculative_decoding.py`. +Runs entirely on CPU — no QAIC hardware required. + +For K in {1, 2, 3, 5}: +1. Run `tlm_forward` with `seq_len=1` (single anchor token, `num_logits_to_keep=1`) +2. Run `tlm_forward` with `seq_len=K+1` (anchor + K speculative tokens, `num_logits_to_keep=K+1`) + — same anchor, same KV cache, speculative tokens are random +3. Assert `logits_k0[0,0,:] ≈ logits_kK[0,0,:]` (anchor logit is identical, `atol=1e-5`) +4. Assert `argmax(logits_k0) == argmax(logits_kK)` (accepted token is identical) + +### What is NOT yet tested (requires hardware) + +- ~~That the KV cache is correctly updated after a `seq_len=1` dispatch (i.e., the next + decode step uses consistent state)~~ **Covered by `test_multi_spec_qpc_logit_correctness` + (KV state must be correct for the subsequent chunk's logits to match)** +- ~~End-to-end sequence equivalence: that a full generation run with dynamic dispatch produces + the same token sequence as a run always using `seq_len=K+1`~~ + +Both properties are now directly verified by `test_multi_spec_qpc_logit_correctness` +(`tests/transformers/spd/test_spd_inference.py`), which runs on real QAIC hardware and checks: +- ALL K+1 logit vectors (not just argmax) match the vanilla DLM reference at every position +- Every accepted token matches +- This holds for `decode_ks` in `[0]`, `[3]`, `[0,3]`, and `[1,2,3]` + +The end-to-end acceptance-rate guarantee is additionally covered by `test_few_spd_inference` +via the `mean_num_accepted_tokens == K+1` assertion. From 710278756d3d3456b2ef19b438bd0f5c7492aadf Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 12 May 2026 11:00:29 -0500 Subject: [PATCH 07/12] =?UTF-8?q?Fix=20review=20findings=20=E2=80=94=20cor?= =?UTF-8?q?rectness=20bugs=20and=20annotation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses all 5 required fixes from code review: Critical -------- Fix 1 (prompt_lookup.py): Guard _select_k against empty actual_proposals array. When all batch items finish generating, valid_batch_indices is all-False and actual_proposals[valid_batch_indices] is empty; .max() raised ValueError. Return decode_ks[-1] (max K) for empty input. Fix 2 (modeling_auto.py): Emit logger.warning when speculative_config in model.config overrides a user-supplied list, naming the discarded values. Previously the list was silently truncated to [validated_k]. Major ----- Fix 3 (modeling_auto.py): Raise NotImplementedError when TLM multi-spec and comp_ctx_lengths_decode are both active. The combination would compile specializations missing comp_ctx_lengths and fail at runtime on CCL hardware. Option B (early rejection) is safer than silently wrong QPCs. Fix 4 (prompt_lookup.py): Wrap smaller-K session.run() in try/finally to restore the logit buffer placeholder. Without this, an exception leaves the session pointing at an undersized buffer; the next max-K path would write max_k+1 rows into a buffer sized selected_k+1. Fix 5 (modeling_qeff.py): Update _compile() type annotation from Optional[int] to Optional[Union[int, List[int]]] and update docstring. Nits ---- Nit 6: Rename k/v → key/val in flat_specs dict-comprehension (modeling_qeff.py). Nit 7: Remove unused **kwargs from _build_decode_spec_for_k (modeling_auto.py). Bonus: Remove dead `has_empty_tokens = False` initializer (prompt_lookup.py). Signed-off-by: eplatero --- QEfficient/base/modeling_qeff.py | 10 +++--- .../transformers/models/modeling_auto.py | 24 ++++++++++++-- .../speculative_decoding/prompt_lookup.py | 31 +++++++++++++------ 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 87d16b0169..dc5ed359f1 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -13,7 +13,7 @@ import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import onnx import torch @@ -486,7 +486,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, - num_speculative_tokens: Optional[int] = None, + num_speculative_tokens: Optional[Union[int, List[int]]] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, @@ -508,7 +508,7 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. - :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. + :num_speculative_tokens (int | List[int], optional): Number of speculative tokens for TLM decode. A plain int K compiles one decode specialization (seq_len=K+1). A list [K0, K1, ...] compiles one specialization per value, enabling per-step dispatch to the cheapest kernel. :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.`` :compiler_options: Pass any compiler option as input. @@ -642,7 +642,9 @@ def _compile( # Using named format for MDP QPCs causes a RuntimeError at ExecObj # creation time ("Failed to create ExecObj") on 4-device tensor-parallel. # All values must be strings — qaic-compile rejects integer values. - flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} for spec in specializations] + flat_specs = [ + {key: str(val) for key, val in spec.items() if key != "_graph_name"} for spec in specializations + ] specializations_data = {"specializations": flat_specs} create_json(str(specializations_json), specializations_data) command.append(f"-network-specialization-config={specializations_json}") diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 714924e364..20fcf2d229 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3379,7 +3379,6 @@ def _build_decode_spec_for_k( full_batch_size: Optional[int] = None, prefill_seq_len: int = 32, comp_ctx_lengths: Optional[int] = None, - **kwargs, ): """ Builds a TLM decode specialization for proposal length *k* (``seq_len = k+1``). @@ -3599,7 +3598,15 @@ def compile( _max_k = _decode_ks[-1] if _decode_ks else None validated_k = self.check_and_get_num_speculative_tokens(_max_k, prefill_seq_len) if validated_k is not None and validated_k != _max_k: - # speculative_config in model.config overrides num_speculative_tokens + # speculative_config in model.config overrides num_speculative_tokens. + # Warn if the user passed a list — the extra values are discarded. + if _decode_ks is not None and len(_decode_ks) > 1: + discarded = [k for k in _decode_ks if k != validated_k] + logger.warning( + f"speculative_config in model.config fixes num_speculative_tokens={validated_k}. " + f"Ignoring user-supplied values {discarded}. " + f"Pass num_speculative_tokens={validated_k} (or [{validated_k}]) to suppress this warning." + ) _decode_ks = [validated_k] if ( @@ -3644,7 +3651,18 @@ def compile( if (prefill_only is None or not prefill_only) and prefill_seq_len != 1: if _decode_ks is not None and self.is_tlm: - # TLM multi-spec path: one decode specialization per K in num_speculative_tokens + # TLM multi-spec path: one decode specialization per K in num_speculative_tokens. + # CCL (comp_ctx_lengths) + multi-spec TLM is not yet supported: the per-K call + # to _build_decode_spec_for_k would need to iterate over CCL values, producing + # len(decode_ks) × len(comp_ctx_lengths_decode) decode specializations whose + # naming and ordering is untested. Reject early so users get a clear error + # instead of a silently wrong QPC. + if self.comp_ctx_lengths_decode is not None: + raise NotImplementedError( + "TLM multi-spec (num_speculative_tokens as a list) combined with " + "comp_ctx_lengths_decode is not yet supported. Pass a plain int for " + "num_speculative_tokens when using CCL." + ) for k in _decode_ks: spec = self._build_decode_spec_for_k( k=k, diff --git a/examples/performance/speculative_decoding/prompt_lookup.py b/examples/performance/speculative_decoding/prompt_lookup.py index 7e7e8d1c7c..4a0b0f8141 100644 --- a/examples/performance/speculative_decoding/prompt_lookup.py +++ b/examples/performance/speculative_decoding/prompt_lookup.py @@ -168,7 +168,6 @@ def find_candidate_pred_tokens( if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length: raise ValueError("Invalid max_ngram_size or num_pred_tokens") - has_empty_tokens = False best_result = np.full(num_pred_tokens, fill_tok, dtype=np.int64) best_count = 0 @@ -206,7 +205,13 @@ def find_candidate_pred_tokens( def _select_k(actual_proposals: np.ndarray, decode_ks: List[int]) -> int: - """Return the smallest K in decode_ks that covers the maximum proposal count in the batch.""" + """Return the smallest K in decode_ks that covers the maximum proposal count in the batch. + + Returns ``decode_ks[-1]`` (max K) when the array is empty — all batch items + have finished generating and no valid proposals remain. + """ + if len(actual_proposals) == 0: + return decode_ks[-1] need = int(actual_proposals.max()) for k in decode_ks: if k >= need: @@ -245,7 +250,9 @@ def pld_spec_decode_inference( Returns: SpDCloudAI100ExecInfo: Execution information, including performance metrics and generated text. """ - decode_ks = sorted(set(num_speculative_tokens)) if isinstance(num_speculative_tokens, list) else [num_speculative_tokens] + decode_ks = ( + sorted(set(num_speculative_tokens)) if isinstance(num_speculative_tokens, list) else [num_speculative_tokens] + ) max_k = decode_ks[-1] # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size # get vocab size @@ -312,7 +319,9 @@ def pld_spec_decode_inference( tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) precode_logits_ph = np.zeros((decode_batch_size, num_logits_to_keep, vocab_size), dtype=np.float32) # Pre-allocate per-K logit buffers for smaller specializations - logit_buffers = {k: np.zeros((decode_batch_size, k + 1, vocab_size), dtype=np.float32) for k in decode_ks if k != max_k} + logit_buffers = { + k: np.zeros((decode_batch_size, k + 1, vocab_size), dtype=np.float32) for k in decode_ks if k != max_k + } target_model_session.set_buffers({"logits": tlm_prefill_logits_ph}) e2e_start = perf_counter() @@ -335,9 +344,7 @@ def pld_spec_decode_inference( generated_ids[bi].append(input_ids.item()) tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() input_len = prompts_tokenized[bi]["position_ids"].max(1).item() + 1 - tlm_precode_inputs["position_ids"][bi] = np.arange( - input_len, input_len + max_k + 1, dtype=np.int64 - ) + tlm_precode_inputs["position_ids"][bi] = np.arange(input_len, input_len + max_k + 1, dtype=np.int64) # assumes that prefill queue will always be popped from the front input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] @@ -404,9 +411,13 @@ def pld_spec_decode_inference( "num_logits_to_keep": np.arange(selected_k + 1, dtype=np.int64).reshape(-1, 1), } target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) - tlm_outputs = target_model_session.run(sel_inputs) - raw_logits = tlm_outputs["logits"] # [batch, selected_k+1, vocab] - target_model_session.set_buffers({"logits": precode_logits_ph}) + try: + tlm_outputs = target_model_session.run(sel_inputs) + raw_logits = tlm_outputs["logits"] # [batch, selected_k+1, vocab] + finally: + # Always restore the max-K placeholder so the next iteration's + # full-K path does not write into an undersized buffer. + target_model_session.set_buffers({"logits": precode_logits_ph}) # Pad to [batch, max_k+1] so downstream acceptance logic is unchanged pad = np.zeros((decode_batch_size, max_k - selected_k, vocab_size), dtype=np.float32) target_logits = np.concatenate([raw_logits, pad], axis=1) From ea26b4a1ab19662ef7961378041e02eff6a7fa73 Mon Sep 17 00:00:00 2001 From: eplatero Date: Tue, 12 May 2026 11:03:48 -0500 Subject: [PATCH 08/12] Docs: add post-review fixes section to variable_spd_specializations.md - Files Changed table: annotate review-fix changes across all 3 files - _select_k description: add empty-batch safety note - New "Post-Review Fixes" section (errors 8-12): 8. _select_k crash on empty batch (Critical) 9. Silent discard of user Ks when speculative_config overrides (Critical) 10. TLM + CCL produces wrong specializations (Major) 11. set_buffers state leak on exception (Major) 12. _compile() typed Optional[int] instead of Optional[Union[int,List[int]]] (Major) - modeling_qeff.py Commits table: add all 5 commits on spd_specs Signed-off-by: eplatero --- docs/variable_spd_specializations.md | 129 +++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 5 deletions(-) diff --git a/docs/variable_spd_specializations.md b/docs/variable_spd_specializations.md index 922c148eae..9a95ae5185 100644 --- a/docs/variable_spd_specializations.md +++ b/docs/variable_spd_specializations.md @@ -53,6 +53,10 @@ count 0; items with a full match have count `max_k`; items with a partial n-gram (continuation shorter than `max_k`, e.g. near the end of the prompt history) have a count between 1 and `max_k − 1`. +**Empty-batch safety:** when all items in the batch have finished generating, +`actual_proposals` is an empty array. The function returns `decode_ks[-1]` (max K) rather +than raising `ValueError: zero-size array to reduction operation maximum`. + ### 4. Multi-spec runtime dispatch (`prompt_lookup.py`) `pld_spec_decode_inference()` now accepts `num_speculative_tokens: Union[int, List[int]]`. @@ -66,9 +70,9 @@ the downstream token acceptance logic is unchanged. | File | Change | |------|--------| -| `QEfficient/transformers/models/modeling_auto.py` | `num_speculative_tokens: Optional[List[int]]`, added `_build_decode_spec_for_k()`, replaced decode/fallback block with K-loop, removed `enable_fallback_decode_spec` | -| `QEfficient/base/modeling_qeff.py` | `_compile()`: write flat-format specializations.json for qaic-compile; strip `_graph_name` tag; convert values to strings | -| `examples/performance/speculative_decoding/prompt_lookup.py` | `_select_k()` helper, per-K buffer allocation, multi-spec dispatch, updated arg parser | +| `QEfficient/transformers/models/modeling_auto.py` | `num_speculative_tokens: Optional[List[int]]`, added `_build_decode_spec_for_k()`, replaced decode/fallback block with K-loop, removed `enable_fallback_decode_spec`; **review:** `speculative_config` override warning, TLM+CCL early rejection, remove unused `**kwargs` | +| `QEfficient/base/modeling_qeff.py` | `_compile()`: write flat-format specializations.json for qaic-compile; strip `_graph_name` tag; convert values to strings; **review:** fix `Optional[Union[int, List[int]]]` annotation, rename `k/v` → `key/val` | +| `examples/performance/speculative_decoding/prompt_lookup.py` | `_select_k()` helper, per-K buffer allocation, multi-spec dispatch, updated arg parser; **review:** empty-batch guard in `_select_k`, `try/finally` for `set_buffers` state leak, remove dead `has_empty_tokens` | | `tests/transformers/spd/test_pld_inference.py` | `test_multi_spec_structure` (4 parametrized cases), `test_select_k` (5 parametrized cases) | | `tests/unit_test/models/test_modeling_auto_cpu.py` | `TestTLMMultiSpecSpecializations` (8 tests) | @@ -504,9 +508,119 @@ flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} --- -## `modeling_qeff.py` — Flat-Format specializations.json +## Post-Review Fixes + +A code review of the `spd_specs` branch identified 5 correctness bugs (2 critical, 3 major) +that were addressed in commit `5bf8ff1`. The findings and fixes are summarised below. +The full review is at `docs/spd_varK/code_review.md`. + +### 8. `_select_k` crashes on empty batch (Critical) + +**File:** `prompt_lookup.py` + +When all batch items finish generating, `valid_batch_indices` is all-False and +`actual_proposals[valid_batch_indices]` is an empty array. Calling `.max()` on an empty +array raises `ValueError: zero-size array to reduction operation maximum which has no +identity`. + +**Fix:** Return `decode_ks[-1]` immediately for empty input: + +```python +if len(actual_proposals) == 0: + return decode_ks[-1] +``` + +--- + +### 9. Silent discard of user-supplied Ks when `speculative_config` overrides (Critical) + +**File:** `modeling_auto.py` + +When `model.config.speculative_config` overrides `num_speculative_tokens`, `_decode_ks` was +silently replaced by `[validated_k]`. A user passing `[0, 1, 3]` would get a single-K +compile with no indication that `[0, 1]` were discarded. + +**Fix:** Emit a `logger.warning` naming the discarded values: + +```python +if validated_k is not None and validated_k != _max_k: + if _decode_ks is not None and len(_decode_ks) > 1: + discarded = [k for k in _decode_ks if k != validated_k] + logger.warning( + f"speculative_config in model.config fixes num_speculative_tokens={validated_k}. " + f"Ignoring user-supplied values {discarded}." + ) + _decode_ks = [validated_k] +``` + +--- + +### 10. TLM + CCL combination produces wrong specializations (Major) + +**File:** `modeling_auto.py` + +`_build_decode_spec_for_k` accepts a `comp_ctx_lengths` parameter, but the TLM multi-spec +loop never passed it. When TLM and CCL (`comp_ctx_lengths_decode`) are both enabled, each +decode specialization is silently missing `comp_ctx_lengths` and will fail at runtime on +CCL-enabled hardware. -The multi-spec feature in `modeling_auto.py` exposed a latent issue in `_compile()` inside +**Fix (option B — early rejection):** Raise `NotImplementedError` before the loop so the +user gets a clear error rather than a silently wrong QPC: + +```python +if self.comp_ctx_lengths_decode is not None: + raise NotImplementedError( + "TLM multi-spec + CCL is not yet supported. " + "Pass a plain int for num_speculative_tokens when using CCL." + ) +``` + +--- + +### 11. `set_buffers` state leak on exception in smaller-K path (Major) + +**File:** `prompt_lookup.py` + +```python +target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) +tlm_outputs = target_model_session.run(sel_inputs) # ← may raise +target_model_session.set_buffers({"logits": precode_logits_ph}) # restore +``` + +If `session.run()` raises, the session's logit buffer is left pointing at the small +`logit_buffers[selected_k]` (shape `[batch, selected_k+1, vocab]`). The next decode +iteration's max-K path writes `max_k+1` rows into it → buffer overwrite or silent truncation. + +**Fix:** `try/finally` to guarantee restoration: + +```python +target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) +try: + tlm_outputs = target_model_session.run(sel_inputs) + raw_logits = tlm_outputs["logits"] +finally: + target_model_session.set_buffers({"logits": precode_logits_ph}) +``` + +--- + +### 12. `_compile()` parameter typed `Optional[int]` instead of `Optional[Union[int, List[int]]]` (Major) + +**File:** `modeling_qeff.py` + +`_compile()` declared `num_speculative_tokens: Optional[int]` but the public `compile()` +interface now passes `Optional[List[int]]`. The docstring also said "int". The hash in +`compile_hash_params` was computed over a list, inconsistently with the annotation. + +**Fix:** Update annotation and docstring: + +```python +num_speculative_tokens: Optional[Union[int, List[int]]] = None, +``` + +--- + +## `modeling_qeff.py` — Flat-Format specializations.json in `_compile()` inside `QEfficient/base/modeling_qeff.py`: the method was calling `to_named_specializations()` when writing `specializations.json` for qaic-compile. This had been harmless for 2-spec QPCs (the converter produced valid output that qaic-compile accepted), but the 3-spec format @@ -559,12 +673,17 @@ and `f51c4a3e527ccff9`). | Hash | Message | |------|---------| +| `767a98a` | Variable decode specializations for ngram/suffix SpD | | `5f1b6c5` | Fix: write flat-format specializations.json for qaic-compile | | `82f2d6c` | Fix: convert specialization values to strings for qaic-compile | +| `7331df7` | Docs: update variable_spd_specializations with modeling_qeff fixes | +| `5bf8ff1` | Fix review findings — correctness bugs and annotation | --- ## Hardware Test Results (New SDK — 2026-05-08) + +After upgrading the QAIC SDK, all hardware compile+inference tests pass. Tests were run on devices 5 and 6 using the `QAIC_TEST_DEVICE_ID` fixture. | Test | Device | Result | Duration | From 9750331f6c477f217299d04c558debe9cab57fe0 Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 13 May 2026 11:14:19 -0500 Subject: [PATCH 09/12] Remove docs/variable_spd_specializations.md Signed-off-by: eplatero --- docs/variable_spd_specializations.md | 769 --------------------------- 1 file changed, 769 deletions(-) delete mode 100644 docs/variable_spd_specializations.md diff --git a/docs/variable_spd_specializations.md b/docs/variable_spd_specializations.md deleted file mode 100644 index 9a95ae5185..0000000000 --- a/docs/variable_spd_specializations.md +++ /dev/null @@ -1,769 +0,0 @@ -# Variable Speculative Decode Specializations - -## Background - -Speculative decoding on QAIC hardware requires **statically compiled shapes**. Each distinct -input shape must be registered as a "specialization" in `specializations.json` before -`qaic-compile` is invoked. The compiler produces a single QPC binary that dispatches to the -right kernel at runtime based on which input shape is presented. - -Before this change, `QEFFAutoModelForCausalLM.compile()` accepted `num_speculative_tokens: int` -and compiled exactly **two** entries: - -1. **Prefill** — `seq_len = prefill_seq_len`, `num_logits_to_keep = 1` -2. **TLM decode** — `seq_len = K+1`, `num_logits_to_keep = K+1` - -This hard-wired a single proposal length. For PLD/n-gram methods, a second "fallback" -specialization (`seq_len=1`) was available via the now-removed `enable_fallback_decode_spec` -flag. For suffix decoding — where average proposal lengths of 1 or 2 are typical — there was -no way to compile for the right mix. - ---- - -## What Was Implemented - -### 1. `num_speculative_tokens: Optional[List[int]]` in `compile()` (`modeling_auto.py`) - -The parameter type changed from `Optional[int]` to `Optional[List[int]]`. Each element `K` -compiles one decode specialization: `seq_len = K+1`, `num_logits_to_keep = K+1`. The list is -sorted and deduplicated before processing. - -Passing a plain `int` still works — it is silently promoted to `[K]` for backward compatibility -with existing code. - -`enable_fallback_decode_spec` is removed. Its old behavior is exactly -`num_speculative_tokens=[0, K]`. - -### 2. `_build_decode_spec_for_k(k, ...)` private method (`modeling_auto.py`) - -Replaces both `build_decode_specialization` (for TLM) and the removed -`build_fallback_decode_specialization`. Builds one decode specialization dict for a given K. -Returns `None` if the spec would duplicate the prefill spec (guards against `seq_len=1` when -`prefill_seq_len=1` and continuous batching is disabled). - -### 3. `_select_k(actual_proposals, decode_ks)` helper (`prompt_lookup.py`) - -Picks the **smallest K in `decode_ks` that covers the maximum actual proposal count** in the -current batch. This ensures the cheapest specialization is used per iteration while still -covering every batch item. - -`actual_proposals` is an integer array of shape `[batch]` containing the number of non-fill -tokens in each item's proposal slots (`input_ids[:, 1:]`). Items with no valid proposals have -count 0; items with a full match have count `max_k`; items with a partial n-gram match -(continuation shorter than `max_k`, e.g. near the end of the prompt history) have a count -between 1 and `max_k − 1`. - -**Empty-batch safety:** when all items in the batch have finished generating, -`actual_proposals` is an empty array. The function returns `decode_ks[-1]` (max K) rather -than raising `ValueError: zero-size array to reduction operation maximum`. - -### 4. Multi-spec runtime dispatch (`prompt_lookup.py`) - -`pld_spec_decode_inference()` now accepts `num_speculative_tokens: Union[int, List[int]]`. -Per-K logit buffers are pre-allocated. Each decode iteration calls `_select_k`, dispatches to -the matching specialization, and pads the logit output back to `[batch, max_K+1, vocab]` so -the downstream token acceptance logic is unchanged. - ---- - -## Files Changed - -| File | Change | -|------|--------| -| `QEfficient/transformers/models/modeling_auto.py` | `num_speculative_tokens: Optional[List[int]]`, added `_build_decode_spec_for_k()`, replaced decode/fallback block with K-loop, removed `enable_fallback_decode_spec`; **review:** `speculative_config` override warning, TLM+CCL early rejection, remove unused `**kwargs` | -| `QEfficient/base/modeling_qeff.py` | `_compile()`: write flat-format specializations.json for qaic-compile; strip `_graph_name` tag; convert values to strings; **review:** fix `Optional[Union[int, List[int]]]` annotation, rename `k/v` → `key/val` | -| `examples/performance/speculative_decoding/prompt_lookup.py` | `_select_k()` helper, per-K buffer allocation, multi-spec dispatch, updated arg parser; **review:** empty-batch guard in `_select_k`, `try/finally` for `set_buffers` state leak, remove dead `has_empty_tokens` | -| `tests/transformers/spd/test_pld_inference.py` | `test_multi_spec_structure` (4 parametrized cases), `test_select_k` (5 parametrized cases) | -| `tests/unit_test/models/test_modeling_auto_cpu.py` | `TestTLMMultiSpecSpecializations` (8 tests) | - ---- - -## Specializations.json — Examples - -All examples use `prefill_seq_len=32`, `ctx_len=128`, `batch_size=1`, `full_batch_size=1` -(continuous batching enabled). - -Each specialization entry contains: - -| Field | Meaning | -|-------|---------| -| `seq_len` | Number of input tokens this kernel accepts | -| `num_logits_to_keep` | Number of output logits returned (always equals `seq_len` for TLM decode) | -| `ctx_len` | Static KV-cache allocation size | -| `batch_size` | Batch size for this phase | -| `full_batch_size` | Continuous-batching full-batch size | - ---- - -### Case 1 — Baseline (plain int, backward compat) - -```python -model.compile(num_speculative_tokens=4) # treated internally as [4] -``` - -```json -{ - "specializations": [ - { - "seq_len": "32", - "num_logits_to_keep": "1", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "5", - "num_logits_to_keep": "5", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - } - ] -} -``` - -| Entry | Role | -|-------|------| -| `seq_len=32` | **Prefill** — processes full prompt, returns 1 logit | -| `seq_len=5` | **TLM decode** — K=4 draft tokens + 1 anchor = 5 positions, returns 5 logits | - ---- - -### Case 2 — PLD with fallback (replaces `enable_fallback_decode_spec=True`) - -```python -model.compile(num_speculative_tokens=[0, 4]) -``` - -```json -{ - "specializations": [ - { - "seq_len": "32", - "num_logits_to_keep": "1", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "1", - "num_logits_to_keep": "1", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "5", - "num_logits_to_keep": "5", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - } - ] -} -``` - -| Entry | Role | -|-------|------| -| `seq_len=32` | **Prefill** | -| `seq_len=1` | **Fallback decode (K=0)** — used when no n-gram matches are found. Cheap single-token forward pass; avoids running the full K+1 kernel wastefully | -| `seq_len=5` | **Full TLM decode (K=4)** — used when n-gram proposals are available | - -At runtime, `_select_k` dispatches to `seq_len=1` when all valid batch items have 0 proposals, -and to `seq_len=5` when any item has proposals (even a partial match). For `decode_ks=[0, 4]` -the dispatch is effectively binary because `find_candidate_pred_tokens` either produces a full -4-token continuation or 0; use `[0, 1, 2, 3, 4]` (Case 4) if you want fine-grained dispatch -for partial continuations near the end of the prompt history. - ---- - -### Case 3 — Intermediate proposal lengths (PLD near end-of-history, or suffix decoding) - -```python -model.compile(num_speculative_tokens=[1, 2]) -``` - -```json -{ - "specializations": [ - { - "seq_len": "32", - "num_logits_to_keep": "1", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "2", - "num_logits_to_keep": "2", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "3", - "num_logits_to_keep": "3", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - } - ] -} -``` - -| Entry | Role | -|-------|------| -| `seq_len=32` | **Prefill** | -| `seq_len=2` | **Short decode (K=1)** — used when max proposal in batch is 1 | -| `seq_len=3` | **Longer decode (K=2)** — used when max proposal in batch is 2 | - ---- - -### Case 4 — Full range (maximum fine-grained dispatch) - -```python -model.compile(num_speculative_tokens=[0, 1, 2, 3, 4]) -``` - -```json -{ - "specializations": [ - { - "seq_len": "32", - "num_logits_to_keep": "1", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "1", - "num_logits_to_keep": "1", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "2", - "num_logits_to_keep": "2", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "3", - "num_logits_to_keep": "3", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "4", - "num_logits_to_keep": "4", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - }, - { - "seq_len": "5", - "num_logits_to_keep": "5", - "ctx_len": "128", - "batch_size": "1", - "full_batch_size": "1" - } - ] -} -``` - -The hardware can dispatch to the cheapest kernel that covers the actual proposal distribution -each decode iteration. Trade-off: 6 specializations vs 2 means a larger QPC binary and longer -compile time, but maximum throughput efficiency at inference. - ---- - -## Runtime Dispatch Logic - -``` -decode iteration - │ - ├── for each batch item: count actual proposals - │ 0 → no n-gram match found - │ 1 .. max_k−1 → partial n-gram match (continuation shorter than max_k, - │ e.g. match is near the end of the prompt history) - │ max_k → full n-gram match - │ - ├── selected_k = _select_k(max_actual_proposals, decode_ks) - │ └── smallest K in decode_ks such that K >= max_actual_proposals - │ (clamps to max(decode_ks) if need exceeds all) - │ - ├── set logit buffer to logit_buffers[selected_k] (shape [batch, K+1, vocab]) - ├── slice input_ids, position_ids to [:, :selected_k+1] - ├── run TLM kernel for selected_k - │ - └── if selected_k < max_k: - pad output → [batch, max_k+1, vocab] (zeros for unused positions) - acceptance logic unchanged regardless of selected_k -``` - ---- - -## Unit Tests Written - -**Total new tests: 17.** Combined with the 41 pre-existing unit tests, the full no-hardware -SPD suite is **58 tests, all passing**. - -### `TestTLMMultiSpecSpecializations` — `test_modeling_auto_cpu.py` (8 tests, CPU-only) - -| Test | What it verifies | -|------|-----------------| -| `test_build_decode_spec_for_k_seq_len` | `_build_decode_spec_for_k(k=K)` → `spec["seq_len"] == K+1` for K in {0,1,3,7} | -| `test_build_decode_spec_for_k_num_logits_to_keep` | Same method → `spec["num_logits_to_keep"] == K+1` | -| `test_build_decode_spec_for_k_returns_none_when_duplicate_prefill` | K=0, `prefill_seq_len=1`, no CB → returns `None` (would duplicate prefill) | -| `test_build_decode_spec_for_k_not_none_with_continuous_batching` | Same scenario but CB=True → spec returned (CB always needs decode) | -| `test_compile_list_produces_correct_spec_count` | `[0,3]` → 1 prefill + 2 decode entries | -| `test_compile_deduplication` | `[3,3,3]` → only 1 decode spec | -| `test_compile_sorting` | `[3,1,2]` → decode specs appear in ascending `seq_len` order | -| `test_compile_int_backward_compat` | `num_speculative_tokens=3` (plain int) → treated as `[3]`, 1 decode spec with `seq_len=4` | - -### `test_multi_spec_structure` — `test_pld_inference.py` (4 parametrized, CPU-only) - -Each `decode_ks` from `[[3], [0, 3], [1, 2, 3], [0, 1, 2, 3]]`: -- Every K produces `spec["seq_len"] == K+1` and `spec["num_logits_to_keep"] == K+1` -- All seq_lens are distinct (no duplicate specs generated) - -### `test_select_k` — `test_pld_inference.py` (5 parametrized, CPU-only) - -| Scenario | `actual_proposals` | `decode_ks` | Expected K | -|---|---|---|---| -| All zeros — all items missed n-gram | `[0, 0, 0]` | `[0, 3]` | `0` | -| Mix — some items have proposals | `[0, 3, 3]` | `[0, 3]` | `3` | -| Single spec — no choice | `[0, 0]` | `[3]` | `3` | -| Exact mid-point — picks smallest fitting K | `[1, 2]` | `[0, 2, 4]` | `2` | -| Need exceeds all — clamps to max | `[5, 5]` | `[0, 3]` | `3` | - ---- - -## Errors Encountered - -### 1. `TypeError: 'int' object is not iterable` - -**Cause:** The initial implementation used `sorted(set(num_speculative_tokens))` unconditionally. -Existing tests in `test_spd_inference.py` and `test_pld_inference.py` pass a plain `int` -directly to `compile()`, which caused `set(4)` to raise. - -**Fix:** Added a type guard before normalizing: - -```python -_decode_ks = ( - sorted(set(num_speculative_tokens)) - if isinstance(num_speculative_tokens, (list, tuple)) - else ([num_speculative_tokens] if num_speculative_tokens is not None else None) -) -``` - ---- - -### 2. Duplicate `@pytest.mark.parametrize("model_id")` decorator - -**Cause:** The old test function `test_fallback_decode_spec_structure` already had a -`@pytest.mark.parametrize("model_id", ...)` decorator. When replacing the function body, the -replacement text included the decorator again, resulting in it appearing twice and pytest -raising `duplicate parametrization of 'model_id'`. - -**Fix:** Removed the duplicate, leaving exactly one `@pytest.mark.parametrize("model_id")` -and one `@pytest.mark.parametrize("decode_ks")`. - ---- - -### 3. Syntax error — missing newline in `arg_parse()` - -**Cause:** When removing the `--enable-fallback-decode-spec` argument block, the closing `)` -of the preceding `add_argument` call merged onto the same line as the next `add_argument` -call, producing a `SyntaxError: Simple statements must be separated by newlines`. - -**Fix:** Restored the newline between the two statements. - ---- - -### 4. Orphaned test method body (`assert "logits" in output_names`) - -**Cause:** The string `assert "logits" in output_names` appeared in two different test -methods in `test_modeling_auto_cpu.py`. The Edit tool matched the first occurrence (in a -different class), replacing it — and leaving the CTC test's method body without its `def` -line, causing an `IndentationError`. - -**Fix:** Restored the `def test_export_onnx_has_logits_output(...)` method header and renamed -the local variable to `output_names_ctc` to make it unique. - ---- - -### 5. `qaic-compile: malloc.c: sysmalloc: Assertion` (SIGABRT) — pre-existing - -**Status: Pre-existing SDK bug, not caused by this change.** - -All hardware compile tests (`test_spd_inference` and `test_pld_inference` full/few/dummy -variants) fail with `Compiler exitcode: -6` — a C-level heap assertion inside the -`qaic-compile` binary itself: - -``` -qaic-compile: malloc.c:2617: sysmalloc: Assertion - `(old_top == initial_top (av) && old_size == 0) || ...` failed. -``` - -The generated `specializations.json` is correct — verified by reading the file from disk -before the compiler crash: - -```json -{ - "specializations": [ - { "seq_len": "32", "num_logits_to_keep": "1", "ctx_len": "128", ... }, - { "seq_len": "5", "num_logits_to_keep": "5", "ctx_len": "128", ... } - ] -} -``` - -The crash reproduces identically regardless of which specialization configuration is used and -is not related to this change. - -> **Update (2026-05-08):** Resolved after SDK upgrade. See Hardware Test Results below. - ---- - -### 6. Named-format specializations.json rejected by MDP firmware (`Failed to create ExecObj`) - -**Discovered during:** vLLM integration testing on 4-device Llama-3.1-8B. - -**Cause:** `_compile()` in `QEfficient/base/modeling_qeff.py` previously called -`to_named_specializations()` to wrap each specialization entry as: - -```json -{"name": "Prefill_0", "symbols": {"batch_size": "4", "seq_len": "128", ...}} -``` - -This named format is **only** required by the QNN compiler path (which branches off early in -`_compile()` and passes specs directly to `qnn_compile()` without touching the JSON file). The -qaic-compile binary and its MDP (multi-device partition) firmware for 4-device -tensor-parallel only accept the legacy **flat format**: - -```json -{"batch_size": "4", "seq_len": "128", "ctx_len": "2048", ...} -``` - -Loading a named-format MDP binary caused `RuntimeError: Failed to create ExecObj` at -`qaicrt.ExecObj(context, program)`. Single-device QPCs (llama-68m, E2E correctness tests) -tolerate the named format — only MDP QPCs are affected. - -**Fix:** `_compile()` now strips the internal `_graph_name` tag and writes flat dicts -directly, bypassing `to_named_specializations()`: - -```python -flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} - for spec in specializations] -specializations_data = {"specializations": flat_specs} -create_json(str(specializations_json), specializations_data) -``` - -**Commit:** `5f1b6c5` — _Fix: write flat-format specializations.json for qaic-compile_ - ---- - -### 7. Integer values in specializations.json rejected by qaic-compile (exit status 255) - -**Discovered during:** first attempt at running the fixed flat-format compile. - -**Cause:** `_build_decode_spec_for_k()` and `build_prefill_specialization()` in -`modeling_auto.py` produce dictionaries with **integer** values (`seq_len=5`, not -`seq_len="5"`). When the previous fix bypassed `to_named_specializations()`, those integers -were written directly to JSON. qaic-compile rejects integer values in -`specializations.json`, exiting with code 255: - -``` -CalledProcessError: Command '['/opt/qti-aic/exec/qaic-compile', ...]' -returned non-zero exit status 255. -``` - -The integer form could be observed in the generated file: - -```json -{"batch_size": 1, "seq_len": 128, ...} ← rejected -``` - -vs. the correct string form: - -```json -{"batch_size": "1", "seq_len": "128", ...} ← accepted -``` - -`to_named_specializations()` had been converting ints to strings as a side-effect of the -`{k: str(v), ...}` construction in `_infer_specialization_name`. The flat-format path missed -this conversion. - -**Fix:** Apply `str()` to every value while stripping `_graph_name`: - -```python -flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} - for spec in specializations] -``` - -**Commit:** `82f2d6c` — _Fix: convert specialization values to strings for qaic-compile_ - ---- - -## Post-Review Fixes - -A code review of the `spd_specs` branch identified 5 correctness bugs (2 critical, 3 major) -that were addressed in commit `5bf8ff1`. The findings and fixes are summarised below. -The full review is at `docs/spd_varK/code_review.md`. - -### 8. `_select_k` crashes on empty batch (Critical) - -**File:** `prompt_lookup.py` - -When all batch items finish generating, `valid_batch_indices` is all-False and -`actual_proposals[valid_batch_indices]` is an empty array. Calling `.max()` on an empty -array raises `ValueError: zero-size array to reduction operation maximum which has no -identity`. - -**Fix:** Return `decode_ks[-1]` immediately for empty input: - -```python -if len(actual_proposals) == 0: - return decode_ks[-1] -``` - ---- - -### 9. Silent discard of user-supplied Ks when `speculative_config` overrides (Critical) - -**File:** `modeling_auto.py` - -When `model.config.speculative_config` overrides `num_speculative_tokens`, `_decode_ks` was -silently replaced by `[validated_k]`. A user passing `[0, 1, 3]` would get a single-K -compile with no indication that `[0, 1]` were discarded. - -**Fix:** Emit a `logger.warning` naming the discarded values: - -```python -if validated_k is not None and validated_k != _max_k: - if _decode_ks is not None and len(_decode_ks) > 1: - discarded = [k for k in _decode_ks if k != validated_k] - logger.warning( - f"speculative_config in model.config fixes num_speculative_tokens={validated_k}. " - f"Ignoring user-supplied values {discarded}." - ) - _decode_ks = [validated_k] -``` - ---- - -### 10. TLM + CCL combination produces wrong specializations (Major) - -**File:** `modeling_auto.py` - -`_build_decode_spec_for_k` accepts a `comp_ctx_lengths` parameter, but the TLM multi-spec -loop never passed it. When TLM and CCL (`comp_ctx_lengths_decode`) are both enabled, each -decode specialization is silently missing `comp_ctx_lengths` and will fail at runtime on -CCL-enabled hardware. - -**Fix (option B — early rejection):** Raise `NotImplementedError` before the loop so the -user gets a clear error rather than a silently wrong QPC: - -```python -if self.comp_ctx_lengths_decode is not None: - raise NotImplementedError( - "TLM multi-spec + CCL is not yet supported. " - "Pass a plain int for num_speculative_tokens when using CCL." - ) -``` - ---- - -### 11. `set_buffers` state leak on exception in smaller-K path (Major) - -**File:** `prompt_lookup.py` - -```python -target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) -tlm_outputs = target_model_session.run(sel_inputs) # ← may raise -target_model_session.set_buffers({"logits": precode_logits_ph}) # restore -``` - -If `session.run()` raises, the session's logit buffer is left pointing at the small -`logit_buffers[selected_k]` (shape `[batch, selected_k+1, vocab]`). The next decode -iteration's max-K path writes `max_k+1` rows into it → buffer overwrite or silent truncation. - -**Fix:** `try/finally` to guarantee restoration: - -```python -target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) -try: - tlm_outputs = target_model_session.run(sel_inputs) - raw_logits = tlm_outputs["logits"] -finally: - target_model_session.set_buffers({"logits": precode_logits_ph}) -``` - ---- - -### 12. `_compile()` parameter typed `Optional[int]` instead of `Optional[Union[int, List[int]]]` (Major) - -**File:** `modeling_qeff.py` - -`_compile()` declared `num_speculative_tokens: Optional[int]` but the public `compile()` -interface now passes `Optional[List[int]]`. The docstring also said "int". The hash in -`compile_hash_params` was computed over a list, inconsistently with the annotation. - -**Fix:** Update annotation and docstring: - -```python -num_speculative_tokens: Optional[Union[int, List[int]]] = None, -``` - ---- - -## `modeling_qeff.py` — Flat-Format specializations.json in `_compile()` inside -`QEfficient/base/modeling_qeff.py`: the method was calling `to_named_specializations()` when -writing `specializations.json` for qaic-compile. This had been harmless for 2-spec QPCs -(the converter produced valid output that qaic-compile accepted), but the 3-spec format -surfaced two problems: (a) the `{name, symbols}` wrapper was rejected by the MDP firmware, -and (b) integer values rather than strings caused qaic-compile to exit 255. - -### What changed in `_compile()` - -```python -# Before (called to_named_specializations — produced named format, int values): -specializations_data = { - "specializations": to_named_specializations(specializations, ...) -} - -# After (flat format, all values converted to strings): -flat_specs = [{k: str(v) for k, v in spec.items() if k != "_graph_name"} - for spec in specializations] -specializations_data = {"specializations": flat_specs} -``` - -**What `_graph_name` is:** an internal routing tag set by `build_prefill_specialization()`, -`_build_decode_spec_for_k()`, and similar helpers so that `to_named_specializations()` can -assign human-readable names (`"Prefill"`, `"Decode"`, etc.). It is not a qaic-compile -field and must be stripped before writing. - -**Why `to_named_specializations()` is still needed:** the QNN compiler path (`enable_qnn=True`) -branches off at line 548 of `_compile()` and calls `qnn_compile()` directly, passing the -raw `specializations` list. `qnn_compile()` internally calls `to_named_specializations()` on -that list. The flat-format fix only affects the qaic-compile branch; QNN is untouched. - -### Correct 3-spec flat-format output - -```json -{ - "specializations": [ - { "batch_size": "1", "ctx_len": "2048", "full_batch_exec_size": "1", - "full_batch_size": "1", "num_logits_to_keep": "1", "seq_len": "128" }, - { "batch_size": "1", "ctx_len": "2048", - "full_batch_size": "1", "num_logits_to_keep": "1", "seq_len": "1" }, - { "batch_size": "1", "ctx_len": "2048", - "full_batch_size": "1", "num_logits_to_keep": "5", "seq_len": "5" } - ] -} -``` - -This matches the format used by all working 2-spec QPCs in the cache (`46f4c4d93f23c51c` -and `f51c4a3e527ccff9`). - -### Commits - -| Hash | Message | -|------|---------| -| `767a98a` | Variable decode specializations for ngram/suffix SpD | -| `5f1b6c5` | Fix: write flat-format specializations.json for qaic-compile | -| `82f2d6c` | Fix: convert specialization values to strings for qaic-compile | -| `7331df7` | Docs: update variable_spd_specializations with modeling_qeff fixes | -| `5bf8ff1` | Fix review findings — correctness bugs and annotation | - ---- - -## Hardware Test Results (New SDK — 2026-05-08) - -After upgrading the QAIC SDK, all hardware compile+inference tests pass. -Tests were run on devices 5 and 6 using the `QAIC_TEST_DEVICE_ID` fixture. - -| Test | Device | Result | Duration | -|---|---|---|---| -| `test_few_pld_inference[CB llama]` | 5 | ✅ PASSED | 43s | -| `test_few_spd_inference[CB llama]` | 6 | ✅ PASSED | ~2.5m | -| `test_few_spd_inference[CB qwen]` | 6 | ✅ PASSED | ~2.5m | - -**Observed inference metrics (CB llama, K=4):** -- Avg accepted tokens = **5.0** (= K+1, i.e., 100% acceptance rate when TLM = DLM ✓) -- Decode throughput = **~519 tokens/sec** - -**Note on first re-run failure:** When CB llama and CB qwen compiled back-to-back on the -same device in the same pytest session, the Llama TLM compile (more complex: `num_cores=6`, -`num_logits_to_keep`) failed silently. Running each model sequentially on the same device -(the default behaviour of `pytest` without `-n`) is sufficient to avoid this. - ---- - -## Test Results Summary - -| Test suite | Count | Result | -|---|---|---| -| `test_speculative_decoding.py` (pre-existing unit tests) | 41 | ✅ All pass | -| `TestTLMMultiSpecSpecializations` (new) | 8 | ✅ All pass | -| `test_multi_spec_structure` (new) | 4 | ✅ All pass | -| `test_select_k` (new) | 5 | ✅ All pass | -| `test_tlm_multi_spec_logit_consistency` (new) | 4 | ✅ All pass | -| `test_multi_spec_qpc_logit_correctness` (new, hardware) | 4 | ✅ All pass | -| Hardware few-layers tests (PLD + SPD, new SDK) | 3 | ✅ All pass | - -**Total Python-layer tests: 62 / 62 pass.** -**Total hardware tests: 7 / 7 pass.** - ---- - -## Functional Correctness Guarantee - -### The key property - -For the multi-spec dispatch to be correct, dispatching to a *smaller* specialization (e.g. -`seq_len=1`, K=0) must produce the **same accepted token** as the full specialization -(`seq_len=K+1`) for the same anchor token and KV cache state. - -This holds because Transformer self-attention is **causal**: the hidden state at position P -depends only on positions 0..P, never on positions P+1..K. Therefore: - -``` -tlm_forward(seq_len=1, anchor_at_P, kv_cache) → logit_P -tlm_forward(seq_len=K+1, anchor_at_P + K_specs, kv_cache) → [logit_P, logit_P1, ..., logit_PK] -``` - -`logit_P` is identical in both cases. The accepted token (`argmax(logit_P)`) is therefore -the same regardless of which specialization is dispatched. - -### What the test verifies (`test_tlm_multi_spec_logit_consistency`) - -Added to `TestTLMForwardExecution` in `tests/unit_test/transforms/test_speculative_decoding.py`. -Runs entirely on CPU — no QAIC hardware required. - -For K in {1, 2, 3, 5}: -1. Run `tlm_forward` with `seq_len=1` (single anchor token, `num_logits_to_keep=1`) -2. Run `tlm_forward` with `seq_len=K+1` (anchor + K speculative tokens, `num_logits_to_keep=K+1`) - — same anchor, same KV cache, speculative tokens are random -3. Assert `logits_k0[0,0,:] ≈ logits_kK[0,0,:]` (anchor logit is identical, `atol=1e-5`) -4. Assert `argmax(logits_k0) == argmax(logits_kK)` (accepted token is identical) - -### What is NOT yet tested (requires hardware) - -- ~~That the KV cache is correctly updated after a `seq_len=1` dispatch (i.e., the next - decode step uses consistent state)~~ **Covered by `test_multi_spec_qpc_logit_correctness` - (KV state must be correct for the subsequent chunk's logits to match)** -- ~~End-to-end sequence equivalence: that a full generation run with dynamic dispatch produces - the same token sequence as a run always using `seq_len=K+1`~~ - -Both properties are now directly verified by `test_multi_spec_qpc_logit_correctness` -(`tests/transformers/spd/test_spd_inference.py`), which runs on real QAIC hardware and checks: -- ALL K+1 logit vectors (not just argmax) match the vanilla DLM reference at every position -- Every accepted token matches -- This holds for `decode_ks` in `[0]`, `[3]`, `[0,3]`, and `[1,2,3]` - -The end-to-end acceptance-rate guarantee is additionally covered by `test_few_spd_inference` -via the `mean_num_accepted_tokens == K+1` assertion. From 8a5077e429bb6f32e0cea4fe4f0970f88acb99ca Mon Sep 17 00:00:00 2001 From: eplatero Date: Wed, 13 May 2026 12:58:16 -0500 Subject: [PATCH 10/12] =?UTF-8?q?Tests:=20drop=20TestVarKDispatch=20?= =?UTF-8?q?=E2=80=94=20tests=20a=20copy=20of=20logic,=20not=20real=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The class reimplemented the dispatch loop inline (_run_dispatch) rather than calling pld_spec_decode_inference directly, so it gave false confidence: a change to the real function would leave all six tests passing. Coverage of the dispatch path is provided by the on-qaic test_multi_spec_qpc_logit_correctness; _select_k is still tested directly via test_select_k. Signed-off-by: eplatero --- tests/transformers/spd/test_pld_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 3107d3c1e0..4721817abc 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -548,3 +548,4 @@ def test_select_k(actual_proposals, decode_ks, expected_k): result = _select_k(actual_proposals, decode_ks) assert result == expected_k, f"Expected {expected_k}, got {result} for proposals={actual_proposals}, ks={decode_ks}" + From 3bbbe6444554bc2671af22a485e3dcc88b61094e Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 18 May 2026 10:21:42 -0500 Subject: [PATCH 11/12] Format: fix ruff formatting in test files Signed-off-by: eplatero --- tests/transformers/spd/test_pld_inference.py | 7 +-- tests/transformers/spd/test_spd_inference.py | 33 +++++++------- .../models/test_modeling_auto_cpu.py | 45 +++++++++++++++---- .../transforms/test_speculative_decoding.py | 5 +-- 4 files changed, 60 insertions(+), 30 deletions(-) diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 4721817abc..cef8886475 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -513,8 +513,10 @@ def test_multi_spec_structure(model_id, decode_ks): prefill_seq_len=prefill_seq_len, ) assert spec is not None, f"_build_decode_spec_for_k returned None for k={k}" - assert spec["seq_len"] == k + 1, f"Expected seq_len={k+1}, got {spec['seq_len']}" - assert spec["num_logits_to_keep"] == k + 1, f"Expected num_logits_to_keep={k+1}, got {spec['num_logits_to_keep']}" + assert spec["seq_len"] == k + 1, f"Expected seq_len={k + 1}, got {spec['seq_len']}" + assert spec["num_logits_to_keep"] == k + 1, ( + f"Expected num_logits_to_keep={k + 1}, got {spec['num_logits_to_keep']}" + ) assert spec["ctx_len"] == ctx_len specs.append(spec) @@ -548,4 +550,3 @@ def test_select_k(actual_proposals, decode_ks, expected_k): result = _select_k(actual_proposals, decode_ks) assert result == expected_k, f"Expected {expected_k}, got {result} for proposals={actual_proposals}, ks={decode_ks}" - diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index d69e6fcacc..feb0153e3c 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -415,10 +415,12 @@ def _collect_vanilla_reference(session, first_token, start_pos, vocab_size, n_st ph = np.zeros((1, 1, vocab_size), dtype=np.float32) session.set_buffers({"logits": ph}) for step in range(n_steps): - out = session.run({ - "input_ids": np.array([[ref_tokens[-1]]], dtype=np.int64), - "position_ids": np.array([[start_pos + step]], dtype=np.int64), - }) + out = session.run( + { + "input_ids": np.array([[ref_tokens[-1]]], dtype=np.int64), + "position_ids": np.array([[start_pos + step]], dtype=np.int64), + } + ) logit = out["logits"][0, 0, :].copy() ref_logits.append(logit) ref_tokens.append(int(logit.argmax())) @@ -443,15 +445,15 @@ def _verify_tlm_spec(tlm_session, k, ref_tokens, ref_logits, start_pos, vocab_si n_assertions = 0 n_chunks = len(ref_logits) // seq_len for chunk in range(n_chunks): - chunk_tokens = ref_tokens[chunk * seq_len: chunk * seq_len + seq_len] - chunk_positions = np.array( - [[start_pos + chunk * seq_len + i for i in range(seq_len)]], dtype=np.int64 + chunk_tokens = ref_tokens[chunk * seq_len : chunk * seq_len + seq_len] + chunk_positions = np.array([[start_pos + chunk * seq_len + i for i in range(seq_len)]], dtype=np.int64) + out = tlm_session.run( + { + "input_ids": np.array([chunk_tokens], dtype=np.int64), + "position_ids": chunk_positions, + "num_logits_to_keep": n_logits_to_keep, + } ) - out = tlm_session.run({ - "input_ids": np.array([chunk_tokens], dtype=np.int64), - "position_ids": chunk_positions, - "num_logits_to_keep": n_logits_to_keep, - }) tlm_logits = out["logits"] # [1, seq_len, vocab] for i in range(seq_len): @@ -478,9 +480,9 @@ def _verify_tlm_spec(tlm_session, k, ref_tokens, ref_logits, start_pos, vocab_si @pytest.mark.parametrize( "decode_ks", [ - [0], # fallback-only (seq_len=1) - [3], # full-K only (seq_len=4) - [0, 3], # fallback + full-K (typical PLD config) + [0], # fallback-only (seq_len=1) + [3], # full-K only (seq_len=4) + [0, 3], # fallback + full-K (typical PLD config) [1, 2, 3], # suffix-decoding range ], ) @@ -583,4 +585,3 @@ def test_multi_spec_qpc_logit_correctness(decode_ks, manual_cleanup): assert total_assertions > 0 manual_cleanup([vanilla.onnx_path, tlm.onnx_path]) - diff --git a/tests/unit_test/models/test_modeling_auto_cpu.py b/tests/unit_test/models/test_modeling_auto_cpu.py index e3c368be48..af5f8d2f09 100644 --- a/tests/unit_test/models/test_modeling_auto_cpu.py +++ b/tests/unit_test/models/test_modeling_auto_cpu.py @@ -1002,7 +1002,9 @@ def test_build_decode_spec_for_k_seq_len(self): qeff = QEFFAutoModelForCausalLM(model) qeff.is_tlm = True for k in [0, 1, 3, 7]: - spec = qeff._build_decode_spec_for_k(k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32) + spec = qeff._build_decode_spec_for_k( + k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 + ) assert spec is not None assert spec["seq_len"] == k + 1 @@ -1012,7 +1014,9 @@ def test_build_decode_spec_for_k_num_logits_to_keep(self): qeff = QEFFAutoModelForCausalLM(model) qeff.is_tlm = True for k in [0, 1, 3]: - spec = qeff._build_decode_spec_for_k(k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32) + spec = qeff._build_decode_spec_for_k( + k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 + ) assert spec["num_logits_to_keep"] == k + 1 def test_build_decode_spec_for_k_returns_none_when_duplicate_prefill(self): @@ -1030,7 +1034,9 @@ def test_build_decode_spec_for_k_not_none_with_continuous_batching(self): qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) qeff.is_tlm = True # k=0 → seq_len=1 == prefill_seq_len=1, but CB is True → should not be None - spec = qeff._build_decode_spec_for_k(k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=2, full_batch_size=2, prefill_seq_len=1) + spec = qeff._build_decode_spec_for_k( + k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=2, full_batch_size=2, prefill_seq_len=1 + ) assert spec is not None # ---- compile() specialization count via mock ---- @@ -1043,7 +1049,13 @@ def test_compile_list_produces_correct_spec_count(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[0, 3]) assert captured.get("specializations") is not None, "_compile was not reached" @@ -1059,7 +1071,13 @@ def test_compile_deduplication(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 3, 3]) assert captured.get("specializations") is not None, "_compile was not reached" @@ -1076,7 +1094,13 @@ def test_compile_sorting(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 1, 2]) assert captured.get("specializations") is not None, "_compile was not reached" @@ -1094,7 +1118,13 @@ def test_compile_int_backward_compat(self): qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) captured = {} - with patch.object(type(qeff), "_compile", side_effect=lambda *args, **kw: captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc"): + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=3) assert captured.get("specializations") is not None, "_compile was not reached" @@ -1102,4 +1132,3 @@ def test_compile_int_backward_compat(self): decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] assert len(decode_specs) == 1, f"Expected 1 decode spec for int input, got: {decode_specs}" assert decode_specs[0]["seq_len"] == 4 # k=3 → seq_len=4 - diff --git a/tests/unit_test/transforms/test_speculative_decoding.py b/tests/unit_test/transforms/test_speculative_decoding.py index 532db225d9..3c33fbaec9 100644 --- a/tests/unit_test/transforms/test_speculative_decoding.py +++ b/tests/unit_test/transforms/test_speculative_decoding.py @@ -441,13 +441,12 @@ def test_tlm_multi_spec_logit_consistency(self, num_spec_tokens): # The anchor logit must be numerically identical regardless of K assert torch.allclose(logit_k0, logit_kK_anchor, atol=1e-5), ( f"Causal property violated: anchor logit differs between seq_len=1 and " - f"seq_len={num_spec_tokens+1}: " + f"seq_len={num_spec_tokens + 1}: " f"max_diff={(logit_k0 - logit_kK_anchor).abs().max().item():.2e}" ) # Accepted token (greedy argmax) must also be identical assert logit_k0.argmax(dim=-1).eq(logit_kK_anchor.argmax(dim=-1)).all(), ( - "Accepted token differs between seq_len=1 and seq_len=K+1 — " - "causal property violated in raw model" + "Accepted token differs between seq_len=1 and seq_len=K+1 — causal property violated in raw model" ) From c9001fc58ca461a68cf4611ac17753e31705c119 Mon Sep 17 00:00:00 2001 From: eplatero Date: Mon, 1 Jun 2026 16:20:09 -0500 Subject: [PATCH 12/12] resolve 1st round of comments --- QEfficient/base/modeling_qeff.py | 23 ++++++---- .../transformers/models/modeling_auto.py | 30 ++++++++---- tests/transformers/spd/conftest.py | 34 -------------- tests/transformers/spd/test_pld_inference.py | 2 +- .../models/test_modeling_auto_cpu.py | 46 +++++++++++++++---- 5 files changed, 74 insertions(+), 61 deletions(-) delete mode 100644 tests/transformers/spd/conftest.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index dc5ed359f1..399f2074f0 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -634,14 +634,21 @@ def _compile( # Write specializations.json file if specializations is not None: specializations_json = compile_dir / "specializations.json" - # Strip internal _graph_name tags and write flat format for qaic-compile. - # Named format ({"name": ..., "symbols": {...}}) is only required for the - # QNN path (already branched off above). The qaic-compile binary and its - # MDP (multi-device partition) firmware support only the flat format: - # {"batch_size": "4", "seq_len": "5", ...} - # Using named format for MDP QPCs causes a RuntimeError at ExecObj - # creation time ("Failed to create ExecObj") on 4-device tensor-parallel. - # All values must be strings — qaic-compile rejects integer values. + # Write specializations.json in flat format required by qaic-compile. + # + # Background: internally, specializations are built as plain dicts that + # carry a "_graph_name" tag (e.g. {"_graph_name": "Decode", "seq_len": 5}). + # The QNN compilation path uses a different named format with "name" and + # "symbols" keys — that branch is already handled above and returns early. + # + # The qaic-compile binary and the MDP (multi-device partition) tensor- + # parallel firmware require only the flat key/value format: + # {"batch_size": "4", "seq_len": "5", "ctx_len": "128"} + # Passing anything other than this flat format to a 4-device MDP QPC + # causes the runtime to fail at ExecObj creation time: + # RuntimeError: Failed to create ExecObj + # All values must also be strings — qaic-compile rejects integer values. + # So we strip "_graph_name" and stringify all remaining values here. flat_specs = [ {key: str(val) for key, val in spec.items() if key != "_graph_name"} for spec in specializations ] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 20fcf2d229..3886bc261b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3372,7 +3372,7 @@ def build_decode_specialization( def _build_decode_spec_for_k( self, - k: int, + num_speculative_tokens: int, ctx_len: int = 128, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, @@ -3381,13 +3381,20 @@ def _build_decode_spec_for_k( comp_ctx_lengths: Optional[int] = None, ): """ - Builds a TLM decode specialization for proposal length *k* (``seq_len = k+1``). + Builds one TLM decode specialization for proposal length *num_speculative_tokens* + (``seq_len = num_speculative_tokens + 1``). + + This method intentionally does not call ``self.model.get_specializations()``. + That interface always returns exactly two entries — ``[0]`` for prefill and + ``[1]`` for decode — and cannot express the 1-prefill + N-decode shape that + variable-K requires. ``compile()`` calls this in a loop, once per K value, + to build the full set of decode specializations. Parameters ---------- - k : int + num_speculative_tokens : int Number of speculative (draft) tokens. ``seq_len`` and ``num_logits_to_keep`` - are both set to ``k + 1``. + are both set to ``num_speculative_tokens + 1``. ctx_len : int, optional Maximum context length. Default is 128. batch_size : int, optional @@ -3405,7 +3412,7 @@ def _build_decode_spec_for_k( Optional[Dict[str, Union[int, str]]] Specialization dict, or ``None`` if it would duplicate the prefill specialization. """ - seq_len = k + 1 + seq_len = num_speculative_tokens + 1 if seq_len == prefill_seq_len and not self.continuous_batching: return None spec = { @@ -3420,7 +3427,9 @@ def _build_decode_spec_for_k( spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size - return {k_: v for k_, v in spec.items() if v is not None} + result = {key: v for key, v in spec.items() if v is not None} + result["_graph_name"] = "Decode" + return result def compile( self, @@ -3438,7 +3447,7 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[List[int]] = None, + num_speculative_tokens: Optional[Union[int, List[int]]] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, @@ -3480,8 +3489,9 @@ def compile( Use MXFP6 compression for weights. Default is False. mxint8_kv_cache : bool, optional Use MXINT8 compression for KV cache. Default is False. - num_speculative_tokens : list[int], optional - List of proposal lengths for Speculative Decoding Target Language Model. + num_speculative_tokens : int or list[int], optional + Proposal length(s) for Speculative Decoding Target Language Model. + A plain int K is treated as ``[K]`` (backward compatible). Each value K generates a decode specialization with seq_len=K+1 and num_logits_to_keep=K+1. Include 0 to compile a cheap single-token fallback (e.g. ``[0, 3]`` for a fallback + full K=3 decode). Required if the model is @@ -3665,7 +3675,7 @@ def compile( ) for k in _decode_ks: spec = self._build_decode_spec_for_k( - k=k, + num_speculative_tokens=k, ctx_len=ctx_len, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, diff --git a/tests/transformers/spd/conftest.py b/tests/transformers/spd/conftest.py deleted file mode 100644 index 2b3f8d7a7c..0000000000 --- a/tests/transformers/spd/conftest.py +++ /dev/null @@ -1,34 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- -""" -Local conftest for SPD hardware tests. - -Patches QAICInferenceSession so that calls without an explicit device_ids -argument default to [4] instead of device 0. Set QAIC_TEST_DEVICE_ID in -the environment to override (e.g. QAIC_TEST_DEVICE_ID=5 pytest ...). -""" - -import os - -import pytest - -_DEVICE_ID = int(os.environ.get("QAIC_TEST_DEVICE_ID", "4")) - - -@pytest.fixture(autouse=True) -def _use_test_device(monkeypatch): - """Redirect all bare QAICInferenceSession() calls to _DEVICE_ID.""" - from QEfficient.generation.cloud_infer import QAICInferenceSession - - _orig_init = QAICInferenceSession.__init__ - - def _patched_init(self, qpc_path, device_ids=None, **kwargs): - if device_ids is None: - device_ids = [_DEVICE_ID] - _orig_init(self, qpc_path, device_ids=device_ids, **kwargs) - - monkeypatch.setattr(QAICInferenceSession, "__init__", _patched_init) diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index cef8886475..bb3a38de9a 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -505,7 +505,7 @@ def test_multi_spec_structure(model_id, decode_ks): specs = [] for k in sorted(set(decode_ks)): spec = target_model._build_decode_spec_for_k( - k=k, + num_speculative_tokens=k, ctx_len=ctx_len, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, diff --git a/tests/unit_test/models/test_modeling_auto_cpu.py b/tests/unit_test/models/test_modeling_auto_cpu.py index af5f8d2f09..b63f0424b6 100644 --- a/tests/unit_test/models/test_modeling_auto_cpu.py +++ b/tests/unit_test/models/test_modeling_auto_cpu.py @@ -997,25 +997,25 @@ class TestTLMMultiSpecSpecializations: # ---- _build_decode_spec_for_k ---- def test_build_decode_spec_for_k_seq_len(self): - """_build_decode_spec_for_k sets seq_len = k+1.""" + """_build_decode_spec_for_k sets seq_len = num_speculative_tokens+1.""" model, _ = make_tiny_llama() qeff = QEFFAutoModelForCausalLM(model) qeff.is_tlm = True for k in [0, 1, 3, 7]: spec = qeff._build_decode_spec_for_k( - k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 + num_speculative_tokens=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 ) assert spec is not None assert spec["seq_len"] == k + 1 def test_build_decode_spec_for_k_num_logits_to_keep(self): - """_build_decode_spec_for_k sets num_logits_to_keep = k+1.""" + """_build_decode_spec_for_k sets num_logits_to_keep = num_speculative_tokens+1.""" model, _ = make_tiny_llama() qeff = QEFFAutoModelForCausalLM(model) qeff.is_tlm = True for k in [0, 1, 3]: spec = qeff._build_decode_spec_for_k( - k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 + num_speculative_tokens=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 ) assert spec["num_logits_to_keep"] == k + 1 @@ -1024,8 +1024,10 @@ def test_build_decode_spec_for_k_returns_none_when_duplicate_prefill(self): model, _ = make_tiny_llama() qeff = QEFFAutoModelForCausalLM(model) qeff.is_tlm = True - # k=0 → seq_len=1 == prefill_seq_len=1 → should be None - spec = qeff._build_decode_spec_for_k(k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=1) + # num_speculative_tokens=0 → seq_len=1 == prefill_seq_len=1 → should be None + spec = qeff._build_decode_spec_for_k( + num_speculative_tokens=0, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=1 + ) assert spec is None def test_build_decode_spec_for_k_not_none_with_continuous_batching(self): @@ -1033,9 +1035,14 @@ def test_build_decode_spec_for_k_not_none_with_continuous_batching(self): model, _ = make_tiny_llama() qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) qeff.is_tlm = True - # k=0 → seq_len=1 == prefill_seq_len=1, but CB is True → should not be None + # num_speculative_tokens=0 → seq_len=1 == prefill_seq_len=1, but CB is True → should not be None spec = qeff._build_decode_spec_for_k( - k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=2, full_batch_size=2, prefill_seq_len=1 + num_speculative_tokens=0, + ctx_len=128, + batch_size=1, + kv_cache_batch_size=2, + full_batch_size=2, + prefill_seq_len=1, ) assert spec is not None @@ -1132,3 +1139,26 @@ def test_compile_int_backward_compat(self): decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] assert len(decode_specs) == 1, f"Expected 1 decode spec for int input, got: {decode_specs}" assert decode_specs[0]["seq_len"] == 4 # k=3 → seq_len=4 + + def test_compile_int_zero_backward_compat(self): + """compile(num_speculative_tokens=0) as plain scalar int still works (treated as [0]).""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=0) + + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec for scalar 0, got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 1 # k=0 → seq_len=1