diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 811b3f84d5..347b466ed0 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -289,10 +289,14 @@ def forward( sin = self.sin_cached[position_ids].unsqueeze(1) cos = self.cos_cached[position_ids].unsqueeze(1) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + self.target_layer_ids = getattr(self, "target_layer_ids", None) + target_hidden_list = [] + + for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states += (hidden_states,) - + if self.target_layer_ids and idx in self.target_layer_ids: + target_hidden_list.append(hidden_states) hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -316,6 +320,16 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + if self.target_layer_ids: + target_hidden = torch.cat(target_hidden_list, dim=-1) + target_hidden_fc = self.fc(target_hidden) + target_hidden_final = self.hidden_norm(target_hidden_fc) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=target_hidden_final, + ) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, @@ -354,6 +368,10 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + + if getattr(self.model, "target_layer_ids", None): + output_hidden_states = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -371,6 +389,20 @@ def forward( # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if getattr(self.model, "target_layer_ids", None): + target_hidden = outputs.hidden_states + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + predicted_token_ids = logits.argmax(dim=-1).to(torch.int32) + return CausalLMOutputWithPast( + loss=None, + logits=predicted_token_ids, + past_key_values=outputs.past_key_values, + hidden_states=target_hidden, + attentions=outputs.attentions, + ) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states).float() diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2668be8a1e..c49ce7125b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -47,6 +47,8 @@ ) from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, + DFlashTLMTransform, + DFlashTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, @@ -2878,6 +2880,20 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True + self.dflash_dlm = False + self.hidden_size = self.model.config.hidden_size + self.vocab_size = self.model.config.vocab_size + if qaic_config is not None: + self.dflash_dlm = qaic_config.get("dflash_dlm", False) + if self.dflash_dlm: + self.model, _ = DFlashTransform.apply(self.model, qaic_config) + + self.dflash_tlm = False + if qaic_config is not None: + self.dflash_tlm = bool(qaic_config.get("target_layer_ids", None)) + if self.dflash_tlm: + self.model, _ = DFlashTLMTransform.apply(self.model, qaic_config) + def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -3079,7 +3095,7 @@ def export( fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( - self.model.config, fbs if self.continuous_batching else bs, seq_len + self.model.config, fbs if self.continuous_batching else bs, seq_len * 2 ) enable_chunking = kwargs.get("enable_chunking", False) if ( @@ -3133,6 +3149,22 @@ def export( "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + + if self.dflash_dlm: + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "target_hidden": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), + "position_ids": torch.arange(seq_len, 2 * seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "position_ids_target": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "past_key_values": [[] for _ in range(self.num_layers)], + } + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "target_hidden": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids_target": {0: "batch_size", 1: "seq_len"}, + } + if self.ccl_enabled: example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int8) dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} @@ -3250,6 +3282,9 @@ def export( qaic_config=self.model.qaic_config, ) + if self.dflash_tlm: + output_names.append("hidden_states") + return self._export( example_inputs, output_names=output_names, @@ -3334,6 +3369,7 @@ def build_decode_specialization( kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + dflash_block_size: Optional[int] = None, **kwargs, ): """ @@ -3377,6 +3413,9 @@ def build_decode_specialization( spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None + if self.dflash_tlm or self.dflash_dlm: + spec["seq_len"] = dflash_block_size + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -3397,6 +3436,7 @@ def compile( batch_size: int = 1, full_batch_size: Optional[int] = None, kv_cache_batch_size: Optional[int] = None, + dflash_block_size: Optional[int] = None, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, @@ -3612,6 +3652,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + dflash_block_size=dflash_block_size, ) if decode_spec: specializations.append(decode_spec) @@ -3624,6 +3665,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + dflash_block_size=dflash_block_size, prefill_only=prefill_only, ) if decode_spec: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ec34ebb046..6bb293f6a0 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -458,6 +458,18 @@ QEffQwen3ForCausalLM, QEffQwen3Model, ) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3Attention as QEffQwen3DFlashAttention, +) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3DecoderLayer as QEffQwen3DFlashDecoderLayer, +) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3ForCausalLM as QEffQwen3DFlashForCausalLM, +) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3Model as QEffQwen3DFlashModel, +) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( QEffPrefillChunkedQwen3MoeSparseMoeBlock, QEffQwen3MoeAttention, @@ -954,6 +966,138 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) - return model, transformed +class DFlashTransform(ModuleMappingTransform): + """ + Replaces standard QEff Qwen3 modules with DFlash-specialized versions when dflash_dlm=True. + Applied after KVCacheTransform so it operates on already-transformed QEff modules. + """ + + _module_mapping = { + QEffQwen3Attention: QEffQwen3DFlashAttention, + QEffQwen3DecoderLayer: QEffQwen3DFlashDecoderLayer, + QEffQwen3Model: QEffQwen3DFlashModel, + QEffQwen3ForCausalLM: QEffQwen3DFlashForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + if not (qaic_config and qaic_config.get("dflash_dlm", False)): + return model, False + return super().apply(model) + + +class DFlashTLMTransform: + """ + Adds fc and hidden_norm layers to the inner model, loads their weights from the TLM + checkpoint, and sets target_layer_ids — activating the TLM hidden-state collection + path in QEffQwen3Model / QEffLlamaModel forward. + + Triggered when target_layer_ids is present in qaic_config. + Works without modifying the transformers library: weights that were silently dropped + as unexpected keys during from_pretrained are re-read here from the checkpoint. + """ + + @classmethod + def _load_tlm_weights(cls, checkpoint_path: str) -> dict: + """Read only fc and hidden_norm weights from a local checkpoint directory.""" + from pathlib import Path + + import torch + + path = Path(checkpoint_path) + target_keys = {"model.fc.weight", "model.hidden_norm.weight"} + dlm_model_weights: dict = {} + + # safetensors (preferred modern format) + sf_files = sorted(path.glob("*.safetensors")) + if sf_files: + try: + from safetensors import safe_open + + for sf in sf_files: + with safe_open(str(sf), framework="pt", device="cpu") as f: + for key in f.keys(): + if key in target_keys and key not in dlm_model_weights: + dlm_model_weights[key] = f.get_tensor(key) + if len(dlm_model_weights) == len(target_keys): + break + return dlm_model_weights + except ImportError: + warnings.warn("safetensors not installed; falling back to .bin loading.") + + # pytorch .bin fallback + bin_files = sorted(path.glob("pytorch_model*.bin")) + for bf in bin_files: + sd = torch.load(str(bf), map_location="cpu", weights_only=True) + for key in target_keys: + if key in sd and key not in dlm_model_weights: + dlm_model_weights[key] = sd[key] + if len(dlm_model_weights) == len(target_keys): + break + + return dlm_model_weights + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + target_layer_ids = qaic_config.get("target_layer_ids", None) if qaic_config else None + if not target_layer_ids: + return model, False + + n = len(target_layer_ids) + inner = model.model # QEffQwen3Model or QEffLlamaModel + hidden_size = model.config.hidden_size + model_type = getattr(model.config, "model_type", "") + + # --- add fc and hidden_norm only if not already present --- + # tlm_maker.py (and similar scripts) inject weights before calling the + # constructor directly; in that case we must not overwrite them. + layers_already_present = hasattr(inner, "fc") and hasattr(inner, "hidden_norm") + + if not layers_already_present: + # add fc + inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False) + + # add hidden_norm using the same RMSNorm class as the model + if "qwen3" in model_type: + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + + inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=model.config.rms_norm_eps) + elif "llama" in model_type: + from transformers.models.llama.modeling_llama import LlamaRMSNorm + + inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=model.config.rms_norm_eps) + else: + warnings.warn( + f"DFlashTLMTransform: unknown model_type '{model_type}'. Using nn.RMSNorm as hidden_norm fallback." + ) + inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(model.config, "rms_norm_eps", 1e-6)) + + # load fc / hidden_norm weights from checkpoint + # pretrained_model_name_or_path is stored in qaic_config by from_pretrained + ckpt_path = (qaic_config or {}).get("pretrained_model_name_or_path", None) + if ckpt_path: + weights = cls._load_tlm_weights(ckpt_path) + if "model.fc.weight" in weights: + inner.fc.weight.data.copy_(weights["model.fc.weight"]) + if "model.hidden_norm.weight" in weights: + inner.hidden_norm.weight.data.copy_(weights["model.hidden_norm.weight"]) + if not weights: + warnings.warn( + "DFlashTLMTransform: fc/hidden_norm weights not found in checkpoint " + f"at '{ckpt_path}'. Layers are randomly initialized." + ) + else: + warnings.warn( + "DFlashTLMTransform: no checkpoint path available — " + "use QEFFAutoModelForCausalLM.from_pretrained() so the path is " + "stored automatically, or set qaic_config['pretrained_model_name_or_path']." + ) + + # --- activate TLM collection path in forward --- + inner.target_layer_ids = target_layer_ids + return model, True + + class VlmKVOffloadTransform(ModuleMappingTransform): # supported architectures _module_mapping = { diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 9844c91016..c2b0ab668d 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -324,10 +324,16 @@ def forward( sin = self.sin_cached[position_ids].unsqueeze(1) cos = self.cos_cached[position_ids].unsqueeze(1) - for decoder_layer in self.layers: + self.target_layer_ids = getattr(self, "target_layer_ids", None) + target_hidden_list = [] + + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + if self.target_layer_ids and idx in self.target_layer_ids: + target_hidden_list.append(hidden_states) + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -350,6 +356,16 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + if self.target_layer_ids: + target_hidden = torch.cat(target_hidden_list, dim=-1) + target_hidden_fc = self.fc(target_hidden) + target_hidden_final = self.hidden_norm(target_hidden_fc) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=target_hidden_final, + ) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -393,6 +409,9 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + if getattr(self.model, "target_layer_ids", None): + output_hidden_states = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -408,6 +427,20 @@ def forward( # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if getattr(self.model, "target_layer_ids", None): + target_hidden = outputs.hidden_states + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + predicted_token_ids = logits.argmax(dim=-1).to(torch.int32) + return CausalLMOutputWithPast( + loss=None, + logits=predicted_token_ids, + past_key_values=outputs.past_key_values, + hidden_states=target_hidden, + attentions=outputs.attentions, + ) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states).float() diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py b/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py new file mode 100644 index 0000000000..11801490c1 --- /dev/null +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py @@ -0,0 +1,547 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch Qwen3 model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3Config, + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3Model, + Qwen3RotaryEmbedding, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +def _create_mask( + position_ids: torch.Tensor, + target_length: int, + # valid_kv_length: int, + sliding_window: Optional[int] = None, + start_index: Optional[int] = 0, +): + """ + Args: + position_ids: [1, target_length] + target_length: 3 * block_size + valid_kv_length: number of valid KV cache positions + sliding_window: optional local attention window + start_index: offset into KV cache (default 0) + + Returns: + attention_mask: [1, 1, num_queries, target_length] + """ + + device = position_ids.device + + num_queries = position_ids.shape[1] # = block_size (B) + + # ---- Step 1: Create base KV validity mask ---- + # Shape: [target_length] + + kv_positions = torch.arange(start_index, start_index + target_length, device=device) + valid_kv_mask = kv_positions > ( + start_index + position_ids.max() + ) # [position_max is 32( input query podtion_ids), 0 to 31] + + # ---- Step 2: Expand to [num_queries, target_length] ---- + attention_mask = valid_kv_mask.unsqueeze(0).expand(num_queries, target_length) + + # ---- Step 4: Add batch & head dimensions ---- + # Final shape: [1, 1, B, 3B] + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) + + return attention_mask + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Qwen3Config, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + # print_stats(emb, "RotaryEmbedding/emb") + cos_cached = emb.cos().to(dtype) + sin_cached = emb.sin().to(dtype) + + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + cos_out = self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling + sin_out = self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling + + return (cos_out, sin_out) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_len = q.size(-2) + + q_embed = (q * cos[..., :q_len, :]) + (rotate_half(q) * sin[..., :q_len, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def qeff_apply_rope_two_streams(q_noise, k_ctx, k_noise, cos, sin, pos_ctx, pos_noise, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos_t = cos[pos_ctx].unsqueeze(unsqueeze_dim) + sin_t = sin[pos_ctx].unsqueeze(unsqueeze_dim) + + rotate_half_k_ctx = rotate_half(k_ctx) + + k_t_embed = k_ctx * cos_t + rotate_half_k_ctx * sin_t + + # ---- NOISE ---- + cos_n = cos[pos_noise].unsqueeze(unsqueeze_dim) + sin_n = sin[pos_noise].unsqueeze(unsqueeze_dim) + + rotate_half_q_noise = rotate_half(q_noise) + q_n_embed = q_noise * cos_n + rotate_half_q_noise * sin_n + + rotate_half_k_noise = rotate_half(k_noise) + + k_n_embed = k_noise * cos_n + rotate_half_k_noise * sin_n + + return q_n_embed, k_t_embed, k_n_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffQwen3Attention(Qwen3Attention): + """ + Copied from Qwen3Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + """ + + def __qeff_init__(self): + self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + self.dflash_dlm = True + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_ids_target: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + + kwargs.pop("output_attentions", None) + kwargs.pop("return_dict", None) + kwargs.pop("labels", None) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + + k_ctx = self.k_norm(k_ctx.view(bsz, ctx_len, -1, self.head_dim)).transpose(1, 2) + k_noise = self.k_norm(k_noise.view(bsz, q_len, -1, self.head_dim)).transpose(1, 2) + + v_ctx = (v_ctx.view(bsz, ctx_len, -1, self.head_dim)).transpose(1, 2) + v_noise = (v_noise.view(bsz, q_len, -1, self.head_dim)).transpose(1, 2) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # Assuming position_id [77,78,79,80, 75,76,-1,-1] first 4 pos id of noise next four position_id for target + + cos, sin = self.rotary_emb(v_ctx, seq_len=kv_seq_len) + query_states, k_ctx, k_noise = qeff_apply_rope_two_streams( + query_states, k_ctx, k_noise, cos, sin, position_ids_target, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids_target} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + + # first write for target positon_id + past_key_value.write_only(k_ctx, v_ctx, self.layer_idx, cache_kwargs) + + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(k_noise, v_noise, self.layer_idx, cache_kwargs) + + attention_interface = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class QEffQwen3DecoderLayer(Qwen3DecoderLayer): + """ + Copied from Qwen3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor = None, + position_ids_target: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + position_ids_target=position_ids_target, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class QEffQwen3Model(Qwen3Model): + """ + Copied from Qwen3Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update causal attention mask + """ + + def forward( + self, + target_hidden: torch.Tensor = None, + noise_embeds: torch.FloatTensor = None, + position_ids_target: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (noise_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if input_ids is not None and noise_embeds is None: + noise_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: ####? + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + noise_embeds.shape[1], device=noise_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze( + 0 + ) ###? no need for this because we input it ( where the tokens will be filled ) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = noise_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = decoder_layer( + hidden_states, + target_hidden=target_hidden, + position_ids_target=position_ids_target, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class QEffQwen3ForCausalLM(Qwen3ForCausalLM): + """ + Copied from Qwen3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + target_hidden: torch.Tensor = None, + noise_embeds: torch.FloatTensor = None, + position_ids_target: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # block_size: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + target_hidden=target_hidden, + noise_embeds=noise_embeds, + position_ids_target=position_ids_target, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + # block_size=block_size, + output_hidden_states=output_hidden_states, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 339e4f4dac..58cef2caa9 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -14,7 +14,7 @@ QEFF_CACHE_DIR_NAME = "qeff_cache" ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 +ONNX_EXPORT_EXAMPLE_SEQ_LEN = 16 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_MAX_NUM_IMAGES = 1 diff --git a/examples/performance/dflash/README.md b/examples/performance/dflash/README.md new file mode 100644 index 0000000000..0a17946ca0 --- /dev/null +++ b/examples/performance/dflash/README.md @@ -0,0 +1,87 @@ +# DFlash SPD Examples + +Two entry points wrap the SPD compile + run pipeline. + +#### basic_inference.py +Basic DFlash usage with dense language models. + +**Supported Models:** +- Llama3.1-8B-Instruct +- Qwen3-4B +- Qwen3-8B + +## Single prompt + +```bash +python basic_inference.py --model_name Qwen3-4B \ + --prompt "Explain speculative decoding in two sentences." +``` + +## Benchmark (dataset) + +```bash +python benchmark.py --model_name Qwen3-4B --dataset humaneval +``` + +## `--model_name` + +Accepts either the short key or the full HF repo path (case-insensitive): + +``` +Qwen3-4B +Qwen/Qwen3-4B +qwen3-4b +``` + +Run either script with `--help` to see the full supported list. + +## Skipping compile (reuse QPCs) + +Pass either or both. Whichever side is supplied skips its compile step; the +other side still compiles. + +```bash +python basic_inference.py --model_name Qwen3-4B \ + --tlm_qpc /path/to/tlm/qpc \ + --dlm_qpc /path/to/dlm/qpc \ + --prompt "Hello" +``` + +## Common flags + +| Flag | Default | Notes | +|---|---|---| +| `--tlm_devices` | `0,1,2,3` | TLM device IDs | +| `--dlm_devices` | `0,1,2,3` | DLM device IDs | +| `--tlm_cores` / `--dlm_cores` | `8` | per-side core count | +| `--ctx_len` | `4096` | | +| `--prefill_seq_len` | `128` | | +| `--generation_len` | `1024` (benchmark) / `256` (single) | | +| `--hf_token` | `$HF_TOKEN` | required for gated repos | +| `--tlm_hf_path` | from `MODEL_MAP` | required when the map entry has `None` | + +`benchmark.py` only: + +| Flag | Default | +|---|---| +| `--dataset` | `humaneval` (also: `gsm8k`, `math500`) | +| `--num_samples` | `0` (= all) | +| `--iteration` | `300` | +| `--output_dir` | `./results-` | + +`basic_inference.py` only: + +| Flag | Default | +|---|---| +| `--prompt` | *(required)* | +| `--category` | `""` (math / coding / reasoning / …) | + +## Adding a new model + +Edit `MODEL_MAP` in `benchmark.py`: + +```python +"": ("", ""), +``` + +`basic_inference.py` reuses the same map automatically. diff --git a/examples/performance/dflash/basic_inference.py b/examples/performance/dflash/basic_inference.py new file mode 100644 index 0000000000..712788af7b --- /dev/null +++ b/examples/performance/dflash/basic_inference.py @@ -0,0 +1,179 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Single-entry SPD single-prompt runner. + +Given a TLM model_name (short name OR full HF repo path) and a prompt, this +script: + 1. Looks up the matching DFlash DLM repo on Hugging Face. + 2. Compiles TLM + DLM QPCs (only the side(s) not provided via + --tlm_qpc / --dlm_qpc). + 3. Runs the SPD single-prompt inference script. + +Examples: + # Compile + run with all defaults + python basic_inference.py --model_name Qwen3-4B \\ + --prompt "Explain speculative decoding in two sentences." + + # Full HF path also accepted + python basic_inference.py --model_name Qwen/Qwen3-4B \\ + --prompt "Hello" + + # Reuse pre-compiled QPCs + python basic_inference.py --model_name Qwen3-4B \\ + --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc \\ + --prompt "What is 17 * 23?" +""" + +import argparse +import os +import subprocess +import sys + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..", "..")) +sys.path.insert(0, REPO_ROOT) +sys.path.insert(0, THIS_DIR) + +from utils import MODEL_MAP, compile_dlm_qpc, compile_tlm_qpc, resolve_model_name # noqa: E402 + +from QEfficient.utils.logging_utils import logger # noqa: E402 + + +def parse_device_list(s): + return [int(x) for x in s.split(",") if x.strip() != ""] + + +def parse_args(): + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument( + "--model_name", + required=True, + type=resolve_model_name, + help="TLM name — either the short key (e.g. 'Qwen3-4B') or " + "the full HF repo path (e.g. 'Qwen/Qwen3-4B'). " + f"Supported: {', '.join(MODEL_MAP.keys())}", + ) + p.add_argument("--prompt", required=True, help="Input prompt text.") + p.add_argument("--category", default="", help="Prompt category for formatting (math, coding, reasoning, …).") + p.add_argument( + "--format_prompt", + action="store_true", + help="If set, wrap the prompt with the category-specific reasoning/coding template before sending to the model. " + "Off by default — the prompt is used verbatim.", + ) + p.add_argument("--tlm_hf_path", default=None, help="Override TLM HF repo (required if mapping has None).") + + # Optional pre-built QPCs (skip compilation) + p.add_argument("--tlm_qpc", default=None, help="Pre-compiled TLM qpc dir (skip TLM compile).") + p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") + + # Devices / cores + p.add_argument( + "--tlm_devices", + type=parse_device_list, + default=[0, 1, 2, 3], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) + p.add_argument( + "--dlm_devices", + type=parse_device_list, + default=[0, 1, 2, 3], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) + p.add_argument("--tlm_cores", type=int, default=8) + p.add_argument("--dlm_cores", type=int, default=8) + + # Compile / run knobs + p.add_argument("--ctx_len", type=int, default=4096) + p.add_argument("--prefill_seq_len", type=int, default=128) + p.add_argument("--generation_len", type=int, default=256) + p.add_argument("--iteration", type=int, default=300) + + p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN")) + return p.parse_args() + + +def main(): + args = parse_args() + + tlm_repo_default, dlm_repo = MODEL_MAP[args.model_name] + tlm_repo = args.tlm_hf_path or tlm_repo_default + if tlm_repo is None: + raise SystemExit(f"No default TLM HF path for '{args.model_name}'. Pass --tlm_hf_path.") + + if args.tlm_qpc: + logger.info(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") + tlm_qpc = args.tlm_qpc + else: + tlm_qpc = compile_tlm_qpc( + tlm_repo, + dlm_repo, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.tlm_cores, + num_devices=len(args.tlm_devices), + hf_token=args.hf_token, + ) + + if args.dlm_qpc: + logger.info(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") + dlm_qpc = args.dlm_qpc + else: + dlm_qpc = compile_dlm_qpc( + tlm_repo, + dlm_repo, + ctx_len=args.ctx_len, + num_cores=args.dlm_cores, + num_devices=len(args.dlm_devices), + hf_token=args.hf_token, + ) + logger.info(f"TLM qpc : {tlm_qpc}") + logger.info(f"DLM qpc : {dlm_qpc}") + + eval_script = os.path.join(THIS_DIR, "dflash_spd_single_prompt.py") + cmd = [ + sys.executable, + eval_script, + "--prompt", + args.prompt, + "--tlm_qpc", + tlm_qpc, + "--dlm_qpc", + dlm_qpc, + "--tlm_model_name", + tlm_repo, + "--dlm_model_name", + dlm_repo, + "--iteration", + str(args.iteration), + "--ctx_len", + str(args.ctx_len), + "--generation_len", + str(args.generation_len), + "--tlm_devices", + *[str(d) for d in args.tlm_devices], + "--dlm_devices", + *[str(d) for d in args.dlm_devices], + ] + if args.hf_token: + cmd += ["--hf_token", args.hf_token] + if args.category: + cmd += ["--category", args.category] + if args.format_prompt: + cmd += ["--format_prompt"] + + logger.info("\n>>> launching SPD single-prompt inference:") + logger.info(" ".join(cmd)) + rc = subprocess.run(cmd, check=False).returncode + if rc != 0: + raise SystemExit(f"single-prompt inference exited with rc={rc}") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/benchmark.py b/examples/performance/dflash/benchmark.py new file mode 100644 index 0000000000..49cf98d2ca --- /dev/null +++ b/examples/performance/dflash/benchmark.py @@ -0,0 +1,173 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Single-entry SPD runner. + +Given just a TLM model_name (short name from the supported table), this script: + 1. Looks up the matching DFlash DLM repo on Hugging Face. + 2. Reads hidden_size and block_size from the DLM config. + 3. Compiles TLM + DLM QPCs (if --tlm_qpc / --dlm_qpc are not supplied). + 4. Runs the SPD benchmark on the chosen dataset (default: humaneval). + +Examples: + # Compile + run with all defaults + python benchmark.py --model_name Qwen3-4B + + # Reuse pre-compiled QPCs (no compilation step) + python benchmark.py --model_name Qwen3-4B \\ + --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc + + # Custom devices / cores / dataset + python benchmark.py --model_name Llama-3.1-8B-Instruct \\ + --tlm_devices 0,1,2,3 --dlm_devices 4,5,6,7 \\ + --tlm_cores 8 --dlm_cores 8 --dataset gsm8k +""" + +import argparse +import os +import subprocess +import sys + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..", "..")) +sys.path.insert(0, REPO_ROOT) +sys.path.insert(0, THIS_DIR) + +from utils import MODEL_MAP, compile_dlm_qpc, compile_tlm_qpc, resolve_model_name # noqa: E402 + + +def parse_device_list(s): + return [int(x) for x in s.split(",") if x.strip() != ""] + + +def parse_args(): + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument( + "--model_name", + required=True, + type=resolve_model_name, + help="TLM name — either the short key (e.g. 'Qwen3-4B') or " + "the full HF repo path (e.g. 'Qwen/Qwen3-4B'). " + f"Supported: {', '.join(MODEL_MAP.keys())}", + ) + p.add_argument("--tlm_hf_path", default=None, help="Override TLM HF repo (required if mapping has None).") + + # Optional pre-built QPCs (skip compilation) + p.add_argument("--tlm_qpc", default=None, help="Pre-compiled TLM qpc dir (skip TLM compile).") + p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") + + # Devices / cores + p.add_argument( + "--tlm_devices", + type=parse_device_list, + default=[0, 1, 2, 3], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) + p.add_argument( + "--dlm_devices", + type=parse_device_list, + default=[0, 1, 2, 3], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) + p.add_argument("--tlm_cores", type=int, default=8) + p.add_argument("--dlm_cores", type=int, default=8) + + # Compile / run knobs + p.add_argument("--ctx_len", type=int, default=4096) + p.add_argument("--prefill_seq_len", type=int, default=128) + p.add_argument("--generation_len", type=int, default=1024) + p.add_argument("--iteration", type=int, default=300) + + # Dataset / output + p.add_argument("--dataset", default="humaneval", choices=["humaneval", "gsm8k", "math500"]) + p.add_argument("--num_samples", type=int, default=0, help="0 = all samples") + p.add_argument("--output_dir", default=None, help="Default: ./results-") + p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN")) + return p.parse_args() + + +def main(): + args = parse_args() + + tlm_repo_default, dlm_repo = MODEL_MAP[args.model_name] + tlm_repo = args.tlm_hf_path or tlm_repo_default + if tlm_repo is None: + raise SystemExit(f"No default TLM HF path for '{args.model_name}'. Pass --tlm_hf_path.") + + if args.tlm_qpc: + print(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") + tlm_qpc = args.tlm_qpc + else: + tlm_qpc = compile_tlm_qpc( + tlm_repo, + dlm_repo, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.tlm_cores, + num_devices=len(args.tlm_devices), + hf_token=args.hf_token, + ) + + if args.dlm_qpc: + print(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") + dlm_qpc = args.dlm_qpc + else: + dlm_qpc = compile_dlm_qpc( + tlm_repo, + dlm_repo, + ctx_len=args.ctx_len, + num_cores=args.dlm_cores, + num_devices=len(args.dlm_devices), + hf_token=args.hf_token, + ) + print(f"TLM qpc : {tlm_qpc}") + print(f"DLM qpc : {dlm_qpc}") + + output_dir = args.output_dir or os.path.join(THIS_DIR, f"results-{args.model_name}") + + eval_script = os.path.join(THIS_DIR, "dflash_spd_benchmark.py") + cmd = [ + sys.executable, + eval_script, + "--dataset", + args.dataset, + "--tlm_qpc", + tlm_qpc, + "--dlm_qpc", + dlm_qpc, + "--tlm_model_name", + tlm_repo, + "--dlm_model_name", + dlm_repo, + "--iteration", + str(args.iteration), + "--ctx_len", + str(args.ctx_len), + "--generation_len", + str(args.generation_len), + "--tlm_devices", + *[str(d) for d in args.tlm_devices], + "--dlm_devices", + *[str(d) for d in args.dlm_devices], + "--output_dir", + output_dir, + ] + if args.hf_token: + cmd += ["--hf_token", args.hf_token] + if args.num_samples and args.num_samples > 0: + cmd += ["--num_samples", str(args.num_samples)] + + print("\n>>> launching SPD eval:") + print(" ".join(cmd)) + rc = subprocess.run(cmd, check=False).returncode + if rc != 0: + raise SystemExit(f"SPD eval exited with rc={rc}") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/dflash_spd_benchmark.py b/examples/performance/dflash/dflash_spd_benchmark.py new file mode 100644 index 0000000000..643e05574d --- /dev/null +++ b/examples/performance/dflash/dflash_spd_benchmark.py @@ -0,0 +1,601 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import csv +import os +import time +from typing import Optional + +import numpy as np +import torch +import transformers +from rich.console import Console +from rich.markup import escape +from utils import load_and_process_dataset + +from QEfficient.generation.cloud_infer import QAICInferenceSession + +torch.manual_seed(42) +np.random.seed(42) + +console = Console() + + +# ===== METRICS ===== + + +class SpecDecodingMetrics: + def __init__(self, block_size: int = 10): + self.block_size = block_size + self.total_prefill_time = 0.0 + self.tlm_decode_time = 0.0 + self.dlm_decode_time = 0.0 + self.total_accepted_tokens = 0 + self.total_rejected_tokens = 0 + self.total_generated_tokens = 0 + self.num_total_iters = 0 + self.acceptance_history = [] + self.generated_ids: list = [] + self.generated_sources: list = [] # "dlm" or "tlm" per token + + def acceptance_rate(self) -> float: + if self.num_total_iters == 0: + return 0.0 + return self.total_generated_tokens / self.num_total_iters + + def dlm_tok_rate(self) -> float: + if self.dlm_decode_time <= 0: + return 0.0 + num_tok_drafted = self.block_size * self.num_total_iters + return num_tok_drafted / self.dlm_decode_time + + def tlm_tok_rate(self) -> float: + if self.tlm_decode_time <= 0: + return 0.0 + ar = self.acceptance_rate() + num_tok_tlm = self.total_generated_tokens / (1 + ar) if (1 + ar) > 0 else 0.0 + return num_tok_tlm / self.tlm_decode_time + + def spd_tok_rate(self) -> float: + total_decode_s = self.tlm_decode_time + self.dlm_decode_time + if total_decode_s <= 0: + return 0.0 + return self.total_generated_tokens / total_decode_s + + +# ===== INFERENCE ===== + + +def run_spd_inference_single( + prompt_text: str, + tokenizer, + dlm_session: QAICInferenceSession, + tlm_session: QAICInferenceSession, + mask_token_id: int, + vocab_size: int, + prompt_chunk_size: int, + ctx_len: int = 4096, + block_size: int = 16, + max_iterations: int = 300, + hidden_size: int = 4096, + generation_len: int = 256, +) -> SpecDecodingMetrics: + eos_token_ids = {tokenizer.eos_token_id} if tokenizer.eos_token_id is not None else set() + + prompt = [prompt_text] + batch_size = 1 + metrics = SpecDecodingMetrics(block_size=block_size) + + # Tokenize + tlm_inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = tlm_inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prompt_chunk_size) # ceil divide without float + padded_len = num_chunks * prompt_chunk_size # Convert to a multiple of padded_len + tlm_inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + tlm_inputs["position_ids"] = np.where(tlm_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + tlm_inputs.pop("token_type_ids", None) + tlm_inputs = {k: torch.from_numpy(v) for k, v in tlm_inputs.items()} + tlm_inputs.pop("past_key_values", None) + tlm_inputs = {k: v.detach().numpy() for k, v in tlm_inputs.items()} + prompt_len = padded_len + + generated_ids = np.full((batch_size, ctx_len - prompt_len), tokenizer.pad_token_id) + + # Set output buffers + tlm_session.set_buffers({"logits": np.zeros((batch_size, prompt_chunk_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) + dlm_session.set_buffers({"logits": np.zeros((batch_size, block_size, vocab_size), dtype=np.float32)}) + + tlm_cache_index = np.array([0]) + dlm_cache_index = np.array([0]) + dlm_inputs = {} + + # ===== PREFILL ===== + prefill_start = time.time() + num_sub_blocks = prompt_chunk_size // block_size + remainder = prompt_chunk_size % block_size + + for pi in range(num_chunks - 1): + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_prefill_outputs = tlm_session.run(chunk_inputs) + ## Add support for when the prefill_seq_len is more than block_size + for sub_i in range(num_sub_blocks): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) + dlm_session.run(dlm_inputs) + + ## Add support when prefill_seq_len is not a multiple of block_size + if remainder > 0: + sub_start = num_sub_blocks * block_size + target_hidden_rem = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden_rem[:, :remainder, :] = tlm_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target_rem = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target_rem[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["target_hidden"] = target_hidden_rem + dlm_inputs["position_ids_target"] = pos_ids_target_rem + dlm_inputs["position_ids"] = pos_ids_target_rem + block_size + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) + dlm_session.run(dlm_inputs) + tlm_cache_index[0] += prompt_chunk_size + dlm_cache_index[0] += prompt_chunk_size + + # Last prefill chunk + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_last_prefill_outputs = tlm_session.run(chunk_inputs) + last_prefill_pos_in_chunk = chunk_inputs["position_ids"].argmax() + new_tlm_token = tlm_last_prefill_outputs["logits"][:, last_prefill_pos_in_chunk] + + ## Add support for when the prefill_seq_len is more than block_size + last_sub = last_prefill_pos_in_chunk // block_size + for sub_i in range(last_sub): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_last_prefill_outputs["hidden_states"][ + :, sub_start : sub_start + block_size, : + ] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) + dlm_session.run(dlm_inputs) + + input_ids = np.full((1, block_size), mask_token_id, dtype=np.int64) + input_ids[:, 0] = new_tlm_token + sub_start = last_sub * block_size + + ## Add support when prefill_seq_len is not a multiple of block_size + if last_sub < num_sub_blocks: + target_hidden = tlm_last_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + else: + target_hidden = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden[:, :remainder, :] = tlm_last_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["position_ids_target"] = pos_ids_target + + dlm_inputs["position_ids"] = np.arange( + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1, + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1 + block_size, + ).reshape(1, -1) + dlm_inputs["input_ids"] = input_ids + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + + metrics.total_prefill_time += time.time() - prefill_start + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + # ===== DECODE ===== + spd_counter_idx = tlm_cache_index[0] + last_prefill_pos_in_chunk + gen_idx = 0 + iteration_count = 0 + continue_generation = True + + tlm_session.set_buffers({"logits": np.zeros((batch_size, block_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) + + while gen_idx < generation_len and iteration_count < max_iterations and continue_generation: + iteration_count += 1 + dlm_candidates[:, 0] = new_tlm_token + + tlm_decode_start = time.time() + tlm_decode_outputs = tlm_session.run( + { + "input_ids": dlm_candidates, + "position_ids": dlm_inputs["position_ids"], + } + ) + metrics.tlm_decode_time += time.time() - tlm_decode_start + + tlm_logits = tlm_decode_outputs["logits"] + target_hidden = tlm_decode_outputs["hidden_states"] + + accepted_length = 0 + rejected_flag = False + + for spec_idx in range(block_size - 1): + tlm_token = tlm_logits[:, spec_idx] + dlm_token = dlm_candidates[:, spec_idx + 1] + if tlm_token == dlm_token: + accepted_length += 1 + metrics.total_accepted_tokens += 1 + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = dlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(dlm_token[0])) + metrics.generated_sources.append("dlm") + else: + metrics.total_rejected_tokens += block_size - spec_idx - 1 + rejected_flag = True + new_tlm_token = tlm_token + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(tlm_token[0])) + metrics.generated_sources.append("tlm") + break + + metrics.acceptance_history.append(accepted_length) + metrics.total_generated_tokens += accepted_length + 1 + + if not rejected_flag: + new_tlm_token = tlm_logits[:, block_size - 1] + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = new_tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(new_tlm_token[0])) + metrics.generated_sources.append("tlm") + + # EOS check + dlm_candidate_ids = list(dlm_candidates[0, 1 : accepted_length + 1]) + this_iter_gen_ids = dlm_candidate_ids + [new_tlm_token[0]] + for tok_id in this_iter_gen_ids: + if tok_id in eos_token_ids: + continue_generation = False + break + + if not continue_generation: + break + + # Next DLM iteration + dlm_decode_start = time.time() + dlm_inputs["position_ids_target"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape( + 1, -1 + ) + spd_counter_idx += accepted_length + 1 + dlm_inputs["position_ids_target"][:, accepted_length + 1 :] = -1 + dlm_inputs["position_ids"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape(1, -1) + input_ids[:, 0] = new_tlm_token + dlm_inputs["input_ids"] = input_ids + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + metrics.dlm_decode_time += time.time() - dlm_decode_start + + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + metrics.num_total_iters = iteration_count + return metrics + + +# ===== DATASET CONFIG ===== + +DATASET_CONFIG = { + "gsm8k": { + "hf_path": "openai/gsm8k", + "hf_name": "main", + "split": "test", + "prompt_field": "question", + "label": "GSM8K", + }, + "math500": { + "hf_path": "HuggingFaceH4/MATH-500", + "hf_name": None, + "split": "test", + "prompt_field": "problem", + "label": "MATH-500", + }, + "humaneval": { + "hf_path": "openai/openai_humaneval", + "hf_name": None, + "split": "test", + "prompt_field": "prompt", + "label": "HumanEval", + }, +} + + +# ===== EVALUATION LOOP ===== + +PER_SAMPLE_FIELDS = [ + "dataset", + "sample_idx", + "acceptance_rate", + "dlm_tps", + "tlm_tps", + "spd_tps", + "total_generated_tokens", + "num_iters", + "prefill_time_s", + "tlm_decode_time_s", + "dlm_decode_time_s", +] + +SUMMARY_FIELDS = [ + "dataset", + "num_evaluated", + "num_total", + "avg_acceptance_rate", + "min_acceptance_rate", + "max_acceptance_rate", + "avg_dlm_tps", + "min_dlm_tps", + "max_dlm_tps", + "avg_tlm_tps", + "min_tlm_tps", + "max_tlm_tps", + "avg_spd_tps", + "min_spd_tps", + "max_spd_tps", +] + + +def _write_per_sample_csv(rows: list, path: str): + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=PER_SAMPLE_FIELDS) + writer.writeheader() + writer.writerows(rows) + console.print(f"[green]Per-sample CSV → {path}[/green]") + + +def _append_summary_csv(row: dict, path: str): + write_header = not os.path.exists(path) + with open(path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=SUMMARY_FIELDS) + if write_header: + writer.writeheader() + writer.writerow(row) + console.print(f"[green]Summary CSV → {path}[/green]") + + +def evaluate_dataset( + dataset_name: str, + tokenizer, + dlm_session, + tlm_session, + mask_token_id: int, + vocab_size: int, + prompt_chunk_size: int, + ctx_len: int = 4096, + block_size: int = 10, + max_iterations: int = 300, + hidden_size: int = 4096, + generation_len: int = 1024, + num_samples: Optional[int] = None, + output_dir: str = "./results", +): + cfg = DATASET_CONFIG[dataset_name] + console.print(f"[bold blue]Loading {cfg['label']} dataset...[/bold blue]") + dataset = load_and_process_dataset(dataset_name) + + if num_samples is not None: + # dataset = dataset.shuffle(seed=0).select(range(min(num_samples, len(dataset)))) + dataset = dataset.select(range(min(num_samples, len(dataset)))) + console.print(f"[green]✓ Loaded {len(dataset)} {cfg['label']} problems[/green]") + + all_ar, all_dlm_tps, all_tlm_tps, all_spd_tps = [], [], [], [] + per_sample_rows = [] + + for i, sample in enumerate(dataset): + user_content = sample["turns"][0] + messages = [{"role": "user", "content": user_content}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + console.print(f"[cyan]({i + 1}/{len(dataset)})[/cyan] Input: {user_content[:80].strip()}") + + # try: + metrics = run_spd_inference_single( + prompt_text=prompt_text, + tokenizer=tokenizer, + dlm_session=dlm_session, + tlm_session=tlm_session, + vocab_size=vocab_size, + prompt_chunk_size=prompt_chunk_size, + ctx_len=ctx_len, + block_size=block_size, + max_iterations=max_iterations, + hidden_size=hidden_size, + generation_len=generation_len, + mask_token_id=mask_token_id, + ) + + ar = metrics.acceptance_rate() + dlm_tps = metrics.dlm_tok_rate() + tlm_tps = metrics.tlm_tok_rate() + spd_tps = metrics.spd_tok_rate() + + all_ar.append(ar) + all_dlm_tps.append(dlm_tps) + all_tlm_tps.append(tlm_tps) + all_spd_tps.append(spd_tps) + + per_sample_rows.append( + { + "dataset": dataset_name, + "sample_idx": i, + "acceptance_rate": round(ar, 4), + "dlm_tps": round(dlm_tps, 2), + "tlm_tps": round(tlm_tps, 2), + "spd_tps": round(spd_tps, 2), + "total_generated_tokens": metrics.total_generated_tokens, + "num_iters": metrics.num_total_iters, + "prefill_time_s": round(metrics.total_prefill_time, 4), + "tlm_decode_time_s": round(metrics.tlm_decode_time, 4), + "dlm_decode_time_s": round(metrics.dlm_decode_time, 4), + } + ) + + console.print(f" AR={ar:.2f} DLM={dlm_tps:.1f} tok/s TLM={tlm_tps:.1f} tok/s SPD={spd_tps:.1f} tok/s") + + output_parts = ["Output: "] + for tok_id, source in zip(metrics.generated_ids, metrics.generated_sources): + text = escape(tokenizer.decode([tok_id], skip_special_tokens=True)) + if source == "dlm": + output_parts.append(f"[blue]{text}[/blue]") + else: + output_parts.append(f"[white]{text}[/white]") + console.print("".join(output_parts)) + + # except Exception as e: + # console.print(f"[red] ✗ Error on sample {i}: {e}[/red]") + + # ===== SUMMARY ===== + if all_ar: + w = 46 + print("\n" + "=" * w) + print(f" {cfg['label']} SPD Evaluation — Averages") + print("=" * w) + print(f" {'Metric':<30} {'Avg':>6} {'Min':>6} {'Max':>6}") + print("-" * w) + for name, vals in [ + ("Acceptance Rate (tok/iter)", all_ar), + ("DLM Throughput (tok/s)", all_dlm_tps), + ("TLM Throughput (tok/s)", all_tlm_tps), + ("SPD Decode Speed (tok/s)", all_spd_tps), + ]: + print(f" {name:<30} {np.mean(vals):>6.2f} {np.min(vals):>6.2f} {np.max(vals):>6.2f}") + print("=" * w) + print(f" Evaluated {len(all_ar)} / {len(dataset)} samples successfully.") + print("=" * w + "\n") + + # ===== SAVE CSV ===== + os.makedirs(output_dir, exist_ok=True) + _write_per_sample_csv( + per_sample_rows, + os.path.join(output_dir, f"{dataset_name}_per_sample.csv"), + ) + _append_summary_csv( + { + "dataset": dataset_name, + "num_evaluated": len(all_ar), + "num_total": len(dataset), + "avg_acceptance_rate": round(float(np.mean(all_ar)), 4), + "min_acceptance_rate": round(float(np.min(all_ar)), 4), + "max_acceptance_rate": round(float(np.max(all_ar)), 4), + "avg_dlm_tps": round(float(np.mean(all_dlm_tps)), 2), + "min_dlm_tps": round(float(np.min(all_dlm_tps)), 2), + "max_dlm_tps": round(float(np.max(all_dlm_tps)), 2), + "avg_tlm_tps": round(float(np.mean(all_tlm_tps)), 2), + "min_tlm_tps": round(float(np.min(all_tlm_tps)), 2), + "max_tlm_tps": round(float(np.max(all_tlm_tps)), 2), + "avg_spd_tps": round(float(np.mean(all_spd_tps)), 2), + "min_spd_tps": round(float(np.min(all_spd_tps)), 2), + "max_spd_tps": round(float(np.max(all_spd_tps)), 2), + }, + os.path.join(output_dir, "summary.csv"), + ) + else: + print("No successful results.") + + +# ===== ARGUMENT PARSING ===== + + +def parse_args(): + parser = argparse.ArgumentParser(description="SPD benchmark — gsm8k / math500 / humaneval") + parser.add_argument("--dataset", required=True, choices=list(DATASET_CONFIG.keys())) + parser.add_argument("--tlm_qpc", required=True) + parser.add_argument("--dlm_qpc", required=True) + parser.add_argument("--tlm_model_name", required=True) + parser.add_argument("--dlm_model_name", required=True) + parser.add_argument("--iteration", type=int, default=300) + parser.add_argument("--ctx_len", type=int, default=4096) + parser.add_argument("--generation_len", type=int, default=1024) + parser.add_argument("--tlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--dlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--hf_token", default=None) + parser.add_argument("--num_samples", type=int, default=0, help="Number of samples to run (0 = all)") + parser.add_argument("--output_dir", default="./results", help="Directory for CSV output (default: ./results)") + return parser.parse_args() + + +# ===== MAIN ===== + + +def main(): + args = parse_args() + num_samples = args.num_samples if args.num_samples > 0 else None + + console.print("[bold blue]Loading tokenizer and config...[/bold blue]") + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.tlm_model_name, token=args.hf_token, trust_remote_code=True + ) + config = transformers.AutoConfig.from_pretrained(args.dlm_model_name, token=args.hf_token, trust_remote_code=True) + vocab_size = config.vocab_size + hidden_size = config.hidden_size + block_size = config.block_size + dflash_cfg = getattr(config, "dflash_config", None) or config.to_dict().get("dflash_config", {}) + mask_token_id = dflash_cfg["mask_token_id"] if isinstance(dflash_cfg, dict) else dflash_cfg.mask_token_id + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + console.print("[bold blue]Loading QAIC inference sessions...[/bold blue]") + dlm_session = QAICInferenceSession(args.dlm_qpc, args.dlm_devices) + tlm_session = QAICInferenceSession(args.tlm_qpc, args.tlm_devices) + dlm_session.skip_buffers( + set([x for x in dlm_session.input_names + dlm_session.output_names if x.startswith("past_")]) + ) + tlm_session.skip_buffers( + set([x for x in tlm_session.input_names + tlm_session.output_names if x.startswith("past_")]) + ) + + prompt_chunk_size = max( + [x[tlm_session.binding_index_map["input_ids"]][1][1] for x in tlm_session.allowed_shapes] + + [tlm_session.bindings[tlm_session.binding_index_map["input_ids"]].dims[1]] + ) + console.print(f"prompt_chunk_size = {prompt_chunk_size}") + + evaluate_dataset( + dataset_name=args.dataset, + tokenizer=tokenizer, + dlm_session=dlm_session, + tlm_session=tlm_session, + vocab_size=vocab_size, + prompt_chunk_size=prompt_chunk_size, + mask_token_id=mask_token_id, + ctx_len=args.ctx_len, + block_size=block_size, + max_iterations=args.iteration, + hidden_size=hidden_size, + generation_len=args.generation_len, + num_samples=num_samples, + output_dir=args.output_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/dflash_spd_single_prompt.py b/examples/performance/dflash/dflash_spd_single_prompt.py new file mode 100644 index 0000000000..bcc11eb72b --- /dev/null +++ b/examples/performance/dflash/dflash_spd_single_prompt.py @@ -0,0 +1,406 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import time + +import numpy as np +import torch +import transformers +from rich.console import Console +from rich.markup import escape +from utils import format_prompt + +from QEfficient.generation.cloud_infer import QAICInferenceSession + +torch.manual_seed(42) +np.random.seed(42) + +console = Console() + + +# ===== METRICS ===== + + +class SpecDecodingMetrics: + def __init__(self, block_size: int = 10): + self.block_size = block_size + self.total_prefill_time = 0.0 + self.tlm_decode_time = 0.0 + self.dlm_decode_time = 0.0 + self.total_accepted_tokens = 0 + self.total_rejected_tokens = 0 + self.total_generated_tokens = 0 + self.num_total_iters = 0 + self.acceptance_history = [] + self.generated_ids: list = [] + self.generated_sources: list = [] + + def acceptance_rate(self) -> float: + if self.num_total_iters == 0: + return 0.0 + return self.total_generated_tokens / self.num_total_iters + + def dlm_tok_rate(self) -> float: + if self.dlm_decode_time <= 0: + return 0.0 + return (self.block_size * self.num_total_iters) / self.dlm_decode_time + + def tlm_tok_rate(self) -> float: + if self.tlm_decode_time <= 0: + return 0.0 + ar = self.acceptance_rate() + num_tok_tlm = self.total_generated_tokens / (1 + ar) if (1 + ar) > 0 else 0.0 + return num_tok_tlm / self.tlm_decode_time + + def spd_tok_rate(self) -> float: + total_decode_s = self.tlm_decode_time + self.dlm_decode_time + if total_decode_s <= 0: + return 0.0 + return self.total_generated_tokens / total_decode_s + + +# ===== INFERENCE ===== + + +def run_spd_inference_single( + prompt_text: str, + tokenizer, + dlm_session: QAICInferenceSession, + tlm_session: QAICInferenceSession, + mask_token_id: int, + vocab_size: int, + prompt_chunk_size: int, + ctx_len: int = 4096, + block_size: int = 16, + max_iterations: int = 300, + hidden_size: int = 4096, + generation_len: int = 256, +) -> SpecDecodingMetrics: + eos_token_ids = {tokenizer.eos_token_id} if tokenizer.eos_token_id is not None else set() + + prompt = [prompt_text] + batch_size = 1 + metrics = SpecDecodingMetrics(block_size=block_size) + + # Tokenize + tlm_inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = tlm_inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prompt_chunk_size) + padded_len = num_chunks * prompt_chunk_size + tlm_inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + tlm_inputs["position_ids"] = np.where(tlm_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + tlm_inputs.pop("token_type_ids", None) + tlm_inputs = {k: torch.from_numpy(v) for k, v in tlm_inputs.items()} + tlm_inputs.pop("past_key_values", None) + tlm_inputs = {k: v.detach().numpy() for k, v in tlm_inputs.items()} + + generated_ids = np.full((batch_size, ctx_len - padded_len), tokenizer.pad_token_id) + + # Set output buffers + tlm_session.set_buffers({"logits": np.zeros((batch_size, prompt_chunk_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) + dlm_session.set_buffers({"logits": np.zeros((batch_size, block_size, vocab_size), dtype=np.float32)}) + + tlm_cache_index = np.array([0]) + dlm_cache_index = np.array([0]) + dlm_inputs = {} + + # ===== PREFILL ===== + prefill_start = time.time() + num_sub_blocks = prompt_chunk_size // block_size + remainder = prompt_chunk_size % block_size + + for pi in range(num_chunks - 1): + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_prefill_outputs = tlm_session.run(chunk_inputs) + for sub_i in range(num_sub_blocks): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) + dlm_session.run(dlm_inputs) + if remainder > 0: + sub_start = num_sub_blocks * block_size + target_hidden_rem = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden_rem[:, :remainder, :] = tlm_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target_rem = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target_rem[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["target_hidden"] = target_hidden_rem + dlm_inputs["position_ids_target"] = pos_ids_target_rem + dlm_inputs["position_ids"] = pos_ids_target_rem + block_size + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) + dlm_session.run(dlm_inputs) + tlm_cache_index[0] += prompt_chunk_size + dlm_cache_index[0] += prompt_chunk_size + + # Last prefill chunk + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_last_prefill_outputs = tlm_session.run(chunk_inputs) + last_prefill_pos_in_chunk = chunk_inputs["position_ids"].argmax() + new_tlm_token = tlm_last_prefill_outputs["logits"][:, last_prefill_pos_in_chunk] + + last_sub = last_prefill_pos_in_chunk // block_size + for sub_i in range(last_sub): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_last_prefill_outputs["hidden_states"][ + :, sub_start : sub_start + block_size, : + ] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) + dlm_session.run(dlm_inputs) + + input_ids = np.full((1, block_size), mask_token_id, dtype=np.int64) + input_ids[:, 0] = new_tlm_token + sub_start = last_sub * block_size + if last_sub < num_sub_blocks: + target_hidden = tlm_last_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + else: + target_hidden = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden[:, :remainder, :] = tlm_last_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["position_ids_target"] = pos_ids_target + dlm_inputs["position_ids"] = np.arange( + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1, + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1 + block_size, + ).reshape(1, -1) + dlm_inputs["input_ids"] = input_ids + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + + metrics.total_prefill_time += time.time() - prefill_start + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + # ===== DECODE ===== + spd_counter_idx = tlm_cache_index[0] + last_prefill_pos_in_chunk + gen_idx = 0 + iteration_count = 0 + continue_generation = True + + tlm_session.set_buffers({"logits": np.zeros((batch_size, block_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) + + while gen_idx < generation_len and iteration_count < max_iterations and continue_generation: + iteration_count += 1 + dlm_candidates[:, 0] = new_tlm_token + + tlm_decode_start = time.time() + tlm_decode_outputs = tlm_session.run( + { + "input_ids": dlm_candidates, + "position_ids": dlm_inputs["position_ids"], + } + ) + metrics.tlm_decode_time += time.time() - tlm_decode_start + + tlm_logits = tlm_decode_outputs["logits"] + target_hidden = tlm_decode_outputs["hidden_states"] + + accepted_length = 0 + rejected_flag = False + + for spec_idx in range(block_size - 1): + tlm_token = tlm_logits[:, spec_idx] + dlm_token = dlm_candidates[:, spec_idx + 1] + if tlm_token == dlm_token: + accepted_length += 1 + metrics.total_accepted_tokens += 1 + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = dlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(dlm_token[0])) + metrics.generated_sources.append("dlm") + else: + metrics.total_rejected_tokens += block_size - spec_idx - 1 + rejected_flag = True + new_tlm_token = tlm_token + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(tlm_token[0])) + metrics.generated_sources.append("tlm") + break + + metrics.acceptance_history.append(accepted_length) + metrics.total_generated_tokens += accepted_length + 1 + + if not rejected_flag: + new_tlm_token = tlm_logits[:, block_size - 1] + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = new_tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(new_tlm_token[0])) + metrics.generated_sources.append("tlm") + + dlm_candidate_ids = list(dlm_candidates[0, 1 : accepted_length + 1]) + this_iter_gen_ids = dlm_candidate_ids + [new_tlm_token[0]] + for tok_id in this_iter_gen_ids: + if tok_id in eos_token_ids: + continue_generation = False + break + + if not continue_generation: + break + + dlm_decode_start = time.time() + dlm_inputs["position_ids_target"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape( + 1, -1 + ) + spd_counter_idx += accepted_length + 1 + dlm_inputs["position_ids_target"][:, accepted_length + 1 :] = -1 + dlm_inputs["position_ids"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape(1, -1) + input_ids[:, 0] = new_tlm_token + dlm_inputs["input_ids"] = input_ids + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + metrics.dlm_decode_time += time.time() - dlm_decode_start + + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + metrics.num_total_iters = iteration_count + return metrics + + +# ===== ARGUMENT PARSING ===== + + +def parse_args(): + parser = argparse.ArgumentParser(description="SPD single-prompt inference") + parser.add_argument("--prompt", required=True, help="Input prompt text") + parser.add_argument("--tlm_qpc", required=True) + parser.add_argument("--dlm_qpc", required=True) + parser.add_argument("--tlm_model_name", required=True) + parser.add_argument("--dlm_model_name", required=True) + parser.add_argument("--iteration", type=int, default=300) + parser.add_argument("--ctx_len", type=int, default=4096) + parser.add_argument("--generation_len", type=int, default=256) + parser.add_argument("--tlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--dlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--hf_token", default=None) + parser.add_argument( + "--category", + default="", + help="Prompt category for formatting (math, coding, reasoning, …). Defaults to the general reasoning format.", + ) + parser.add_argument( + "--format_prompt", + action="store_true", + help="If set, wrap the prompt with the category-specific template from utils.format_prompt. " + "Off by default — the prompt is used verbatim.", + ) + return parser.parse_args() + + +# ===== MAIN ===== + + +def main(): + args = parse_args() + + console.print("[bold blue]Loading tokenizer and config...[/bold blue]") + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.tlm_model_name, token=args.hf_token, trust_remote_code=True + ) + config = transformers.AutoConfig.from_pretrained(args.dlm_model_name, token=args.hf_token, trust_remote_code=True) + vocab_size = config.vocab_size + hidden_size = config.hidden_size + block_size = config.block_size + dflash_cfg = getattr(config, "dflash_config", None) or config.to_dict().get("dflash_config", {}) + mask_token_id = dflash_cfg["mask_token_id"] if isinstance(dflash_cfg, dict) else dflash_cfg.mask_token_id + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + console.print("[bold blue]Loading QAIC inference sessions...[/bold blue]") + dlm_session = QAICInferenceSession(args.dlm_qpc, args.dlm_devices) + tlm_session = QAICInferenceSession(args.tlm_qpc, args.tlm_devices) + dlm_session.skip_buffers( + set([x for x in dlm_session.input_names + dlm_session.output_names if x.startswith("past_")]) + ) + tlm_session.skip_buffers( + set([x for x in tlm_session.input_names + tlm_session.output_names if x.startswith("past_")]) + ) + + prompt_chunk_size = max( + [x[tlm_session.binding_index_map["input_ids"]][1][1] for x in tlm_session.allowed_shapes] + + [tlm_session.bindings[tlm_session.binding_index_map["input_ids"]].dims[1]] + ) + console.print(f"prompt_chunk_size = {prompt_chunk_size}") + + user_content = format_prompt(args.prompt, args.category) if args.format_prompt else args.prompt + messages = [{"role": "user", "content": user_content}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + console.print(f"[cyan]Input:[/cyan] {args.prompt[:120].strip()}") + + metrics = run_spd_inference_single( + prompt_text=prompt_text, + tokenizer=tokenizer, + dlm_session=dlm_session, + tlm_session=tlm_session, + vocab_size=vocab_size, + prompt_chunk_size=prompt_chunk_size, + ctx_len=args.ctx_len, + block_size=block_size, + max_iterations=args.iteration, + hidden_size=hidden_size, + generation_len=args.generation_len, + mask_token_id=mask_token_id, + ) + + output_parts = ["Output: "] + for tok_id, source in zip(metrics.generated_ids, metrics.generated_sources): + text = escape(tokenizer.decode([tok_id], skip_special_tokens=True)) + if source == "dlm": + output_parts.append(f"[blue]{text}[/blue]") + else: + output_parts.append(f"[white]{text}[/white]") + console.print("".join(output_parts)) + + ar = metrics.acceptance_rate() + dlm_tps = metrics.dlm_tok_rate() + tlm_tps = metrics.tlm_tok_rate() + spd_tps = metrics.spd_tok_rate() + + w = 46 + print("\n" + "=" * w) + print(" SPD Inference — Metrics") + print("=" * w) + print(f" {'Acceptance Rate (tok/iter)':<30} {ar:>6.2f}") + print(f" {'DLM Throughput (tok/s)':<30} {dlm_tps:>6.1f}") + print(f" {'TLM Throughput (tok/s)':<30} {tlm_tps:>6.1f}") + print(f" {'SPD Decode Speed (tok/s)':<30} {spd_tps:>6.1f}") + print(f" {'Generated tokens':<30} {metrics.total_generated_tokens:>6}") + print(f" {'Iterations':<30} {metrics.num_total_iters:>6}") + print(f" {'Prefill time (s)':<30} {metrics.total_prefill_time:>6.3f}") + print("=" * w + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/make_models.py b/examples/performance/dflash/make_models.py new file mode 100644 index 0000000000..8ae7fcae4c --- /dev/null +++ b/examples/performance/dflash/make_models.py @@ -0,0 +1,146 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Build and compile TLM + DLM QPC packages for DFlash speculative decoding. + +Usage: + python make_models.py # build both (TLM + DLM) in separate subprocesses + python make_models.py --mode tlm # build TLM only (single process) + python make_models.py --mode dlm # build DLM only (single process) + python make_models.py --mode both # alias for default + +The default 'both' mode launches one subprocess per model so that compiler +state from the TLM build cannot affect the DLM build (back-to-back compiles +in the same process have been observed to segfault). +""" + +import argparse +import os +import subprocess +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +import torch +from transformers import AutoModelForCausalLM +from utils import build_dlm_model, build_tlm_model, extract_embed, extract_lm_head, load_dflash_checkpoint + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.logging_utils import logger + +# ── Paths ───────────────────────────────────────────────────────────────────── +TLM_MODEL_PATH = "Qwen/Qwen3-4B" +# TLM_MODEL_PATH = "Qwen/Qwen3-8B" +# TLM_MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct" + +DFLASH_MODEL_PATH = "z-lab/Qwen3-4B-DFlash-b16" +# DFLASH_MODEL_PATH = "z-lab/Qwen3-8B-DFlash-b16" +# DFLASH_MODEL_PATH = "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + +# ── Compile options ─────────────────────────────────────────────────────────── +COMPILE_KWARGS = dict( + ctx_len=4096, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, +) + +TLM_NUM_CORES = 8 +DLM_NUM_CORES = 16 + + +def _load_dflash_meta(): + dflash_state_dict, cfg = load_dflash_checkpoint(DFLASH_MODEL_PATH) + target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) + mask_token_id = cfg.get("dflash_config", {}).get("mask_token_id", []) + block_size = cfg.get("block_size", None) + logger.info(f" target_layer_ids : {target_layer_ids}") + logger.info(f" mask_token_id : {mask_token_id}") + logger.info(f" block_size : {block_size}") + return dflash_state_dict, target_layer_ids, block_size + + +def build_tlm(): + logger.info(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") + dflash_state_dict, target_layer_ids, block_size = _load_dflash_meta() + + logger.info(f"\nLoading base model: {TLM_MODEL_PATH}") + base_model = AutoModelForCausalLM.from_pretrained(TLM_MODEL_PATH, torch_dtype=torch.float32) + + logger.info("\n=== TLM ===") + tlm_target_ids = [i + 1 for i in target_layer_ids] + build_tlm_model(base_model, dflash_state_dict, tlm_target_ids) + + tlm_qeff = QEFFAutoModelForCausalLM(base_model, qaic_config={"target_layer_ids": tlm_target_ids}) + tlm_qpc_path = tlm_qeff.compile( + prefill_seq_len=128, + num_cores=TLM_NUM_CORES, + dflash_block_size=block_size, + **COMPILE_KWARGS, + ) + logger.info(f"tlm_qpc_path: {tlm_qpc_path}") + return tlm_qpc_path + + +def build_dlm(): + logger.info(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") + _, _, block_size = _load_dflash_meta() + + logger.info(f"\nLoading base model (for lm_head): {TLM_MODEL_PATH}") + base_model = AutoModelForCausalLM.from_pretrained(TLM_MODEL_PATH, torch_dtype=torch.float32) + lm_head_weight, lm_head_bias = extract_lm_head(base_model) + embed_weight = extract_embed(base_model) + del base_model + + logger.info("\n=== DLM ===") + dlm_model = build_dlm_model(DFLASH_MODEL_PATH, lm_head_weight, lm_head_bias, embed_weight) + + dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) + dlm_qpc_path = dlm_qeff.compile( + prefill_seq_len=block_size, + num_cores=DLM_NUM_CORES, + prefill_only=True, + **COMPILE_KWARGS, + ) + logger.info(f"dlm_qpc_path: {dlm_qpc_path}") + return dlm_qpc_path + + +def _run_subprocess(mode: str): + logger.info(f"\n>>> Spawning subprocess: --mode {mode}") + result = subprocess.run( + [sys.executable, os.path.abspath(__file__), "--mode", mode], + check=False, + ) + if result.returncode != 0: + raise SystemExit(f"Subprocess for --mode {mode} exited with code {result.returncode}") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--mode", + choices=["tlm", "dlm", "both"], + default="both", + help="Which model(s) to build. 'both' (default) runs TLM and DLM in separate subprocesses.", + ) + args = parser.parse_args() + + if args.mode == "tlm": + build_tlm() + elif args.mode == "dlm": + build_dlm() + else: + _run_subprocess("tlm") + _run_subprocess("dlm") + logger.info("\n=== Done ===") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/utils.py b/examples/performance/dflash/utils.py new file mode 100644 index 0000000000..711b49aef9 --- /dev/null +++ b/examples/performance/dflash/utils.py @@ -0,0 +1,444 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import json +import warnings +from typing import Optional + +import torch +from datasets import Features, Sequence, Value, load_dataset +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from torch import nn +from transformers import AutoModelForCausalLM + +from QEfficient.utils.logging_utils import logger + +# ───────────────────────────────────────────────────────────────────────────── +# model_name (TLM short) → (TLM HF repo, DLM HF repo) +# DLM column is the canonical DFlash repo. TLM column is the standard HF repo +# when known; otherwise None and must be supplied via --tlm_hf_path. +# ───────────────────────────────────────────────────────────────────────────── +MODEL_MAP = { + "gemma-4-31B-it": (None, "z-lab/gemma-4-31B-it-DFlash"), + "gemma-4-26B-A4B-it": (None, "z-lab/gemma-4-26B-A4B-it-DFlash"), + "MiniMax-M2.7": (None, "z-lab/MiniMax-M2.7-DFlash"), + "MiniMax-M2.5": (None, "z-lab/MiniMax-M2.5-DFlash"), + "Kimi-K2.6": (None, "z-lab/Kimi-K2.6-DFlash"), + "Kimi-K2.5": (None, "z-lab/Kimi-K2.5-DFlash"), + "Qwen3.6-27B": (None, "z-lab/Qwen3.6-27B-DFlash"), + "Qwen3.6-35B-A3B": (None, "z-lab/Qwen3.6-35B-A3B-DFlash"), + "Qwen3.5-4B": (None, "z-lab/Qwen3.5-4B-DFlash"), + "Qwen3.5-9B": (None, "z-lab/Qwen3.5-9B-DFlash"), + "Qwen3.5-27B": (None, "z-lab/Qwen3.5-27B-DFlash"), + "Qwen3.5-35B-A3B": (None, "z-lab/Qwen3.5-35B-A3B-DFlash"), + "Qwen3.5-122B-A10B": (None, "z-lab/Qwen3.5-122B-A10B-DFlash"), + "gpt-oss-20b": ("openai/gpt-oss-20b", "z-lab/gpt-oss-20b-DFlash"), + "gpt-oss-120b": ("openai/gpt-oss-120b", "z-lab/gpt-oss-120b-DFlash"), + "Qwen3-Coder-Next": (None, "z-lab/Qwen3-Coder-Next-DFlash"), + "Qwen3-4B": ("Qwen/Qwen3-4B", "z-lab/Qwen3-4B-DFlash-b16"), + "Qwen3-8B": ("Qwen/Qwen3-8B", "z-lab/Qwen3-8B-DFlash-b16"), + "Qwen3-Coder-30B-A3B": ("Qwen/Qwen3-Coder-30B-A3B-Instruct", "z-lab/Qwen3-Coder-30B-A3B-DFlash"), + "Llama-3.1-8B-Instruct": ("meta-llama/Llama-3.1-8B-Instruct", "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat"), +} + + +def _build_aliases(model_map): + aliases = {} + for short, (tlm_repo, _) in model_map.items(): + aliases[short.lower()] = short + if tlm_repo: + aliases[tlm_repo.lower()] = short + aliases[tlm_repo.split("/", 1)[-1].lower()] = short + return aliases + + +MODEL_ALIASES = _build_aliases(MODEL_MAP) + + +def resolve_model_name(name: str) -> str: + """Map a user-supplied model name (short, full HF path, or basename) to + the canonical short name used as a key in MODEL_MAP.""" + canonical = MODEL_ALIASES.get(name.lower()) + if canonical is None: + raise argparse.ArgumentTypeError( + f"unknown model_name '{name}'. Supported: " + ", ".join(sorted(MODEL_MAP.keys())) + ) + return canonical + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): + if num_draft_layers == 1: + return [(num_target_layers // 2)] + start = 1 + end = num_target_layers - 3 + span = end - start + target_layer_ids = [int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers)] + return target_layer_ids + + +def extract_context_feature( + hidden_states: list[torch.Tensor], + layer_ids: Optional[list[int]], +) -> torch.Tensor: + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + +_TARGET_ABSMAX = 128.0 + + +def load_dflash_checkpoint(dflash_model_path: str) -> tuple[dict, dict]: + """Download and load the DFlash safetensors checkpoint and config. + + Returns + ------- + state_dict : dict[str, Tensor] — all tensors in fp32 + cfg : dict — parsed config.json + """ + bin_path = hf_hub_download(repo_id=dflash_model_path, filename="model.safetensors") + config_path = hf_hub_download(repo_id=dflash_model_path, filename="config.json") + + with open(config_path, "r") as f: + cfg = json.load(f) + + state_dict = {} + with safe_open(bin_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(torch.float32) + + return state_dict, cfg + + +def extract_lm_head(model: AutoModelForCausalLM) -> tuple[torch.Tensor, torch.Tensor | None]: + """Return (lm_head_weight, lm_head_bias) from a HuggingFace causal LM (fp32).""" + sd = model.state_dict() + weight = sd["lm_head.weight"].to(torch.float32) + bias = sd.get("lm_head.bias") + if bias is not None: + bias = bias.to(torch.float32) + return weight, bias + + +def extract_embed(model: AutoModelForCausalLM) -> torch.Tensor: + """Return embed_tokens.weight from a HuggingFace causal LM (fp32).""" + sd = model.state_dict() + return sd["model.embed_tokens.weight"].to(torch.float32) + + +def build_tlm_model( + base_model: AutoModelForCausalLM, + dflash_state_dict: dict, + target_layer_ids: list[int], + target_absmax: float = _TARGET_ABSMAX, +) -> AutoModelForCausalLM: + """Attach fc + hidden_norm to *base_model*, inject DFlash weights, and scale fc. + + Modifies *base_model* in-place and returns it. + """ + inner = base_model.model + hidden_size = base_model.config.hidden_size + model_type = getattr(base_model.config, "model_type", "") + n = len(target_layer_ids) + + # Add fc and hidden_norm + inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False) + + if "qwen3" in model_type: + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + + inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) + elif "llama" in model_type: + from transformers.models.llama.modeling_llama import LlamaRMSNorm + + inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) + else: + warnings.warn(f"Unknown model_type '{model_type}'; using nn.RMSNorm for hidden_norm.") + inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(base_model.config, "rms_norm_eps", 1e-6)) + + # Inject weights from DFlash checkpoint + fc_tensor = dflash_state_dict["fc.weight"].to(torch.float32) + inner.fc.weight.data.copy_(fc_tensor) + + hn_tensor = dflash_state_dict["hidden_norm.weight"].to(torch.float32) + inner.hidden_norm.weight.data.copy_(hn_tensor) + + # Scale fc weights so activations stay within fp16 range + # RMSNorm(x/s) == RMSNorm(x), so this is zero-accuracy-cost + with torch.no_grad(): + in_feat = inner.fc.in_features + max_row_norm = inner.fc.weight.data.norm(dim=1).max().item() + fc_out_bound = (in_feat**0.5) * max_row_norm + s = max(fc_out_bound / target_absmax, 1.0) + inner.fc.weight.data.div_(s) + logger.info( + f"[TLM] fc scale: in_features={in_feat}, max_row_norm={max_row_norm:.4f}, " + f"fc_out_bound={fc_out_bound:.2f}, s={s:.6f}" + ) + + logger.info(f"[TLM] fc ({n * hidden_size} -> {hidden_size}) and hidden_norm attached and scaled") + return base_model + + +def build_dlm_model( + dflash_model_path: str, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None = None, + embed_weight: torch.Tensor | None = None, +) -> AutoModelForCausalLM: + """Load the DFlash model and inject lm_head (and optionally embed_tokens) weights from the base TLM model. + + Also removes fc / hidden_norm if the DFlash checkpoint has them. + """ + dlm_model = AutoModelForCausalLM.from_pretrained(dflash_model_path, torch_dtype=torch.float32) + + with torch.no_grad(): + dlm_model.lm_head.weight.copy_(lm_head_weight) + if embed_weight is not None: + dlm_model.model.embed_tokens.weight.copy_(embed_weight) + + if lm_head_bias is not None: + if dlm_model.lm_head.bias is None: + dlm_model.lm_head.bias = nn.Parameter(lm_head_bias) + else: + with torch.no_grad(): + dlm_model.lm_head.bias.copy_(lm_head_bias) + + # DFlash checkpoints occasionally carry fc / hidden_norm — strip them + for attr in ("fc", "hidden_norm"): + if hasattr(dlm_model, attr): + delattr(dlm_model, attr) + logger.info(f"[DLM] Removed dlm_model.{attr}") + + logger.info(f"[DLM] lm_head injected (shape: {lm_head_weight.shape})") + if embed_weight is not None: + logger.info(f"[DLM] embed_tokens injected (shape: {embed_weight.shape})") + return dlm_model + + +def read_dlm_meta(dlm_repo: str, hf_token: Optional[str] = None): + """Load a DFlash checkpoint and return (state_dict, target_layer_ids, block_size).""" + state_dict, cfg = load_dflash_checkpoint(dlm_repo) + target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) + block_size = cfg.get("block_size", None) + return state_dict, target_layer_ids, block_size + + +def compile_tlm_qpc( + tlm_repo: str, + dlm_repo: str, + *, + prefill_seq_len: int, + ctx_len: int, + num_cores: int, + num_devices: int, + hf_token: Optional[str] = None, +) -> str: + """Build the TLM (base model + fc/hidden_norm from the DFlash checkpoint) and + compile it to a QPC. Returns the qpc directory path.""" + from QEfficient import QEFFAutoModelForCausalLM + + state_dict, target_layer_ids, block_size = read_dlm_meta(dlm_repo, hf_token) + tlm_target_ids = [i + 1 for i in target_layer_ids] + + logger.info(f"[compile_tlm] base={tlm_repo} dlm={dlm_repo} block_size={block_size}") + base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=hf_token) + build_tlm_model(base_model, state_dict, tlm_target_ids) + + tlm_qeff = QEFFAutoModelForCausalLM(base_model, qaic_config={"target_layer_ids": tlm_target_ids}) + qpc = tlm_qeff.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + dflash_block_size=block_size, + ) + qpc = str(qpc) + logger.info(f"[compile_tlm] qpc={qpc}") + return qpc + + +def compile_dlm_qpc( + tlm_repo: str, + dlm_repo: str, + *, + ctx_len: int, + num_cores: int, + num_devices: int, + hf_token: Optional[str] = None, +) -> str: + """Build the DLM (DFlash model with lm_head/embed_tokens copied from the base + TLM, and fc/hidden_norm stripped) and compile it to a QPC.""" + from QEfficient import QEFFAutoModelForCausalLM + + _, _, block_size = read_dlm_meta(dlm_repo, hf_token) + + logger.info(f"[compile_dlm] dlm={dlm_repo} block_size={block_size}") + base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=hf_token) + lm_head_w, lm_head_b = extract_lm_head(base_model) + embed_w = extract_embed(base_model) + del base_model + + dlm_model = build_dlm_model(dlm_repo, lm_head_w, lm_head_b, embed_w) + dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) + qpc = dlm_qeff.compile( + prefill_seq_len=block_size, + ctx_len=ctx_len, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + prefill_only=True, + ) + qpc = str(qpc) + logger.info(f"[compile_dlm] qpc={qpc}") + return qpc + + +def load_and_process_dataset(data_name: str): + # Math datasets + if data_name == "gsm8k": + dataset = load_dataset("openai/gsm8k", "main", split="test") + prompt_fmt = "{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "math500": + dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "aime24": + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "aime25": + dataset = load_dataset("MathArena/aime_2025", split="train") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + # Chat datasets + elif data_name == "alpaca": + dataset = load_dataset("tatsu-lab/alpaca", split="train") + dataset = dataset.map( + lambda x: { + "formatted_input": (f"{x['instruction']}\n\nInput:\n{x['input']}" if x["input"] else x["instruction"]) + } + ) + dataset = dataset.map(lambda x: {"turns": [x["formatted_input"]]}) + + elif data_name == "mt-bench": + dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") + dataset = dataset.map(lambda x: {"turns": x["prompt"]}) + + # Coding datasets + elif data_name == "humaneval": + dataset = load_dataset("openai/openai_humaneval", split="test") + prompt_fmt = "Write a solution to the following problem and make sure that it passes the tests:\n```python\n{prompt}\n```" + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "mbpp": + dataset = load_dataset("google-research-datasets/mbpp", "sanitized", split="test") + dataset = dataset.map(lambda x: {"turns": [x["prompt"]]}) + + elif data_name == "lbpp": + LBPP_PY_TEST_URL = "https://huggingface.co/datasets/CohereLabs/lbpp/resolve/main/python/test.parquet" + dataset = load_dataset("parquet", data_files={"test": LBPP_PY_TEST_URL})["test"] + dataset = dataset.map(lambda x: {"turns": [x["instruction"]]}) + + elif data_name == "swe-bench": + dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + prompt_fmt = "Problem Statement:\n{problem_statement}\nPlease fix the issue described above." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "livecodebench": + base = "https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/main/" + allowed_files = ["test.jsonl", "test2.jsonl", "test3.jsonl", "test4.jsonl", "test5.jsonl", "test6.jsonl"] + urls = [base + fn for fn in allowed_files] + dataset = load_dataset("json", data_files={"test": urls})["test"] + + def format_lcb(doc): + system_prompt = ( + "You are an expert Python programmer. You will be given a question (problem specification) " + "and will generate a correct Python program that matches the specification and passes all tests. " + "You will NOT return anything except for the program" + ) + question_block = f"### Question:\n{doc['question_content']}" + if doc.get("starter_code"): + format_message = "### Format: Use the following code structure:" + code_block = f"```python\n{doc['starter_code']}\n```" + else: + format_message = "### Format: Write your code in the following format:" + code_block = "```python\n# YOUR CODE HERE\n```" + answer_footer = "### Answer: (use the provided format with backticks)" + return f"{system_prompt}\n\n{question_block}\n\n{format_message}\n{code_block}\n\n{answer_footer}" + + target_features = Features({"turns": Sequence(Value("large_string"))}) + dataset = dataset.map( + lambda x: {"turns": [format_lcb(x)]}, remove_columns=dataset.column_names, features=target_features + ) + + return dataset + + +_DEFAULT_FMT = "{prompt}\nPlease reason step by step, and put your final answer within \\boxed{{}}." +_CODING_FMT = ( + "Write a solution to the following problem and make sure that it passes the tests:\n```python\n{prompt}\n```" +) + +_CATEGORY_FMT = { + "math": _DEFAULT_FMT, + "math_reasoning": _DEFAULT_FMT, + "coding": _CODING_FMT, + "reasoning": _DEFAULT_FMT, + "stem": _DEFAULT_FMT, + "qa": _DEFAULT_FMT, + "rag": _DEFAULT_FMT, + "extraction": _DEFAULT_FMT, + "humanities": _DEFAULT_FMT, + "writing": _DEFAULT_FMT, + "summarization": _DEFAULT_FMT, + "translation": _DEFAULT_FMT, + "roleplay": _DEFAULT_FMT, +} + + +def format_prompt(prompt: str, category: str = "") -> str: + fmt = _CATEGORY_FMT.get(category, _DEFAULT_FMT) + return fmt.format(prompt=prompt) + + +def reformat_jsonl_by_category(questions: list) -> list: + """Apply instruction prefix to JSONL questions based on their category. + + Mirrors the prompt_fmt logic in load_and_process_dataset: for categories + where a specific instruction is obvious (math, coding) a tailored prefix is + used; for all others the same step-by-step reasoning instruction is applied. + """ + for q in questions: + category = q.get("category", "") + q["turns"][0] = format_prompt(q["turns"][0], category) + return questions