diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 443f56a038..dc5ed359f1 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -13,7 +13,7 @@ import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import onnx import torch @@ -42,7 +42,6 @@ hash_dict_params, load_json, require_value, - to_named_specializations, ) from QEfficient.utils.export_utils import export_wrapper @@ -487,7 +486,7 @@ def _compile( specializations: Optional[List[Dict[str, int]]] = None, custom_io: Optional[Dict[str, str]] = None, mdp_ts_num_devices: int = 1, - num_speculative_tokens: Optional[int] = None, + num_speculative_tokens: Optional[Union[int, List[int]]] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, @@ -509,7 +508,7 @@ def _compile( :specializations (list): List of specializations to compile for :custom_io (dict): Custom IO to specify the input and outputs in different formats than default :mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing. - :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. + :num_speculative_tokens (int | List[int], optional): Number of speculative tokens for TLM decode. A plain int K compiles one decode specialization (seq_len=K+1). A list [K0, K1, ...] compiles one specialization per value, enabling per-step dispatch to the cheapest kernel. :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` :qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.`` :compiler_options: Pass any compiler option as input. @@ -635,9 +634,18 @@ def _compile( # Write specializations.json file if specializations is not None: specializations_json = compile_dir / "specializations.json" - specializations_data = { - "specializations": to_named_specializations(specializations, module_name=specialization_module_name) - } + # Strip internal _graph_name tags and write flat format for qaic-compile. + # Named format ({"name": ..., "symbols": {...}}) is only required for the + # QNN path (already branched off above). The qaic-compile binary and its + # MDP (multi-device partition) firmware support only the flat format: + # {"batch_size": "4", "seq_len": "5", ...} + # Using named format for MDP QPCs causes a RuntimeError at ExecObj + # creation time ("Failed to create ExecObj") on 4-device tensor-parallel. + # All values must be strings — qaic-compile rejects integer values. + flat_specs = [ + {key: str(val) for key, val in spec.items() if key != "_graph_name"} for spec in specializations + ] + specializations_data = {"specializations": flat_specs} create_json(str(specializations_json), specializations_data) command.append(f"-network-specialization-config={specializations_json}") diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4ad56592fb..20fcf2d229 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3370,6 +3370,58 @@ def build_decode_specialization( result["_graph_name"] = "Decode" return result + def _build_decode_spec_for_k( + self, + k: int, + ctx_len: int = 128, + batch_size: int = 1, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + prefill_seq_len: int = 32, + comp_ctx_lengths: Optional[int] = None, + ): + """ + Builds a TLM decode specialization for proposal length *k* (``seq_len = k+1``). + + Parameters + ---------- + k : int + Number of speculative (draft) tokens. ``seq_len`` and ``num_logits_to_keep`` + are both set to ``k + 1``. + ctx_len : int, optional + Maximum context length. Default is 128. + batch_size : int, optional + Batch size. Default is 1. + kv_cache_batch_size : int, optional + Batch size for KV cache allocation. + full_batch_size : int, optional + Continuous batching full batch size. + prefill_seq_len : int, optional + Used to detect and skip duplicate specializations (when ``seq_len == prefill_seq_len`` + and continuous batching is disabled). + + Returns + ------- + Optional[Dict[str, Union[int, str]]] + Specialization dict, or ``None`` if it would duplicate the prefill specialization. + """ + seq_len = k + 1 + if seq_len == prefill_seq_len and not self.continuous_batching: + return None + spec = { + "seq_len": seq_len, + "ctx_len": ctx_len, + "num_logits_to_keep": seq_len, + } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: + spec["batch_size"] = full_batch_size + spec["full_batch_size"] = kv_cache_batch_size + else: + spec["batch_size"] = kv_cache_batch_size + return {k_: v for k_, v in spec.items() if v is not None} + def compile( self, onnx_path: Optional[str] = None, @@ -3386,7 +3438,7 @@ def compile( num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, - num_speculative_tokens: Optional[int] = None, + num_speculative_tokens: Optional[List[int]] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, @@ -3428,9 +3480,12 @@ def compile( Use MXFP6 compression for weights. Default is False. mxint8_kv_cache : bool, optional Use MXINT8 compression for KV cache. Default is False. - num_speculative_tokens : int, optional - Number of speculative tokens for Speculative Decoding Target Language Model. - Required if the model is configured as a Target Language Model (`is_tlm=True`). + num_speculative_tokens : list[int], optional + List of proposal lengths for Speculative Decoding Target Language Model. + Each value K generates a decode specialization with seq_len=K+1 and + num_logits_to_keep=K+1. Include 0 to compile a cheap single-token fallback + (e.g. ``[0, 3]`` for a fallback + full K=3 decode). Required if the model is + configured as a Target Language Model (``is_tlm=True``). prefill_only : bool, optional If True, compiles only for the prefill stage. If False, compiles only for the decode stage. If None, compiles for both stages. Default is None. @@ -3465,10 +3520,10 @@ def compile( TypeError If `prefill_only` is not a boolean. If `full_batch_size` is None when `continuous_batching` is True. - If `num_speculative_tokens` is None when the model is a TLM. + If `num_speculative_tokens` is None or empty when the model is a TLM. ValueError If KV caching is requested without continuous batching (`full_batch_size`). - If `include_sampler` is True and `num_speculative_tokens` is greater than 0. + If `include_sampler` is True and `num_speculative_tokens` contains a value > 0. If `num_speculative_tokens` is not an integer greater than 1. If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. @@ -3533,14 +3588,32 @@ def compile( if prefill_only is not None and not isinstance(prefill_only, bool): raise TypeError("`prefill_only` must be a boolean.") + _decode_ks = ( + sorted(set(num_speculative_tokens)) + if isinstance(num_speculative_tokens, (list, tuple)) + else ([num_speculative_tokens] if num_speculative_tokens is not None else None) + ) + if self.is_tlm: - num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) + _max_k = _decode_ks[-1] if _decode_ks else None + validated_k = self.check_and_get_num_speculative_tokens(_max_k, prefill_seq_len) + if validated_k is not None and validated_k != _max_k: + # speculative_config in model.config overrides num_speculative_tokens. + # Warn if the user passed a list — the extra values are discarded. + if _decode_ks is not None and len(_decode_ks) > 1: + discarded = [k for k in _decode_ks if k != validated_k] + logger.warning( + f"speculative_config in model.config fixes num_speculative_tokens={validated_k}. " + f"Ignoring user-supplied values {discarded}. " + f"Pass num_speculative_tokens={validated_k} (or [{validated_k}]) to suppress this warning." + ) + _decode_ks = [validated_k] if ( self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False) - and num_speculative_tokens is not None - and num_speculative_tokens > 0 + and _decode_ks is not None + and max(_decode_ks) > 0 ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") @@ -3577,8 +3650,33 @@ def compile( ) if (prefill_only is None or not prefill_only) and prefill_seq_len != 1: - if self.comp_ctx_lengths_decode is not None: - # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + if _decode_ks is not None and self.is_tlm: + # TLM multi-spec path: one decode specialization per K in num_speculative_tokens. + # CCL (comp_ctx_lengths) + multi-spec TLM is not yet supported: the per-K call + # to _build_decode_spec_for_k would need to iterate over CCL values, producing + # len(decode_ks) × len(comp_ctx_lengths_decode) decode specializations whose + # naming and ordering is untested. Reject early so users get a clear error + # instead of a silently wrong QPC. + if self.comp_ctx_lengths_decode is not None: + raise NotImplementedError( + "TLM multi-spec (num_speculative_tokens as a list) combined with " + "comp_ctx_lengths_decode is not yet supported. Pass a plain int for " + "num_speculative_tokens when using CCL." + ) + for k in _decode_ks: + spec = self._build_decode_spec_for_k( + k=k, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + prefill_seq_len=prefill_seq_len, + ) + if spec is not None: + specializations.append(spec) + + elif self.comp_ctx_lengths_decode is not None: + # CCL loop (non-TLM) for i in range(0, len(self.comp_ctx_lengths_decode)): decode_spec = self.build_decode_specialization( prefill_seq_len=prefill_seq_len, @@ -3587,7 +3685,7 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=None, ) if decode_spec: specializations.append(decode_spec) @@ -3599,7 +3697,7 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=None, prefill_only=prefill_only, ) if decode_spec: diff --git a/examples/performance/speculative_decoding/prompt_lookup.py b/examples/performance/speculative_decoding/prompt_lookup.py index 53b1f4e851..4a0b0f8141 100644 --- a/examples/performance/speculative_decoding/prompt_lookup.py +++ b/examples/performance/speculative_decoding/prompt_lookup.py @@ -168,7 +168,9 @@ def find_candidate_pred_tokens( if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length: raise ValueError("Invalid max_ngram_size or num_pred_tokens") - has_empty_tokens = False + best_result = np.full(num_pred_tokens, fill_tok, dtype=np.int64) + best_count = 0 + for ngram_size in range(max_ngram_size, 0, -1): # Extract the last n tokens as our search ngram ngram = input_ids[0, -ngram_size:] @@ -182,23 +184,44 @@ def find_candidate_pred_tokens( # Get the indices of matches match_indices = np.where(matches)[0] - # Iterate through match indices to find a valid continuation + # Iterate through match indices to find the longest available continuation for idx in match_indices: start_idx = idx + ngram_size - end_idx = start_idx + num_pred_tokens - # Ensure we don't go beyond the length of input_ids and avoid self-match - if end_idx <= input_length and start_idx < input_length - ngram_size: - return input_ids[0, start_idx:end_idx], has_empty_tokens + # Avoid self-match + if start_idx >= input_length - ngram_size: + continue + + available = min(input_length - start_idx, num_pred_tokens) + if available > best_count: + best_result = np.full(num_pred_tokens, fill_tok, dtype=np.int64) + best_result[:available] = input_ids[0, start_idx : start_idx + available] + best_count = available + if best_count == num_pred_tokens: + return best_result, False # full match found + + # has_empty_tokens is True only when zero proposals were found + return best_result, (best_count == 0) + + +def _select_k(actual_proposals: np.ndarray, decode_ks: List[int]) -> int: + """Return the smallest K in decode_ks that covers the maximum proposal count in the batch. - # If no match is found, return invalid array - has_empty_tokens = True - return np.full(num_pred_tokens, fill_tok, dtype=np.int64), has_empty_tokens + Returns ``decode_ks[-1]`` (max K) when the array is empty — all batch items + have finished generating and no valid proposals remain. + """ + if len(actual_proposals) == 0: + return decode_ks[-1] + need = int(actual_proposals.max()) + for k in decode_ks: + if k >= need: + return k + return decode_ks[-1] def pld_spec_decode_inference( prompts: List[str], - num_speculative_tokens: int, + num_speculative_tokens: Union[int, List[int]], prefill_seq_len: int, ctx_len: int, prefill_bsz: int, @@ -212,7 +235,10 @@ def pld_spec_decode_inference( Args: prompts (List[str]): List of prompts to perform inference on. - num_speculative_tokens (int): Number of speculative tokens. + num_speculative_tokens (Union[int, List[int]]): Number of speculative tokens, or a list + of proposal lengths to compile specializations for. Each value K generates a + specialization with seq_len=K+1. Include 0 for a cheap single-token fallback + (e.g. [0, 3]). A plain int is treated as a single-element list. prefill_seq_len (int): Prefill sequence length. ctx_len (int): Context length. prefill_bsz (int): Prefill batch size. @@ -224,6 +250,10 @@ def pld_spec_decode_inference( Returns: SpDCloudAI100ExecInfo: Execution information, including performance metrics and generated text. """ + decode_ks = ( + sorted(set(num_speculative_tokens)) if isinstance(num_speculative_tokens, list) else [num_speculative_tokens] + ) + max_k = decode_ks[-1] # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size # get vocab size tokenizer = AutoTokenizer.from_pretrained(target_model_name, padding_side="right") @@ -245,7 +275,7 @@ def pld_spec_decode_inference( ctx_len=ctx_len, aic_enable_depth_first=True, full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, + num_speculative_tokens=decode_ks, ) # init qaic session target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) @@ -278,16 +308,20 @@ def pld_spec_decode_inference( # run prefill on both draft and target models # mock input key "logits" to store the first batch of output logits tlm_precode_inputs = dict( - input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), - position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64), + input_ids=np.zeros((decode_batch_size, max_k + 1), dtype=np.int64), + position_ids=np.zeros((decode_batch_size, max_k + 1), dtype=np.int64), batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1), - num_logits_to_keep=np.arange(num_speculative_tokens + 1, dtype=np.int64).reshape(-1, 1), + num_logits_to_keep=np.arange(max_k + 1, dtype=np.int64).reshape(-1, 1), ) - num_logits_to_keep = num_speculative_tokens + 1 + num_logits_to_keep = max_k + 1 max_gen_len = [ctx_len] * decode_batch_size # setup buffers tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) precode_logits_ph = np.zeros((decode_batch_size, num_logits_to_keep, vocab_size), dtype=np.float32) + # Pre-allocate per-K logit buffers for smaller specializations + logit_buffers = { + k: np.zeros((decode_batch_size, k + 1, vocab_size), dtype=np.float32) for k in decode_ks if k != max_k + } target_model_session.set_buffers({"logits": tlm_prefill_logits_ph}) e2e_start = perf_counter() @@ -310,9 +344,7 @@ def pld_spec_decode_inference( generated_ids[bi].append(input_ids.item()) tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item() input_len = prompts_tokenized[bi]["position_ids"].max(1).item() + 1 - tlm_precode_inputs["position_ids"][bi] = np.arange( - input_len, input_len + num_speculative_tokens + 1, dtype=np.int64 - ) + tlm_precode_inputs["position_ids"][bi] = np.arange(input_len, input_len + max_k + 1, dtype=np.int64) # assumes that prefill queue will always be popped from the front input_lengths[bi] = input_len max_gen_len[bi] -= input_lengths[bi] @@ -329,7 +361,7 @@ def pld_spec_decode_inference( decode_start = perf_counter() mean_num_accepted_tokens = 0 all_accept = np.full(decode_batch_size, False, dtype=bool) - tlm_position_ids = np.arange(num_speculative_tokens + 1).reshape(1, -1).repeat(decode_batch_size, axis=0) + tlm_position_ids = np.arange(max_k + 1).reshape(1, -1).repeat(decode_batch_size, axis=0) empty_indices = np.zeros(decode_batch_size, dtype=bool) decode_draft_time = 0.0 decode_target_time = 0.0 @@ -347,28 +379,55 @@ def pld_spec_decode_inference( all_ids[bi : bi + 1, : prompt_plus_gen_idx[bi]], fill_tok=-1, max_ngram_size=max_ngram_size, - num_pred_tokens=num_speculative_tokens, + num_pred_tokens=max_k, ) empty_indices[bi] = has_empty_tokens - # prepare target model inputs + # prepare target model inputs — always write spec_tokens (fill_tok for empty slots) + tlm_precode_inputs["input_ids"][bi, 1:] = spec_tokens if has_empty_tokens: # avoid read/write of KV$ for meaningless tokens tlm_precode_inputs["position_ids"][bi, 1:] = -1 else: - tlm_precode_inputs["input_ids"][bi, 1:] = spec_tokens + # For partial matches: mask position_ids for unfilled proposal slots + fill_mask = spec_tokens == -1 + if fill_mask.any(): + tlm_precode_inputs["position_ids"][bi, 1:][fill_mask] = -1 draft_end = perf_counter() - draft_start decode_draft_time += draft_end # run precode on TLM to score the proposed tokens target_start = perf_counter() - tlm_outputs = target_model_session.run(tlm_precode_inputs) - target_logits = tlm_outputs["logits"] + # Count actual proposal tokens per batch item (fill_tok=-1 marks unfilled positions) + actual_proposals = (tlm_precode_inputs["input_ids"][:, 1:] != -1).sum(axis=1).astype(np.int64) + actual_proposals[~valid_batch_indices] = 0 + selected_k = _select_k(actual_proposals[valid_batch_indices], decode_ks) + if selected_k == max_k: + tlm_outputs = target_model_session.run(tlm_precode_inputs) + target_logits = tlm_outputs["logits"] + else: + sel_inputs = { + "input_ids": tlm_precode_inputs["input_ids"][:, : selected_k + 1], + "position_ids": tlm_precode_inputs["position_ids"][:, : selected_k + 1], + "batch_index": tlm_precode_inputs["batch_index"], + "num_logits_to_keep": np.arange(selected_k + 1, dtype=np.int64).reshape(-1, 1), + } + target_model_session.set_buffers({"logits": logit_buffers[selected_k]}) + try: + tlm_outputs = target_model_session.run(sel_inputs) + raw_logits = tlm_outputs["logits"] # [batch, selected_k+1, vocab] + finally: + # Always restore the max-K placeholder so the next iteration's + # full-K path does not write into an undersized buffer. + target_model_session.set_buffers({"logits": precode_logits_ph}) + # Pad to [batch, max_k+1] so downstream acceptance logic is unchanged + pad = np.zeros((decode_batch_size, max_k - selected_k, vocab_size), dtype=np.float32) + target_logits = np.concatenate([raw_logits, pad], axis=1) # greedy sampling from target model target_tokens = target_logits.argmax(-1) target_end = perf_counter() - target_start decode_target_time += target_end # exact matching between draft and target tokens num_tokens_selected = np.ones(decode_batch_size, dtype=np.int64) - tlm_precode_position_ids = np.full((decode_batch_size, num_speculative_tokens + 1), -1, dtype=np.int64) + tlm_precode_position_ids = np.full((decode_batch_size, max_k + 1), -1, dtype=np.int64) non_empty_valid_indices = ~empty_indices & valid_batch_indices matching = ( tlm_precode_inputs["input_ids"][non_empty_valid_indices, 1:] == target_tokens[non_empty_valid_indices, :-1] @@ -383,7 +442,7 @@ def pld_spec_decode_inference( non_empty_valid_indices ] + num_tokens_selected[non_empty_valid_indices].reshape(-1, 1) # record accepted tokens - all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == num_speculative_tokens + 1 + all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == max_k + 1 mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item() # append selected tokens to the generated_ids for bi, valid in enumerate(valid_batch_indices): @@ -439,7 +498,7 @@ def pld_spec_decode_inference( batch_decode, generated_ids, perf_metrics, - num_speculative_tokens, + max_k, prefill_seq_len, ctx_len, prefill_bsz, @@ -457,7 +516,12 @@ def comma_separated_ints(x: str): def arg_parse(): parser = ArgumentParser(description="Draft-based SpD Inference") parser.add_argument("--prompts", action="append", default=None, help="Input prompt(s)") - parser.add_argument("--num-speculative-tokens", type=int, default=3, help="Number of speculative tokens") + parser.add_argument( + "--num-speculative-tokens", + type=comma_separated_ints, + default="3", + help="Comma-separated list of proposal lengths (e.g. '0,3' or '3'). Each value K compiles a specialization with seq_len=K+1.", + ) parser.add_argument("--prefill-seq-len", type=int, default=256, help="Prefill sequence length") parser.add_argument("--ctx-len", type=int, default=1024, help="Context length") parser.add_argument("--prefill-bsz", type=int, default=1, help="Prefill batch size") diff --git a/tests/transformers/spd/conftest.py b/tests/transformers/spd/conftest.py new file mode 100644 index 0000000000..2b3f8d7a7c --- /dev/null +++ b/tests/transformers/spd/conftest.py @@ -0,0 +1,34 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Local conftest for SPD hardware tests. + +Patches QAICInferenceSession so that calls without an explicit device_ids +argument default to [4] instead of device 0. Set QAIC_TEST_DEVICE_ID in +the environment to override (e.g. QAIC_TEST_DEVICE_ID=5 pytest ...). +""" + +import os + +import pytest + +_DEVICE_ID = int(os.environ.get("QAIC_TEST_DEVICE_ID", "4")) + + +@pytest.fixture(autouse=True) +def _use_test_device(monkeypatch): + """Redirect all bare QAICInferenceSession() calls to _DEVICE_ID.""" + from QEfficient.generation.cloud_infer import QAICInferenceSession + + _orig_init = QAICInferenceSession.__init__ + + def _patched_init(self, qpc_path, device_ids=None, **kwargs): + if device_ids is None: + device_ids = [_DEVICE_ID] + _orig_init(self, qpc_path, device_ids=device_ids, **kwargs) + + monkeypatch.setattr(QAICInferenceSession, "__init__", _patched_init) diff --git a/tests/transformers/spd/test_pld_inference.py b/tests/transformers/spd/test_pld_inference.py index 28428394c2..cef8886475 100644 --- a/tests/transformers/spd/test_pld_inference.py +++ b/tests/transformers/spd/test_pld_inference.py @@ -477,3 +477,76 @@ def test_dummy_pld_inference(model_id, manual_cleanup): model_config_dict[model_id]["target_model_name"], **model_config_dict[model_id]["additional_params"] ) check_pld_spec_decode_inference(model_id, config=hf_config, manual_cleanup=manual_cleanup) + + +@pytest.mark.parametrize("model_id", test_models_id) +@pytest.mark.parametrize("decode_ks", [[3], [0, 3], [1, 2, 3], [0, 1, 2, 3]]) +def test_multi_spec_structure(model_id, decode_ks): + """ + Verify that _build_decode_spec_for_k produces correct specializations for each K value. + No hardware required. + """ + target_model_name = model_config_dict[model_id]["target_model_name"] + prefill_seq_len = model_config_dict[model_id]["prefill_seq_len"] + ctx_len = model_config_dict[model_id]["ctx_len"] + full_batch_size = model_config_dict[model_id]["full_batch_size"] + continuous_batching = full_batch_size is not None + + target_model = load_qeff_causal_lm_model( + target_model_name, + num_hidden_layers=2, + continuous_batching=continuous_batching, + qaic_config={"speculative_model_type": "target"}, + ) + + kv_cache_batch_size = full_batch_size or 1 + batch_size = 1 + + specs = [] + for k in sorted(set(decode_ks)): + spec = target_model._build_decode_spec_for_k( + k=k, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + prefill_seq_len=prefill_seq_len, + ) + assert spec is not None, f"_build_decode_spec_for_k returned None for k={k}" + assert spec["seq_len"] == k + 1, f"Expected seq_len={k + 1}, got {spec['seq_len']}" + assert spec["num_logits_to_keep"] == k + 1, ( + f"Expected num_logits_to_keep={k + 1}, got {spec['num_logits_to_keep']}" + ) + assert spec["ctx_len"] == ctx_len + specs.append(spec) + + seq_lens = [s["seq_len"] for s in specs] + assert len(seq_lens) == len(set(seq_lens)), f"Duplicate seq_len values in specs: {seq_lens}" + + +# --------------------------------------------------------------------------- +# _select_k dispatch helper tests (no hardware required) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "actual_proposals,decode_ks,expected_k", + [ + # All batch items have 0 proposals → smallest k >= 0 + (np.array([0, 0, 0]), [0, 3], 0), + # Mix: some proposals, some not → smallest k >= max=3 + (np.array([0, 3, 3]), [0, 3], 3), + # Single spec: always returns only option + (np.array([0, 0]), [3], 3), + # need=2, ks=[0,2,4] → returns 2 + (np.array([1, 2]), [0, 2, 4], 2), + # need exceeds all → returns max + (np.array([5, 5]), [0, 3], 3), + ], +) +def test_select_k(actual_proposals, decode_ks, expected_k): + """_select_k returns the smallest K in decode_ks covering the max actual proposal count.""" + from examples.performance.speculative_decoding.prompt_lookup import _select_k + + result = _select_k(actual_proposals, decode_ks) + assert result == expected_k, f"Expected {expected_k}, got {result} for proposals={actual_proposals}, ks={decode_ks}" diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index a79f17d556..feb0153e3c 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -375,3 +375,213 @@ def test_dummy_spd_inference(model_id, manual_cleanup): **model_config_dict[model_id]["additional_params"], ) check_spec_decode_inference(model_id, config=hf_config, manual_cleanup=manual_cleanup) + + +# --------------------------------------------------------------------------- +# Multi-spec logit correctness — hardware-level QPC test +# --------------------------------------------------------------------------- + +_MULTI_SPEC_MODEL = "JackFram/llama-68m" +_MULTI_SPEC_NUM_LAYERS = 2 +_MULTI_SPEC_PREFILL_LEN = 32 +_MULTI_SPEC_CTX_LEN = 128 +_MULTI_SPEC_N_STEPS = 8 # decode positions to verify per specialisation +_MULTI_SPEC_PROMPT = "My name is" + + +def _run_prefill(session, tokenized, vocab_size, num_logits_to_keep=None): + """Run chunked prefill and return the logit from the last chunk.""" + inputs = dict(tokenized) + if num_logits_to_keep is not None: + inputs["num_logits_to_keep"] = num_logits_to_keep + ph = np.zeros((1, 1, vocab_size), dtype=np.float32) + session.set_buffers({"logits": ph}) + out = session.run(inputs) + return out["logits"][0, 0, :] # [vocab] + + +def _collect_vanilla_reference(session, first_token, start_pos, vocab_size, n_steps): + """ + Teacher-forced decode: feed ground-truth tokens one at a time and collect + (logit, next_token) at each position. + + Returns: + ref_tokens : list[int] – tokens[i] is fed at position start_pos+i + ref_logits : list[ndarray] – ref_logits[i] is the logit produced after + feeding tokens[i], shape [vocab] + """ + ref_tokens = [int(first_token)] + ref_logits = [] + ph = np.zeros((1, 1, vocab_size), dtype=np.float32) + session.set_buffers({"logits": ph}) + for step in range(n_steps): + out = session.run( + { + "input_ids": np.array([[ref_tokens[-1]]], dtype=np.int64), + "position_ids": np.array([[start_pos + step]], dtype=np.int64), + } + ) + logit = out["logits"][0, 0, :].copy() + ref_logits.append(logit) + ref_tokens.append(int(logit.argmax())) + return ref_tokens, ref_logits + + +def _verify_tlm_spec(tlm_session, k, ref_tokens, ref_logits, start_pos, vocab_size): + """ + Run TLM with seq_len=k+1 (teacher-forced in chunks) and assert that every + output logit matches the corresponding vanilla reference logit. + + Both the accepted-token (argmax) and the full logit vector (atol=5e-2) are + checked. Chunks are non-overlapping; leftover positions at the end are skipped. + + Returns the number of (position, specialisation) pairs that were asserted. + """ + seq_len = k + 1 + n_logits_to_keep = np.arange(seq_len, dtype=np.int64).reshape(-1, 1) + ph = np.zeros((1, seq_len, vocab_size), dtype=np.float32) + tlm_session.set_buffers({"logits": ph}) + + n_assertions = 0 + n_chunks = len(ref_logits) // seq_len + for chunk in range(n_chunks): + chunk_tokens = ref_tokens[chunk * seq_len : chunk * seq_len + seq_len] + chunk_positions = np.array([[start_pos + chunk * seq_len + i for i in range(seq_len)]], dtype=np.int64) + out = tlm_session.run( + { + "input_ids": np.array([chunk_tokens], dtype=np.int64), + "position_ids": chunk_positions, + "num_logits_to_keep": n_logits_to_keep, + } + ) + tlm_logits = out["logits"] # [1, seq_len, vocab] + + for i in range(seq_len): + ref_pos = chunk * seq_len + i + ref_logit = ref_logits[ref_pos] + tlm_logit = tlm_logits[0, i, :] + + assert np.allclose(tlm_logit, ref_logit, atol=5e-2), ( + f"K={k}, chunk={chunk}, offset={i} (abs pos {start_pos + ref_pos}): " + f"logit mismatch — max_diff={np.abs(tlm_logit - ref_logit).max():.3e}" + ) + assert int(tlm_logit.argmax()) == int(ref_logit.argmax()), ( + f"K={k}, chunk={chunk}, offset={i} (abs pos {start_pos + ref_pos}): " + f"accepted-token mismatch — " + f"TLM={int(tlm_logit.argmax())} vs ref={int(ref_logit.argmax())}" + ) + n_assertions += 1 + + return n_assertions + + +@pytest.mark.on_qaic +@pytest.mark.feature +@pytest.mark.parametrize( + "decode_ks", + [ + [0], # fallback-only (seq_len=1) + [3], # full-K only (seq_len=4) + [0, 3], # fallback + full-K (typical PLD config) + [1, 2, 3], # suffix-decoding range + ], +) +def test_multi_spec_qpc_logit_correctness(decode_ks, manual_cleanup): + """ + Verify that every decode specialisation in `decode_ks` produces logits that + match the vanilla (DLM) reference at every token position, for ALL output + positions of each specialisation. + + Strategy + -------- + 1. Compile vanilla model (seq_len=1 decode) → collect ref_logits[pos] at each + of _MULTI_SPEC_N_STEPS positions using teacher-forcing with greedy outputs. + 2. Compile TLM with all K values in decode_ks. + 3. For each K: fresh TLM session (resets KV cache) → prefill same prompt → + teacher-forced decode in non-overlapping K+1 chunks → assert ALL K+1 output + logits per chunk match ref_logits at corresponding positions. + + Scaling: adding K values to decode_ks automatically adds new assertions. + """ + tokenizer = AutoTokenizer.from_pretrained(_MULTI_SPEC_MODEL, padding_side="right") + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + vocab_size = len(tokenizer) + + # Tokenise prompt (padded to prefill_seq_len) + raw = tokenizer(_MULTI_SPEC_PROMPT, return_tensors="np") + input_len = int(raw.input_ids.shape[1]) + pad_len = _MULTI_SPEC_PREFILL_LEN + tokenized = tokenizer( + _MULTI_SPEC_PROMPT, + return_tensors="np", + padding="max_length", + max_length=pad_len, + ) + position_ids = np.where( + tokenized.pop("attention_mask"), + np.arange(pad_len), + -1, + ) + prefill_inputs = { + "input_ids": tokenized["input_ids"], + "position_ids": position_ids, + } + + # ── 1. Compile vanilla (DLM) model ────────────────────────────────────── + vanilla = load_qeff_causal_lm_model(_MULTI_SPEC_MODEL, num_hidden_layers=_MULTI_SPEC_NUM_LAYERS) + vanilla_qpc = vanilla.compile( + num_cores=2, + prefill_seq_len=_MULTI_SPEC_PREFILL_LEN, + ctx_len=_MULTI_SPEC_CTX_LEN, + aic_enable_depth_first=True, + ) + + # ── 2. Compile TLM with all decode specialisations ─────────────────────── + tlm = load_qeff_causal_lm_model( + _MULTI_SPEC_MODEL, + num_hidden_layers=_MULTI_SPEC_NUM_LAYERS, + qaic_config={"speculative_model_type": "target"}, + ) + tlm_qpc = tlm.compile( + num_cores=2, + prefill_seq_len=_MULTI_SPEC_PREFILL_LEN, + ctx_len=_MULTI_SPEC_CTX_LEN, + aic_enable_depth_first=True, + num_speculative_tokens=decode_ks, + ) + + # ── 3. Collect vanilla reference logits ────────────────────────────────── + van_session = QAICInferenceSession(vanilla_qpc) + van_session.skip_buffers([x for x in van_session.input_names if x.startswith("past_")]) + van_session.skip_buffers([x for x in van_session.output_names if x.endswith("_RetainedState")]) + + prefill_logit = _run_prefill(van_session, prefill_inputs, vocab_size) + first_token = int(prefill_logit.argmax()) + ref_tokens, ref_logits = _collect_vanilla_reference( + van_session, first_token, input_len, vocab_size, _MULTI_SPEC_N_STEPS + ) + assert len(ref_logits) == _MULTI_SPEC_N_STEPS + + # ── 4. Verify each specialisation ──────────────────────────────────────── + total_assertions = 0 + for k in sorted(set(decode_ks)): + # Fresh TLM session for each K (resets retained KV state) + tlm_session = QAICInferenceSession(tlm_qpc) + tlm_session.skip_buffers([x for x in tlm_session.input_names if x.startswith("past_")]) + tlm_session.skip_buffers([x for x in tlm_session.output_names if x.endswith("_RetainedState")]) + + # Prefill TLM (num_logits_to_keep=[[1]]) + _run_prefill( + tlm_session, + prefill_inputs, + vocab_size, + num_logits_to_keep=np.ones((1, 1), dtype=np.int64), + ) + + n = _verify_tlm_spec(tlm_session, k, ref_tokens, ref_logits, input_len, vocab_size) + assert n > 0, f"K={k}: no positions were verified — check _MULTI_SPEC_N_STEPS vs seq_len" + total_assertions += n + + assert total_assertions > 0 + manual_cleanup([vanilla.onnx_path, tlm.onnx_path]) diff --git a/tests/unit_test/models/test_modeling_auto_cpu.py b/tests/unit_test/models/test_modeling_auto_cpu.py index 06302b2309..af5f8d2f09 100644 --- a/tests/unit_test/models/test_modeling_auto_cpu.py +++ b/tests/unit_test/models/test_modeling_auto_cpu.py @@ -980,5 +980,155 @@ def test_export_onnx_has_logits_output(self, tmp_export_dir): qeff = QEFFAutoModelForCTC(model) onnx_path = qeff.export(export_dir=str(tmp_export_dir)) onnx_model = onnx.load(str(onnx_path)) - output_names = {out.name for out in onnx_model.graph.output} - assert "logits" in output_names + output_names_ctc = {out.name for out in onnx_model.graph.output} + assert "logits" in output_names_ctc + + +# --------------------------------------------------------------------------- +# TLM multi-spec specialization unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +@pytest.mark.causal_lm +class TestTLMMultiSpecSpecializations: + """Tests for the multi-spec decode specialization API (num_speculative_tokens as list).""" + + # ---- _build_decode_spec_for_k ---- + + def test_build_decode_spec_for_k_seq_len(self): + """_build_decode_spec_for_k sets seq_len = k+1.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + qeff.is_tlm = True + for k in [0, 1, 3, 7]: + spec = qeff._build_decode_spec_for_k( + k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 + ) + assert spec is not None + assert spec["seq_len"] == k + 1 + + def test_build_decode_spec_for_k_num_logits_to_keep(self): + """_build_decode_spec_for_k sets num_logits_to_keep = k+1.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + qeff.is_tlm = True + for k in [0, 1, 3]: + spec = qeff._build_decode_spec_for_k( + k=k, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=32 + ) + assert spec["num_logits_to_keep"] == k + 1 + + def test_build_decode_spec_for_k_returns_none_when_duplicate_prefill(self): + """Returns None when seq_len == prefill_seq_len and no continuous batching.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + qeff.is_tlm = True + # k=0 → seq_len=1 == prefill_seq_len=1 → should be None + spec = qeff._build_decode_spec_for_k(k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=1, prefill_seq_len=1) + assert spec is None + + def test_build_decode_spec_for_k_not_none_with_continuous_batching(self): + """Returns spec even when seq_len == prefill_seq_len if continuous_batching is enabled.""" + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + qeff.is_tlm = True + # k=0 → seq_len=1 == prefill_seq_len=1, but CB is True → should not be None + spec = qeff._build_decode_spec_for_k( + k=0, ctx_len=128, batch_size=1, kv_cache_batch_size=2, full_batch_size=2, prefill_seq_len=1 + ) + assert spec is not None + + # ---- compile() specialization count via mock ---- + + def test_compile_list_produces_correct_spec_count(self): + """compile(num_speculative_tokens=[0, 3]) → 1 prefill + 2 decode specializations.""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[0, 3]) + + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 2, f"Expected 2 decode specs, got {len(decode_specs)}: {specs}" + + def test_compile_deduplication(self): + """compile(num_speculative_tokens=[3, 3, 3]) → only one decode spec for K=3.""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 3, 3]) + + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec (deduplicated), got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 4 + + def test_compile_sorting(self): + """compile(num_speculative_tokens=[3, 1, 2]) → decode specs in ascending seq_len order.""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=[3, 1, 2]) + + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 3 + seq_lens = [s["seq_len"] for s in decode_specs] + assert seq_lens == sorted(seq_lens), f"Decode specs not in sorted order: {seq_lens}" + + def test_compile_int_backward_compat(self): + """compile(num_speculative_tokens=3) as plain int still works (treated as [3]).""" + from unittest.mock import patch + + model, _ = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + captured = {} + + with patch.object( + type(qeff), + "_compile", + side_effect=lambda *args, **kw: ( + captured.update({"specializations": kw.get("specializations")}) or "/fake/qpc" + ), + ): + qeff.compile(prefill_seq_len=32, ctx_len=128, num_speculative_tokens=3) + + assert captured.get("specializations") is not None, "_compile was not reached" + specs = captured["specializations"] + decode_specs = [s for s in specs if s.get("seq_len", 0) != 32] + assert len(decode_specs) == 1, f"Expected 1 decode spec for int input, got: {decode_specs}" + assert decode_specs[0]["seq_len"] == 4 # k=3 → seq_len=4 diff --git a/tests/unit_test/transforms/test_speculative_decoding.py b/tests/unit_test/transforms/test_speculative_decoding.py index cdffb7c46a..3c33fbaec9 100644 --- a/tests/unit_test/transforms/test_speculative_decoding.py +++ b/tests/unit_test/transforms/test_speculative_decoding.py @@ -380,6 +380,75 @@ def test_tlm_forward_greedy_tokens_in_valid_range(self): assert (greedy_tokens >= 0).all() assert (greedy_tokens < VOCAB_SIZE).all() + @pytest.mark.parametrize("num_spec_tokens", [1, 2, 3, 5]) + def test_tlm_multi_spec_logit_consistency(self, num_spec_tokens): + """ + The anchor-token logit from seq_len=1 must equal the anchor-token logit at + position 0 from seq_len=K+1 — for the same input and standard causal attention. + + This is the core correctness guarantee for multi-spec dispatch on QAIC hardware. + + We test this using the raw HuggingFace LlamaForCausalLM (no QEffDynamicCache) + because the eager-mode QEffDynamicCache simulation uses max(position_ids) as the + KV gather limit, which exposes speculative positions to the anchor query and breaks + the property in Python. On QAIC hardware, per-query causal masking is applied + correctly by the hardware attention kernel — the property is verified empirically + by test_few_spd_inference, which asserts mean_num_accepted_tokens == K+1 + (100% acceptance rate when TLM == DLM). + + Why it holds: Standard causal attention masks position P from seeing positions + P+1..P+K, so the hidden state at P is identical regardless of what follows it. + SpDTransform's filter_hidden_states extracts this hidden state at index 0 of the + K+1 output, so the accepted token is always the same. + """ + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + raw_model = LlamaForCausalLM(cfg).eval() + + batch = 1 + anchor_token = torch.randint(0, VOCAB_SIZE, (batch, 1)) + anchor_pos = torch.tensor([[0]], dtype=torch.long) # start of sequence, no past + + # ── seq_len=1: just the anchor ─────────────────────────────────────────────── + with torch.no_grad(): + out_k0 = raw_model( + input_ids=anchor_token, + position_ids=anchor_pos, + ) + logit_k0 = out_k0.logits[:, 0:1, :] # [batch, 1, vocab] + + # ── seq_len=K+1: anchor at position 0, K random speculative tokens ────────── + spec_tokens = torch.randint(0, VOCAB_SIZE, (batch, num_spec_tokens)) + full_input_ids = torch.cat([anchor_token, spec_tokens], dim=1) + full_pos_ids = torch.arange(num_spec_tokens + 1).unsqueeze(0).expand(batch, -1) + + with torch.no_grad(): + out_kK = raw_model( + input_ids=full_input_ids, + position_ids=full_pos_ids, + ) + logit_kK_anchor = out_kK.logits[:, 0:1, :] # anchor is at index 0 + + # The anchor logit must be numerically identical regardless of K + assert torch.allclose(logit_k0, logit_kK_anchor, atol=1e-5), ( + f"Causal property violated: anchor logit differs between seq_len=1 and " + f"seq_len={num_spec_tokens + 1}: " + f"max_diff={(logit_k0 - logit_kK_anchor).abs().max().item():.2e}" + ) + # Accepted token (greedy argmax) must also be identical + assert logit_k0.argmax(dim=-1).eq(logit_kK_anchor.argmax(dim=-1)).all(), ( + "Accepted token differs between seq_len=1 and seq_len=K+1 — causal property violated in raw model" + ) + # --------------------------------------------------------------------------- # Tests: SpDTransform for Qwen2