From 99ba515af8fd3fc83449e3e7ea5ca174c5d8d3fd Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Mon, 18 May 2026 09:03:50 -0700 Subject: [PATCH 1/5] Dflash: Block Diffusion Speculative Decoding Co-authored-by: Vahid Janfaza Co-authored-by: fannanya Signed-off-by: Vahid Janfaza Signed-off-by: Vahid Janfaza Signed-off-by: Vahid Janfaza --- .../models/llama/modeling_llama.py | 44 +- .../transformers/models/modeling_auto.py | 45 +- .../transformers/models/pytorch_transforms.py | 144 +++++ .../models/qwen3/modeling_qwen3.py | 43 +- .../qwen3/modeling_qwen3_dflash_draft.py | 544 ++++++++++++++++ QEfficient/utils/constants.py | 2 +- examples/performance/dflash/README.md | 88 +++ .../performance/dflash/basic_inference.py | 224 +++++++ examples/performance/dflash/benchmark.py | 347 ++++++++++ examples/performance/dflash/dbg.log | 0 .../dflash/dflash_spd_benchmark.py | 603 ++++++++++++++++++ .../dflash/dflash_spd_single_prompt.py | 401 ++++++++++++ examples/performance/dflash/make_models.py | 144 +++++ .../Llama-3.1-8B-Instruct_noise_embeds.npy | Bin 0 -> 16512 bytes .../noise_embedding/Qwen3-4B_noise_embeds.npy | Bin 0 -> 10368 bytes .../noise_embedding/Qwen3-8B_noise_embeds.npy | Bin 0 -> 16512 bytes .../gpt-oss-20b_noise_embeds.npy | Bin 0 -> 11648 bytes .../results-Qwen3-4B/humaneval_per_sample.csv | 165 +++++ .../dflash/results-Qwen3-4B/summary.csv | 2 + examples/performance/dflash/utils.py | 312 +++++++++ 20 files changed, 3103 insertions(+), 5 deletions(-) create mode 100644 QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py create mode 100644 examples/performance/dflash/README.md create mode 100644 examples/performance/dflash/basic_inference.py create mode 100644 examples/performance/dflash/benchmark.py create mode 100644 examples/performance/dflash/dbg.log create mode 100644 examples/performance/dflash/dflash_spd_benchmark.py create mode 100644 examples/performance/dflash/dflash_spd_single_prompt.py create mode 100644 examples/performance/dflash/make_models.py create mode 100755 examples/performance/dflash/noise_embedding/Llama-3.1-8B-Instruct_noise_embeds.npy create mode 100755 examples/performance/dflash/noise_embedding/Qwen3-4B_noise_embeds.npy create mode 100755 examples/performance/dflash/noise_embedding/Qwen3-8B_noise_embeds.npy create mode 100755 examples/performance/dflash/noise_embedding/gpt-oss-20b_noise_embeds.npy create mode 100644 examples/performance/dflash/results-Qwen3-4B/humaneval_per_sample.csv create mode 100644 examples/performance/dflash/results-Qwen3-4B/summary.csv create mode 100644 examples/performance/dflash/utils.py diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 811b3f84d5..22f866b8ce 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from dataclasses import dataclass from typing import List, Optional, Tuple, Type, Union import torch @@ -36,6 +37,11 @@ from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + output_embeds: Optional[torch.FloatTensor] = None + + class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -289,10 +295,14 @@ def forward( sin = self.sin_cached[position_ids].unsqueeze(1) cos = self.cos_cached[position_ids].unsqueeze(1) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + self.target_layer_ids = getattr(self, "target_layer_ids", None) + target_hidden_list = [] + + for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states += (hidden_states,) - + if self.target_layer_ids and idx in self.target_layer_ids: + target_hidden_list.append(hidden_states) hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -316,6 +326,16 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + if self.target_layer_ids: + target_hidden = torch.cat(target_hidden_list, dim=-1) + target_hidden_fc = self.fc(target_hidden) + target_hidden_final = self.hidden_norm(target_hidden_fc) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=target_hidden_final, + ) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, @@ -354,6 +374,10 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + + if getattr(self.model, "target_layer_ids", None): + output_hidden_states = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -371,6 +395,22 @@ def forward( # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if getattr(self.model, "target_layer_ids", None): + target_hidden = outputs.hidden_states + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + predicted_token_ids = logits.argmax(dim=-1).to(torch.int32) + output_embed = self.model.embed_tokens(predicted_token_ids) + return QEffCausalLMOutputWithPast( + loss=None, + logits=predicted_token_ids, + past_key_values=outputs.past_key_values, + hidden_states=target_hidden, + attentions=outputs.attentions, + output_embeds=output_embed, + ) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states).float() diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2668be8a1e..0263e0b91f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -47,6 +47,8 @@ ) from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, + DFlashTLMTransform, + DFlashTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, @@ -2878,6 +2880,20 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True + self.dflash_dlm = False + self.hidden_size = self.model.config.hidden_size + self.vocab_size = self.model.config.vocab_size + if qaic_config is not None: + self.dflash_dlm = qaic_config.get("dflash_dlm", False) + if self.dflash_dlm: + self.model, _ = DFlashTransform.apply(self.model, qaic_config) + + self.dflash_tlm = False + if qaic_config is not None: + self.dflash_tlm = bool(qaic_config.get("target_layer_ids", None)) + if self.dflash_tlm: + self.model, _ = DFlashTLMTransform.apply(self.model, qaic_config) + def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -3079,7 +3095,7 @@ def export( fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( - self.model.config, fbs if self.continuous_batching else bs, seq_len + self.model.config, fbs if self.continuous_batching else bs, seq_len * 2 ) enable_chunking = kwargs.get("enable_chunking", False) if ( @@ -3133,6 +3149,22 @@ def export( "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + + if self.dflash_dlm: + example_inputs = { + "noise_embeds": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), + "target_hidden": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), + "position_ids": torch.arange(seq_len, 2 * seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "position_ids_target": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), + "past_key_values": [[] for _ in range(self.num_layers)], + } + dynamic_axes = { + "noise_embeds": {0: "batch_size", 1: "seq_len"}, + "target_hidden": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids_target": {0: "batch_size", 1: "seq_len"}, + } + if self.ccl_enabled: example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int8) dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} @@ -3250,6 +3282,10 @@ def export( qaic_config=self.model.qaic_config, ) + if self.dflash_tlm: + output_names.append("hidden_states") + output_names.append("output_embeds") + return self._export( example_inputs, output_names=output_names, @@ -3334,6 +3370,7 @@ def build_decode_specialization( kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + dflash_block_size: Optional[int] = None, **kwargs, ): """ @@ -3377,6 +3414,9 @@ def build_decode_specialization( spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None + if self.dflash_tlm: + spec["seq_len"] = dflash_block_size + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -3397,6 +3437,7 @@ def compile( batch_size: int = 1, full_batch_size: Optional[int] = None, kv_cache_batch_size: Optional[int] = None, + dflash_block_size: Optional[int] = None, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, @@ -3612,6 +3653,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + dflash_block_size=dflash_block_size, ) if decode_spec: specializations.append(decode_spec) @@ -3624,6 +3666,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + dflash_block_size=dflash_block_size, prefill_only=prefill_only, ) if decode_spec: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ec34ebb046..6bb293f6a0 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -458,6 +458,18 @@ QEffQwen3ForCausalLM, QEffQwen3Model, ) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3Attention as QEffQwen3DFlashAttention, +) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3DecoderLayer as QEffQwen3DFlashDecoderLayer, +) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3ForCausalLM as QEffQwen3DFlashForCausalLM, +) +from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import ( + QEffQwen3Model as QEffQwen3DFlashModel, +) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( QEffPrefillChunkedQwen3MoeSparseMoeBlock, QEffQwen3MoeAttention, @@ -954,6 +966,138 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) - return model, transformed +class DFlashTransform(ModuleMappingTransform): + """ + Replaces standard QEff Qwen3 modules with DFlash-specialized versions when dflash_dlm=True. + Applied after KVCacheTransform so it operates on already-transformed QEff modules. + """ + + _module_mapping = { + QEffQwen3Attention: QEffQwen3DFlashAttention, + QEffQwen3DecoderLayer: QEffQwen3DFlashDecoderLayer, + QEffQwen3Model: QEffQwen3DFlashModel, + QEffQwen3ForCausalLM: QEffQwen3DFlashForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + if not (qaic_config and qaic_config.get("dflash_dlm", False)): + return model, False + return super().apply(model) + + +class DFlashTLMTransform: + """ + Adds fc and hidden_norm layers to the inner model, loads their weights from the TLM + checkpoint, and sets target_layer_ids — activating the TLM hidden-state collection + path in QEffQwen3Model / QEffLlamaModel forward. + + Triggered when target_layer_ids is present in qaic_config. + Works without modifying the transformers library: weights that were silently dropped + as unexpected keys during from_pretrained are re-read here from the checkpoint. + """ + + @classmethod + def _load_tlm_weights(cls, checkpoint_path: str) -> dict: + """Read only fc and hidden_norm weights from a local checkpoint directory.""" + from pathlib import Path + + import torch + + path = Path(checkpoint_path) + target_keys = {"model.fc.weight", "model.hidden_norm.weight"} + dlm_model_weights: dict = {} + + # safetensors (preferred modern format) + sf_files = sorted(path.glob("*.safetensors")) + if sf_files: + try: + from safetensors import safe_open + + for sf in sf_files: + with safe_open(str(sf), framework="pt", device="cpu") as f: + for key in f.keys(): + if key in target_keys and key not in dlm_model_weights: + dlm_model_weights[key] = f.get_tensor(key) + if len(dlm_model_weights) == len(target_keys): + break + return dlm_model_weights + except ImportError: + warnings.warn("safetensors not installed; falling back to .bin loading.") + + # pytorch .bin fallback + bin_files = sorted(path.glob("pytorch_model*.bin")) + for bf in bin_files: + sd = torch.load(str(bf), map_location="cpu", weights_only=True) + for key in target_keys: + if key in sd and key not in dlm_model_weights: + dlm_model_weights[key] = sd[key] + if len(dlm_model_weights) == len(target_keys): + break + + return dlm_model_weights + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + target_layer_ids = qaic_config.get("target_layer_ids", None) if qaic_config else None + if not target_layer_ids: + return model, False + + n = len(target_layer_ids) + inner = model.model # QEffQwen3Model or QEffLlamaModel + hidden_size = model.config.hidden_size + model_type = getattr(model.config, "model_type", "") + + # --- add fc and hidden_norm only if not already present --- + # tlm_maker.py (and similar scripts) inject weights before calling the + # constructor directly; in that case we must not overwrite them. + layers_already_present = hasattr(inner, "fc") and hasattr(inner, "hidden_norm") + + if not layers_already_present: + # add fc + inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False) + + # add hidden_norm using the same RMSNorm class as the model + if "qwen3" in model_type: + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + + inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=model.config.rms_norm_eps) + elif "llama" in model_type: + from transformers.models.llama.modeling_llama import LlamaRMSNorm + + inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=model.config.rms_norm_eps) + else: + warnings.warn( + f"DFlashTLMTransform: unknown model_type '{model_type}'. Using nn.RMSNorm as hidden_norm fallback." + ) + inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(model.config, "rms_norm_eps", 1e-6)) + + # load fc / hidden_norm weights from checkpoint + # pretrained_model_name_or_path is stored in qaic_config by from_pretrained + ckpt_path = (qaic_config or {}).get("pretrained_model_name_or_path", None) + if ckpt_path: + weights = cls._load_tlm_weights(ckpt_path) + if "model.fc.weight" in weights: + inner.fc.weight.data.copy_(weights["model.fc.weight"]) + if "model.hidden_norm.weight" in weights: + inner.hidden_norm.weight.data.copy_(weights["model.hidden_norm.weight"]) + if not weights: + warnings.warn( + "DFlashTLMTransform: fc/hidden_norm weights not found in checkpoint " + f"at '{ckpt_path}'. Layers are randomly initialized." + ) + else: + warnings.warn( + "DFlashTLMTransform: no checkpoint path available — " + "use QEFFAutoModelForCausalLM.from_pretrained() so the path is " + "stored automatically, or set qaic_config['pretrained_model_name_or_path']." + ) + + # --- activate TLM collection path in forward --- + inner.target_layer_ids = target_layer_ids + return model, True + + class VlmKVOffloadTransform(ModuleMappingTransform): # supported architectures _module_mapping = { diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 9844c91016..a13cfa9a3a 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -7,6 +7,7 @@ """PyTorch Qwen3 model.""" +from dataclasses import dataclass from typing import List, Optional, Tuple, Type, Union import torch @@ -39,6 +40,11 @@ from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + output_embeds: Optional[torch.FloatTensor] = None + + # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): """ @@ -324,10 +330,16 @@ def forward( sin = self.sin_cached[position_ids].unsqueeze(1) cos = self.cos_cached[position_ids].unsqueeze(1) - for decoder_layer in self.layers: + self.target_layer_ids = getattr(self, "target_layer_ids", None) + target_hidden_list = [] + + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + if self.target_layer_ids and idx in self.target_layer_ids: + target_hidden_list.append(hidden_states) + hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, @@ -350,6 +362,16 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() + if self.target_layer_ids: + target_hidden = torch.cat(target_hidden_list, dim=-1) + target_hidden_fc = self.fc(target_hidden) + target_hidden_final = self.hidden_norm(target_hidden_fc) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=target_hidden_final, + ) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, @@ -393,6 +415,9 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + if getattr(self.model, "target_layer_ids", None): + output_hidden_states = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -408,6 +433,22 @@ def forward( # Cast to INT32 to avoid issue while running in ONNXRT logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + + if getattr(self.model, "target_layer_ids", None): + target_hidden = outputs.hidden_states + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + predicted_token_ids = logits.argmax(dim=-1).to(torch.int32) + output_embed = self.model.embed_tokens(predicted_token_ids) + return QEffCausalLMOutputWithPast( + loss=None, + logits=predicted_token_ids, + past_key_values=outputs.past_key_values, + hidden_states=target_hidden, + attentions=outputs.attentions, + output_embeds=output_embed, + ) + hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states).float() diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py b/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py new file mode 100644 index 0000000000..0f6bc7bb64 --- /dev/null +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py @@ -0,0 +1,544 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch Qwen3 model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3Config, + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3Model, + Qwen3RotaryEmbedding, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +def _create_mask( + position_ids: torch.Tensor, + target_length: int, + # valid_kv_length: int, + sliding_window: Optional[int] = None, + start_index: Optional[int] = 0, +): + """ + Args: + position_ids: [1, target_length] + target_length: 3 * block_size + valid_kv_length: number of valid KV cache positions + sliding_window: optional local attention window + start_index: offset into KV cache (default 0) + + Returns: + attention_mask: [1, 1, num_queries, target_length] + """ + + device = position_ids.device + + num_queries = position_ids.shape[1] # = block_size (B) + + # ---- Step 1: Create base KV validity mask ---- + # Shape: [target_length] + + kv_positions = torch.arange(start_index, start_index + target_length, device=device) + valid_kv_mask = kv_positions > ( + start_index + position_ids.max() + ) # [position_max is 32( input query podtion_ids), 0 to 31] + + # ---- Step 2: Expand to [num_queries, target_length] ---- + attention_mask = valid_kv_mask.unsqueeze(0).expand(num_queries, target_length) + + # ---- Step 4: Add batch & head dimensions ---- + # Final shape: [1, 1, B, 3B] + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) + + return attention_mask + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Qwen3Config, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + # print_stats(emb, "RotaryEmbedding/emb") + cos_cached = emb.cos().to(dtype) + sin_cached = emb.sin().to(dtype) + + self.register_buffer("cos_cached", cos_cached, persistent=False) + self.register_buffer("sin_cached", sin_cached, persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + cos_out = self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling + sin_out = self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling + + return (cos_out, sin_out) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_len = q.size(-2) + + q_embed = (q * cos[..., :q_len, :]) + (rotate_half(q) * sin[..., :q_len, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def qeff_apply_rope_two_streams(q_noise, k_ctx, k_noise, cos, sin, pos_ctx, pos_noise, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos_t = cos[pos_ctx].unsqueeze(unsqueeze_dim) + sin_t = sin[pos_ctx].unsqueeze(unsqueeze_dim) + + rotate_half_k_ctx = rotate_half(k_ctx) + + k_t_embed = k_ctx * cos_t + rotate_half_k_ctx * sin_t + + # ---- NOISE ---- + cos_n = cos[pos_noise].unsqueeze(unsqueeze_dim) + sin_n = sin[pos_noise].unsqueeze(unsqueeze_dim) + + rotate_half_q_noise = rotate_half(q_noise) + q_n_embed = q_noise * cos_n + rotate_half_q_noise * sin_n + + rotate_half_k_noise = rotate_half(k_noise) + + k_n_embed = k_noise * cos_n + rotate_half_k_noise * sin_n + + return q_n_embed, k_t_embed, k_n_embed + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffQwen3Attention(Qwen3Attention): + """ + Copied from Qwen3Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + """ + + def __qeff_init__(self): + self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + self.dflash_dlm = True + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_ids_target: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + + kwargs.pop("output_attentions", None) + kwargs.pop("return_dict", None) + kwargs.pop("labels", None) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + + k_ctx = self.k_norm(k_ctx.view(bsz, ctx_len, -1, self.head_dim)).transpose(1, 2) + k_noise = self.k_norm(k_noise.view(bsz, q_len, -1, self.head_dim)).transpose(1, 2) + + v_ctx = (v_ctx.view(bsz, ctx_len, -1, self.head_dim)).transpose(1, 2) + v_noise = (v_noise.view(bsz, q_len, -1, self.head_dim)).transpose(1, 2) + + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # Assuming position_id [77,78,79,80, 75,76,-1,-1] first 4 pos id of noise next four position_id for target + + cos, sin = self.rotary_emb(v_ctx, seq_len=kv_seq_len) + query_states, k_ctx, k_noise = qeff_apply_rope_two_streams( + query_states, k_ctx, k_noise, cos, sin, position_ids_target, position_ids + ) + + if past_key_value is not None: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids_target} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + + # first write for target positon_id + past_key_value.write_only(k_ctx, v_ctx, self.layer_idx, cache_kwargs) + + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(k_noise, v_noise, self.layer_idx, cache_kwargs) + + attention_interface = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class QEffQwen3DecoderLayer(Qwen3DecoderLayer): + """ + Copied from Qwen3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor = None, + position_ids_target: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + position_ids_target=position_ids_target, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class QEffQwen3Model(Qwen3Model): + """ + Copied from Qwen3Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update causal attention mask + """ + + def forward( + self, + target_hidden: torch.Tensor = None, + noise_embeds: torch.FloatTensor = None, + position_ids_target: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (noise_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: ####? + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + noise_embeds.shape[1], device=noise_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze( + 0 + ) ###? no need for this because we input it ( where the tokens will be filled ) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = noise_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = decoder_layer( + hidden_states, + target_hidden=target_hidden, + position_ids_target=position_ids_target, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class QEffQwen3ForCausalLM(Qwen3ForCausalLM): + """ + Copied from Qwen3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + target_hidden: torch.Tensor = None, + noise_embeds: torch.FloatTensor = None, + position_ids_target: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # block_size: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + target_hidden=target_hidden, + noise_embeds=noise_embeds, + position_ids_target=position_ids_target, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + # block_size=block_size, + output_hidden_states=output_hidden_states, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states).float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 339e4f4dac..58cef2caa9 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -14,7 +14,7 @@ QEFF_CACHE_DIR_NAME = "qeff_cache" ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 +ONNX_EXPORT_EXAMPLE_SEQ_LEN = 16 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_MAX_NUM_IMAGES = 1 diff --git a/examples/performance/dflash/README.md b/examples/performance/dflash/README.md new file mode 100644 index 0000000000..8b0181d307 --- /dev/null +++ b/examples/performance/dflash/README.md @@ -0,0 +1,88 @@ +# DFlash SPD Examples + +Two entry points wrap the SPD compile + run pipeline. + +#### basic_inference.py +Basic DFlash usage with dense language models. + +**Supported Models:** +- Llama3.1-8B-Instruct +- Qwen3-4B +- Qwen3-8B + +## Single prompt + +```bash +python basic_inference.py --model_name Qwen3-4B \ + --prompt "Explain speculative decoding in two sentences." +``` + +## Benchmark (dataset) + +```bash +python benchmark.py --model_name Qwen3-4B --dataset humaneval +``` + +## `--model_name` + +Accepts either the short key or the full HF repo path (case-insensitive): + +``` +Qwen3-4B +Qwen/Qwen3-4B +qwen3-4b +``` + +Run either script with `--help` to see the full supported list. + +## Skipping compile (reuse QPCs) + +Pass either or both. Whichever side is supplied skips its compile step; the +other side still compiles. + +```bash +python basic_inference.py --model_name Qwen3-4B \ + --tlm_qpc /path/to/tlm/qpc \ + --dlm_qpc /path/to/dlm/qpc \ + --prompt "Hello" +``` + +## Common flags + +| Flag | Default | Notes | +|---|---|---| +| `--tlm_devices` | `0 1 2 3` | TLM device IDs | +| `--dlm_devices` | `0 1 2 3` | DLM device IDs | +| `--tlm_cores` / `--dlm_cores` | `8` | per-side core count | +| `--ctx_len` | `4096` | | +| `--prefill_seq_len` | `128` | | +| `--generation_len` | `1024` (benchmark) / `256` (single) | | +| `--noise_embed_path` | `noise_embedding/_noise_embeds.npy` | override if needed | +| `--hf_token` | `$HF_TOKEN` | required for gated repos | +| `--tlm_hf_path` | from `MODEL_MAP` | required when the map entry has `None` | + +`benchmark.py` only: + +| Flag | Default | +|---|---| +| `--dataset` | `humaneval` (also: `gsm8k`, `math500`) | +| `--num_samples` | `0` (= all) | +| `--iteration` | `300` | +| `--output_dir` | `./results-` | + +`basic_inference.py` only: + +| Flag | Default | +|---|---| +| `--prompt` | *(required)* | +| `--category` | `""` (math / coding / reasoning / …) | + +## Adding a new model + +Edit `MODEL_MAP` in `benchmark.py`: + +```python +"": ("", ""), +``` + +`basic_inference.py` reuses the same map automatically. diff --git a/examples/performance/dflash/basic_inference.py b/examples/performance/dflash/basic_inference.py new file mode 100644 index 0000000000..ff69165492 --- /dev/null +++ b/examples/performance/dflash/basic_inference.py @@ -0,0 +1,224 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Single-entry SPD single-prompt runner. + +Given a TLM model_name (short name OR full HF repo path) and a prompt, this +script: + 1. Looks up the matching DFlash DLM repo on Hugging Face. + 2. Reads hidden_size and block_size from the DLM config. + 3. Compiles TLM + DLM QPCs (only the side(s) not provided via + --tlm_qpc / --dlm_qpc). + 4. Runs the SPD single-prompt inference script. + +Examples: + # Compile + run with all defaults + python basic_inference.py --model_name Qwen3-4B \ + --prompt "Explain speculative decoding in two sentences." + + # Full HF path also accepted + python basic_inference.py --model_name Qwen/Qwen3-4B \ + --prompt "Hello" + + # Reuse pre-compiled QPCs + python basic_inference.py --model_name Qwen3-4B \ + --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc \ + --prompt "What is 17 * 23?" +""" + +import argparse +import os +import subprocess +import sys + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..", "..")) +sys.path.insert(0, REPO_ROOT) +sys.path.insert(0, THIS_DIR) + +from benchmark import MODEL_MAP, resolve_model_name # noqa: E402 # reuse the alias table + + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── +def parse_args(): + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument( + "--model_name", + required=True, + type=resolve_model_name, + help="TLM name — either the short key (e.g. 'Qwen3-4B') or " + "the full HF repo path (e.g. 'Qwen/Qwen3-4B'). " + f"Supported: {', '.join(MODEL_MAP.keys())}", + ) + p.add_argument("--prompt", required=True, help="Input prompt text.") + p.add_argument("--category", default="", help="Prompt category for formatting (math, coding, reasoning, …).") + p.add_argument("--tlm_hf_path", default=None, help="Override TLM HF repo (required if mapping has None).") + + # Optional pre-built QPCs (skip compilation) + p.add_argument("--tlm_qpc", default=None, help="Pre-compiled TLM qpc dir (skip TLM compile).") + p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") + + # Devices / cores + p.add_argument("--tlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) + p.add_argument("--dlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) + p.add_argument("--tlm_cores", type=int, default=8) + p.add_argument("--dlm_cores", type=int, default=8) + + # Compile / run knobs + p.add_argument("--ctx_len", type=int, default=4096) + p.add_argument("--prefill_seq_len", type=int, default=128) + p.add_argument("--generation_len", type=int, default=256) + p.add_argument("--iteration", type=int, default=300) + + p.add_argument("--noise_embed_path", default=None, help="Defaults to noise_embedding/_noise_embeds.npy") + p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN")) + + # Internal modes used by self-spawned compile subprocesses + p.add_argument("--_build", choices=["tlm", "dlm"], default=None, help=argparse.SUPPRESS) + return p.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── +def main(): + args = parse_args() + + tlm_repo_default, dlm_repo = MODEL_MAP[args.model_name] + tlm_repo = args.tlm_hf_path or tlm_repo_default + if tlm_repo is None: + raise SystemExit(f"No default TLM HF path for '{args.model_name}'. Pass --tlm_hf_path.") + + # Sub-mode: spawned compile subprocess. Reuse benchmark.py's builders so we + # don't duplicate the compile pipeline. + if args._build is not None: + from benchmark import _build_dlm, _build_tlm + + if args._build == "tlm": + _build_tlm(args, tlm_repo, dlm_repo) + else: + _build_dlm(args, tlm_repo, dlm_repo) + return + + # ── Resolve / discover hidden_size + block_size from DLM config ──────── + import transformers + + config = transformers.AutoConfig.from_pretrained(dlm_repo, token=args.hf_token, trust_remote_code=True) + hidden_size = config.hidden_size + block_size = getattr(config, "block_size", None) + print(f"DLM repo : {dlm_repo}") + print(f"hidden_size : {hidden_size}") + print(f"block_size : {block_size}") + + # ── Resolve QPC paths (compile only the side that wasn't pre-supplied) ─ + forwarded = [ + "--model_name", + args.model_name, + "--prompt", + args.prompt, + "--ctx_len", + str(args.ctx_len), + "--prefill_seq_len", + str(args.prefill_seq_len), + "--tlm_cores", + str(args.tlm_cores), + "--dlm_cores", + str(args.dlm_cores), + "--tlm_devices", + *[str(d) for d in args.tlm_devices], + "--dlm_devices", + *[str(d) for d in args.dlm_devices], + ] + if args.tlm_hf_path: + forwarded += ["--tlm_hf_path", args.tlm_hf_path] + if args.hf_token: + forwarded += ["--hf_token", args.hf_token] + + if args.tlm_qpc: + print(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") + tlm_qpc = args.tlm_qpc + else: + tlm_qpc = _spawn_compile("tlm", forwarded) + + if args.dlm_qpc: + print(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") + dlm_qpc = args.dlm_qpc + else: + dlm_qpc = _spawn_compile("dlm", forwarded) + print(f"TLM qpc : {tlm_qpc}") + print(f"DLM qpc : {dlm_qpc}") + + # ── Resolve noise embed path ─────────────────────────────────────────── + noise_embed = args.noise_embed_path or os.path.join( + THIS_DIR, "noise_embedding", f"{args.model_name}_noise_embeds.npy" + ) + if not os.path.exists(noise_embed): + raise SystemExit(f"noise embedding not found: {noise_embed}\nPass --noise_embed_path explicitly.") + + # ── Run the existing single-prompt inference script ──────────────────── + eval_script = os.path.join(THIS_DIR, "dflash_spd_single_prompt.py") + cmd = [ + sys.executable, + eval_script, + "--prompt", + args.prompt, + "--tlm_qpc", + tlm_qpc, + "--dlm_qpc", + dlm_qpc, + "--tlm_model_name", + tlm_repo, + "--dlm_model_name", + dlm_repo, + "--noise_embed_path", + noise_embed, + "--iteration", + str(args.iteration), + "--ctx_len", + str(args.ctx_len), + "--generation_len", + str(args.generation_len), + "--tlm_devices", + *[str(d) for d in args.tlm_devices], + "--dlm_devices", + *[str(d) for d in args.dlm_devices], + ] + if args.hf_token: + cmd += ["--hf_token", args.hf_token] + if args.category: + cmd += ["--category", args.category] + + print("\n>>> launching SPD single-prompt inference:") + print(" ".join(cmd)) + rc = subprocess.run(cmd, check=False).returncode + if rc != 0: + raise SystemExit(f"single-prompt inference exited with rc={rc}") + + +def _spawn_compile(mode, argv_template): + """Run this same script with --_build {mode} in a fresh process and return + the qpc path printed on the line starting with TLM_QPC= or DLM_QPC=.""" + cmd = [sys.executable, os.path.abspath(__file__), "--_build", mode] + argv_template + print(f"\n>>> spawning compile subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, check=False, capture_output=True, text=True) + sys.stdout.write(proc.stdout) + sys.stderr.write(proc.stderr) + if proc.returncode != 0: + raise SystemExit(f"compile subprocess (--_build {mode}) failed (rc={proc.returncode})") + + tag = "TLM_QPC=" if mode == "tlm" else "DLM_QPC=" + qpc_line = next((ln for ln in reversed(proc.stdout.splitlines()) if ln.startswith(tag)), None) + if qpc_line is None: + raise SystemExit(f"could not find {tag} line in compile output") + return qpc_line.split("=", 1)[1].strip() + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/benchmark.py b/examples/performance/dflash/benchmark.py new file mode 100644 index 0000000000..c28b942ee1 --- /dev/null +++ b/examples/performance/dflash/benchmark.py @@ -0,0 +1,347 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Single-entry SPD runner. + +Given just a TLM model_name (short name from the supported table), this script: + 1. Looks up the matching DFlash DLM repo on Hugging Face. + 2. Reads hidden_size and block_size from the DLM config. + 3. Compiles TLM + DLM QPCs (if --tlm_qpc / --dlm_qpc are not supplied). + 4. Runs the SPD benchmark on the chosen dataset (default: humaneval). + +Examples: + # Compile + run with all defaults + python run_spd.py --model_name Qwen3-4B + + # Reuse pre-compiled QPCs (no compilation step) + python run_spd.py --model_name Qwen3-4B \ + --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc + + # Custom devices / cores / dataset + python run_spd.py --model_name Llama-3.1-8B-Instruct \ + --tlm_devices 0 1 2 3 --dlm_devices 4 5 6 7 \ + --tlm_cores 8 --dlm_cores 8 --dataset gsm8k +""" + +import argparse +import os +import subprocess +import sys + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.abspath(os.path.join(THIS_DIR, "..", "..")) +sys.path.insert(0, REPO_ROOT) +sys.path.insert(0, THIS_DIR) + + +# ───────────────────────────────────────────────────────────────────────────── +# model_name (TLM short) → (TLM HF repo, DLM HF repo) +# DLM column comes verbatim from the user's supported list. +# TLM column is the standard HF repo when known; otherwise None and must be +# supplied via --tlm_hf_path on the command line. +# ───────────────────────────────────────────────────────────────────────────── +MODEL_MAP = { + "gemma-4-31B-it": (None, "z-lab/gemma-4-31B-it-DFlash"), + "gemma-4-26B-A4B-it": (None, "z-lab/gemma-4-26B-A4B-it-DFlash"), + "MiniMax-M2.7": (None, "z-lab/MiniMax-M2.7-DFlash"), + "MiniMax-M2.5": (None, "z-lab/MiniMax-M2.5-DFlash"), + "Kimi-K2.6": (None, "z-lab/Kimi-K2.6-DFlash"), + "Kimi-K2.5": (None, "z-lab/Kimi-K2.5-DFlash"), + "Qwen3.6-27B": (None, "z-lab/Qwen3.6-27B-DFlash"), + "Qwen3.6-35B-A3B": (None, "z-lab/Qwen3.6-35B-A3B-DFlash"), + "Qwen3.5-4B": (None, "z-lab/Qwen3.5-4B-DFlash"), + "Qwen3.5-9B": (None, "z-lab/Qwen3.5-9B-DFlash"), + "Qwen3.5-27B": (None, "z-lab/Qwen3.5-27B-DFlash"), + "Qwen3.5-35B-A3B": (None, "z-lab/Qwen3.5-35B-A3B-DFlash"), + "Qwen3.5-122B-A10B": (None, "z-lab/Qwen3.5-122B-A10B-DFlash"), + "gpt-oss-20b": ("openai/gpt-oss-20b", "z-lab/gpt-oss-20b-DFlash"), + "gpt-oss-120b": ("openai/gpt-oss-120b", "z-lab/gpt-oss-120b-DFlash"), + "Qwen3-Coder-Next": (None, "z-lab/Qwen3-Coder-Next-DFlash"), + "Qwen3-4B": ("Qwen/Qwen3-4B", "z-lab/Qwen3-4B-DFlash-b16"), + "Qwen3-8B": ("Qwen/Qwen3-8B", "z-lab/Qwen3-8B-DFlash-b16"), + "Qwen3-Coder-30B-A3B": ("Qwen/Qwen3-Coder-30B-A3B-Instruct", "z-lab/Qwen3-Coder-30B-A3B-DFlash"), + "Llama-3.1-8B-Instruct": ("meta-llama/Llama-3.1-8B-Instruct", "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat"), +} + + +# Build alias table: full HF repo path (e.g. "Qwen/Qwen3-4B") and basename +# (case-insensitive) → canonical short name. Lets users pass either form. +def _build_aliases(model_map): + aliases = {} + for short, (tlm_repo, _) in model_map.items(): + aliases[short.lower()] = short + if tlm_repo: + aliases[tlm_repo.lower()] = short + aliases[tlm_repo.split("/", 1)[-1].lower()] = short + return aliases + + +MODEL_ALIASES = _build_aliases(MODEL_MAP) + + +def resolve_model_name(name): + """Map a user-supplied model name (short, full HF path, or basename) to + the canonical short name used as a key in MODEL_MAP.""" + canonical = MODEL_ALIASES.get(name.lower()) + if canonical is None: + raise argparse.ArgumentTypeError( + f"unknown model_name '{name}'. Supported: " + ", ".join(sorted(MODEL_MAP.keys())) + ) + return canonical + + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── +def parse_args(): + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument( + "--model_name", + required=True, + type=resolve_model_name, + help="TLM name — either the short key (e.g. 'Qwen3-4B') or " + "the full HF repo path (e.g. 'Qwen/Qwen3-4B'). " + f"Supported: {', '.join(MODEL_MAP.keys())}", + ) + p.add_argument("--tlm_hf_path", default=None, help="Override TLM HF repo (required if mapping has None).") + + # Optional pre-built QPCs (skip compilation) + p.add_argument("--tlm_qpc", default=None, help="Pre-compiled TLM qpc dir (skip TLM compile).") + p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") + + # Devices / cores + p.add_argument("--tlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) + p.add_argument("--dlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) + p.add_argument("--tlm_cores", type=int, default=8) + p.add_argument("--dlm_cores", type=int, default=8) + + # Compile / run knobs + p.add_argument("--ctx_len", type=int, default=4096) + p.add_argument("--prefill_seq_len", type=int, default=128) + p.add_argument("--generation_len", type=int, default=1024) + p.add_argument("--iteration", type=int, default=300) + + # Dataset / output + p.add_argument("--dataset", default="humaneval", choices=["humaneval", "gsm8k", "math500"]) + p.add_argument("--num_samples", type=int, default=0, help="0 = all samples") + p.add_argument("--output_dir", default=None, help="Default: ./results-") + p.add_argument("--noise_embed_path", default=None, help="Defaults to noise_embedding/_noise_embeds.npy") + p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN")) + + # Internal modes used by self-spawned compile subprocesses + p.add_argument("--_build", choices=["tlm", "dlm"], default=None, help=argparse.SUPPRESS) + return p.parse_args() + + +# ───────────────────────────────────────────────────────────────────────────── +# Compilation helpers — mirror make_models.py but parameterised +# ───────────────────────────────────────────────────────────────────────────── +def _read_dlm_meta(dlm_repo, hf_token): + from utils import load_dflash_checkpoint + + state_dict, cfg = load_dflash_checkpoint(dlm_repo) + target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) + block_size = cfg.get("block_size", None) + return state_dict, target_layer_ids, block_size + + +def _build_tlm(args, tlm_repo, dlm_repo): + import torch + from transformers import AutoModelForCausalLM + from utils import build_tlm_model + + from QEfficient import QEFFAutoModelForCausalLM + + state_dict, target_layer_ids, block_size = _read_dlm_meta(dlm_repo, args.hf_token) + tlm_target_ids = [i + 1 for i in target_layer_ids] + + print(f"[build_tlm] base={tlm_repo} dlm={dlm_repo} block_size={block_size}") + base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=args.hf_token) + build_tlm_model(base_model, state_dict, tlm_target_ids) + + tlm_qeff = QEFFAutoModelForCausalLM(base_model, qaic_config={"target_layer_ids": tlm_target_ids}) + qpc = tlm_qeff.compile( + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.tlm_cores, + num_devices=len(args.tlm_devices), + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + dflash_block_size=block_size, + ) + print(f"TLM_QPC={qpc}") + return qpc + + +def _build_dlm(args, tlm_repo, dlm_repo): + import torch + from transformers import AutoModelForCausalLM + from utils import build_dlm_model, extract_lm_head + + from QEfficient import QEFFAutoModelForCausalLM + + _, _, block_size = _read_dlm_meta(dlm_repo, args.hf_token) + + print(f"[build_dlm] dlm={dlm_repo} block_size={block_size}") + base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=args.hf_token) + lm_head_w, lm_head_b = extract_lm_head(base_model) + del base_model + + dlm_model = build_dlm_model(dlm_repo, lm_head_w, lm_head_b) + dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) + qpc = dlm_qeff.compile( + prefill_seq_len=block_size, + ctx_len=args.ctx_len, + num_cores=args.dlm_cores, + num_devices=len(args.dlm_devices), + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + prefill_only=True, + ) + print(f"DLM_QPC={qpc}") + return qpc + + +def _spawn_compile(mode, argv_template): + """Run this same script with --_build {mode} in a fresh process and return + the qpc path printed on the line starting with TLM_QPC= or DLM_QPC=.""" + cmd = [sys.executable, os.path.abspath(__file__), "--_build", mode] + argv_template + print(f"\n>>> spawning compile subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, check=False, capture_output=True, text=True) + sys.stdout.write(proc.stdout) + sys.stderr.write(proc.stderr) + if proc.returncode != 0: + raise SystemExit(f"compile subprocess (--_build {mode}) failed (rc={proc.returncode})") + + tag = "TLM_QPC=" if mode == "tlm" else "DLM_QPC=" + qpc_line = next((ln for ln in reversed(proc.stdout.splitlines()) if ln.startswith(tag)), None) + if qpc_line is None: + raise SystemExit(f"could not find {tag} line in compile output") + return qpc_line.split("=", 1)[1].strip() + + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── +def main(): + args = parse_args() + + tlm_repo_default, dlm_repo = MODEL_MAP[args.model_name] + tlm_repo = args.tlm_hf_path or tlm_repo_default + if tlm_repo is None: + raise SystemExit(f"No default TLM HF path for '{args.model_name}'. Pass --tlm_hf_path.") + + # ── Sub-mode: this process exists only to compile one model ───────────── + if args._build == "tlm": + _build_tlm(args, tlm_repo, dlm_repo) + return + if args._build == "dlm": + _build_dlm(args, tlm_repo, dlm_repo) + return + + # ── Resolve / discover hidden_size + block_size from DLM config ──────── + import transformers + + config = transformers.AutoConfig.from_pretrained(dlm_repo, token=args.hf_token, trust_remote_code=True) + hidden_size = config.hidden_size + block_size = getattr(config, "block_size", None) + print(f"DLM repo : {dlm_repo}") + print(f"hidden_size : {hidden_size}") + print(f"block_size : {block_size}") + + # ── Resolve QPC paths (compile only the side that wasn't pre-supplied) ─ + forwarded = [ + "--model_name", + args.model_name, + "--ctx_len", + str(args.ctx_len), + "--prefill_seq_len", + str(args.prefill_seq_len), + "--tlm_cores", + str(args.tlm_cores), + "--dlm_cores", + str(args.dlm_cores), + "--tlm_devices", + *[str(d) for d in args.tlm_devices], + "--dlm_devices", + *[str(d) for d in args.dlm_devices], + ] + if args.tlm_hf_path: + forwarded += ["--tlm_hf_path", args.tlm_hf_path] + if args.hf_token: + forwarded += ["--hf_token", args.hf_token] + + if args.tlm_qpc: + print(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") + tlm_qpc = args.tlm_qpc + else: + tlm_qpc = _spawn_compile("tlm", forwarded) + + if args.dlm_qpc: + print(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") + dlm_qpc = args.dlm_qpc + else: + dlm_qpc = _spawn_compile("dlm", forwarded) + print(f"TLM qpc : {tlm_qpc}") + print(f"DLM qpc : {dlm_qpc}") + + # ── Resolve noise embed path ─────────────────────────────────────────── + noise_embed = args.noise_embed_path or os.path.join( + THIS_DIR, "noise_embedding", f"{args.model_name}_noise_embeds.npy" + ) + if not os.path.exists(noise_embed): + raise SystemExit(f"noise embedding not found: {noise_embed}\nPass --noise_embed_path explicitly.") + + output_dir = args.output_dir or os.path.join(THIS_DIR, f"results-{args.model_name}") + + # ── Run the existing SPD eval script ─────────────────────────────────── + eval_script = os.path.join(THIS_DIR, "dflash_spd_benchmark.py") + cmd = [ + sys.executable, + eval_script, + "--dataset", + args.dataset, + "--tlm_qpc", + tlm_qpc, + "--dlm_qpc", + dlm_qpc, + "--tlm_model_name", + tlm_repo, + "--dlm_model_name", + dlm_repo, + "--noise_embed_path", + noise_embed, + "--iteration", + str(args.iteration), + "--ctx_len", + str(args.ctx_len), + "--generation_len", + str(args.generation_len), + "--tlm_devices", + *[str(d) for d in args.tlm_devices], + "--dlm_devices", + *[str(d) for d in args.dlm_devices], + "--output_dir", + output_dir, + ] + if args.hf_token: + cmd += ["--hf_token", args.hf_token] + if args.num_samples and args.num_samples > 0: + cmd += ["--num_samples", str(args.num_samples)] + + print("\n>>> launching SPD eval:") + print(" ".join(cmd)) + rc = subprocess.run(cmd, check=False).returncode + if rc != 0: + raise SystemExit(f"SPD eval exited with rc={rc}") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/dbg.log b/examples/performance/dflash/dbg.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/performance/dflash/dflash_spd_benchmark.py b/examples/performance/dflash/dflash_spd_benchmark.py new file mode 100644 index 0000000000..0644bffc9b --- /dev/null +++ b/examples/performance/dflash/dflash_spd_benchmark.py @@ -0,0 +1,603 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import csv +import os +import time +from typing import Optional + +import numpy as np +import torch +import transformers +from rich.console import Console +from rich.markup import escape +from utils import load_and_process_dataset + +from QEfficient.generation.cloud_infer import QAICInferenceSession + +torch.manual_seed(42) +np.random.seed(42) + +console = Console() + + +# ===== METRICS ===== + + +class SpecDecodingMetrics: + def __init__(self, block_size: int = 10): + self.block_size = block_size + self.total_prefill_time = 0.0 + self.tlm_decode_time = 0.0 + self.dlm_decode_time = 0.0 + self.total_accepted_tokens = 0 + self.total_rejected_tokens = 0 + self.total_generated_tokens = 0 + self.num_total_iters = 0 + self.acceptance_history = [] + self.generated_ids: list = [] + self.generated_sources: list = [] # "dlm" or "tlm" per token + + def acceptance_rate(self) -> float: + if self.num_total_iters == 0: + return 0.0 + return self.total_generated_tokens / self.num_total_iters + + def dlm_tok_rate(self) -> float: + if self.dlm_decode_time <= 0: + return 0.0 + num_tok_drafted = self.block_size * self.num_total_iters + return num_tok_drafted / self.dlm_decode_time + + def tlm_tok_rate(self) -> float: + if self.tlm_decode_time <= 0: + return 0.0 + ar = self.acceptance_rate() + num_tok_tlm = self.total_generated_tokens / (1 + ar) if (1 + ar) > 0 else 0.0 + return num_tok_tlm / self.tlm_decode_time + + def spd_tok_rate(self) -> float: + total_decode_s = self.tlm_decode_time + self.dlm_decode_time + if total_decode_s <= 0: + return 0.0 + return self.total_generated_tokens / total_decode_s + + +# ===== INFERENCE ===== + + +def run_spd_inference_single( + prompt_text: str, + tokenizer, + dlm_session: QAICInferenceSession, + tlm_session: QAICInferenceSession, + mask_token_embed, + vocab_size: int, + prompt_chunk_size: int, + ctx_len: int = 4096, + block_size: int = 16, + max_iterations: int = 300, + hidden_size: int = 4096, + generation_len: int = 256, +) -> SpecDecodingMetrics: + eos_token_ids = {tokenizer.eos_token_id} if tokenizer.eos_token_id is not None else set() + + prompt = [prompt_text] + batch_size = 1 + metrics = SpecDecodingMetrics(block_size=block_size) + + # Tokenize + tlm_inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = tlm_inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prompt_chunk_size) # ceil divide without float + padded_len = num_chunks * prompt_chunk_size # Convert to a multiple of padded_len + tlm_inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + tlm_inputs["position_ids"] = np.where(tlm_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + tlm_inputs.pop("token_type_ids", None) + tlm_inputs = {k: torch.from_numpy(v) for k, v in tlm_inputs.items()} + tlm_inputs.pop("past_key_values", None) + tlm_inputs = {k: v.detach().numpy() for k, v in tlm_inputs.items()} + prompt_len = padded_len + + generated_ids = np.full((batch_size, ctx_len - prompt_len), tokenizer.pad_token_id) + + # Set output buffers + tlm_session.set_buffers({"logits": np.zeros((batch_size, prompt_chunk_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) + tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) + dlm_session.set_buffers({"logits": np.zeros((batch_size, block_size, vocab_size), dtype=np.float32)}) + + tlm_cache_index = np.array([0]) + dlm_cache_index = np.array([0]) + dlm_inputs = {} + + # ===== PREFILL ===== + prefill_start = time.time() + num_sub_blocks = prompt_chunk_size // block_size + remainder = prompt_chunk_size % block_size + + for pi in range(num_chunks - 1): + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_prefill_outputs = tlm_session.run(chunk_inputs) + ## Add support for when the prefill_seq_len is more than block_size + for sub_i in range(num_sub_blocks): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_session.run(dlm_inputs) + + ## Add support when prefill_seq_len is not a multiple of block_size + if remainder > 0: + sub_start = num_sub_blocks * block_size + target_hidden_rem = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden_rem[:, :remainder, :] = tlm_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target_rem = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target_rem[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["target_hidden"] = target_hidden_rem + dlm_inputs["position_ids_target"] = pos_ids_target_rem + dlm_inputs["position_ids"] = pos_ids_target_rem + block_size + dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_session.run(dlm_inputs) + tlm_cache_index[0] += prompt_chunk_size + dlm_cache_index[0] += prompt_chunk_size + + # Last prefill chunk + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_last_prefill_outputs = tlm_session.run(chunk_inputs) + last_prefill_pos_in_chunk = chunk_inputs["position_ids"].argmax() + new_tlm_token = tlm_last_prefill_outputs["logits"][:, last_prefill_pos_in_chunk] + + ## Add support for when the prefill_seq_len is more than block_size + last_sub = last_prefill_pos_in_chunk // block_size + for sub_i in range(last_sub): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_last_prefill_outputs["hidden_states"][ + :, sub_start : sub_start + block_size, : + ] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_session.run(dlm_inputs) + + noise_embeds = np.tile(mask_token_embed, (1, block_size, 1)) + noise_embeds[:, 0, :] = tlm_last_prefill_outputs["output_embeds"][:, last_prefill_pos_in_chunk, :] + sub_start = last_sub * block_size + + ## Add support when prefill_seq_len is not a multiple of block_size + if last_sub < num_sub_blocks: + target_hidden = tlm_last_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + else: + target_hidden = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden[:, :remainder, :] = tlm_last_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["position_ids_target"] = pos_ids_target + + dlm_inputs["position_ids"] = np.arange( + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1, + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1 + block_size, + ).reshape(1, -1) + dlm_inputs["noise_embeds"] = noise_embeds + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + + metrics.total_prefill_time += time.time() - prefill_start + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + # ===== DECODE ===== + spd_counter_idx = tlm_cache_index[0] + last_prefill_pos_in_chunk + gen_idx = 0 + iteration_count = 0 + continue_generation = True + + tlm_session.set_buffers({"logits": np.zeros((batch_size, block_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) + tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) + + while gen_idx < generation_len and iteration_count < max_iterations and continue_generation: + iteration_count += 1 + dlm_candidates[:, 0] = new_tlm_token + + tlm_decode_start = time.time() + tlm_decode_outputs = tlm_session.run( + { + "input_ids": dlm_candidates, + "position_ids": dlm_inputs["position_ids"], + } + ) + metrics.tlm_decode_time += time.time() - tlm_decode_start + + tlm_logits = tlm_decode_outputs["logits"] + target_hidden = tlm_decode_outputs["hidden_states"] + + accepted_length = 0 + rejected_flag = False + + for spec_idx in range(block_size - 1): + tlm_token = tlm_logits[:, spec_idx] + dlm_token = dlm_candidates[:, spec_idx + 1] + if tlm_token == dlm_token: + accepted_length += 1 + metrics.total_accepted_tokens += 1 + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = dlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(dlm_token[0])) + metrics.generated_sources.append("dlm") + else: + metrics.total_rejected_tokens += block_size - spec_idx - 1 + rejected_flag = True + new_tlm_token = tlm_token + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(tlm_token[0])) + metrics.generated_sources.append("tlm") + break + + metrics.acceptance_history.append(accepted_length) + metrics.total_generated_tokens += accepted_length + 1 + + if not rejected_flag: + new_tlm_token = tlm_logits[:, block_size - 1] + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = new_tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(new_tlm_token[0])) + metrics.generated_sources.append("tlm") + + # EOS check + dlm_candidate_ids = list(dlm_candidates[0, 1 : accepted_length + 1]) + this_iter_gen_ids = dlm_candidate_ids + [new_tlm_token[0]] + for tok_id in this_iter_gen_ids: + if tok_id in eos_token_ids: + continue_generation = False + break + + if not continue_generation: + break + + # Next DLM iteration + dlm_decode_start = time.time() + dlm_inputs["position_ids_target"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape( + 1, -1 + ) + spd_counter_idx += accepted_length + 1 + dlm_inputs["position_ids_target"][:, accepted_length + 1 :] = -1 + dlm_inputs["position_ids"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape(1, -1) + noise_embeds[:, 0, :] = tlm_decode_outputs["output_embeds"][:, accepted_length, :] + dlm_inputs["noise_embeds"] = noise_embeds + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + metrics.dlm_decode_time += time.time() - dlm_decode_start + + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + metrics.num_total_iters = iteration_count + return metrics + + +# ===== DATASET CONFIG ===== + +DATASET_CONFIG = { + "gsm8k": { + "hf_path": "openai/gsm8k", + "hf_name": "main", + "split": "test", + "prompt_field": "question", + "label": "GSM8K", + }, + "math500": { + "hf_path": "HuggingFaceH4/MATH-500", + "hf_name": None, + "split": "test", + "prompt_field": "problem", + "label": "MATH-500", + }, + "humaneval": { + "hf_path": "openai/openai_humaneval", + "hf_name": None, + "split": "test", + "prompt_field": "prompt", + "label": "HumanEval", + }, +} + + +# ===== EVALUATION LOOP ===== + +PER_SAMPLE_FIELDS = [ + "dataset", + "sample_idx", + "acceptance_rate", + "dlm_tps", + "tlm_tps", + "spd_tps", + "total_generated_tokens", + "num_iters", + "prefill_time_s", + "tlm_decode_time_s", + "dlm_decode_time_s", +] + +SUMMARY_FIELDS = [ + "dataset", + "num_evaluated", + "num_total", + "avg_acceptance_rate", + "min_acceptance_rate", + "max_acceptance_rate", + "avg_dlm_tps", + "min_dlm_tps", + "max_dlm_tps", + "avg_tlm_tps", + "min_tlm_tps", + "max_tlm_tps", + "avg_spd_tps", + "min_spd_tps", + "max_spd_tps", +] + + +def _write_per_sample_csv(rows: list, path: str): + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=PER_SAMPLE_FIELDS) + writer.writeheader() + writer.writerows(rows) + console.print(f"[green]Per-sample CSV → {path}[/green]") + + +def _append_summary_csv(row: dict, path: str): + write_header = not os.path.exists(path) + with open(path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=SUMMARY_FIELDS) + if write_header: + writer.writeheader() + writer.writerow(row) + console.print(f"[green]Summary CSV → {path}[/green]") + + +def evaluate_dataset( + dataset_name: str, + tokenizer, + dlm_session, + tlm_session, + mask_token_embed, + vocab_size: int, + prompt_chunk_size: int, + ctx_len: int = 4096, + block_size: int = 10, + max_iterations: int = 300, + hidden_size: int = 4096, + generation_len: int = 1024, + num_samples: Optional[int] = None, + output_dir: str = "./results", +): + cfg = DATASET_CONFIG[dataset_name] + console.print(f"[bold blue]Loading {cfg['label']} dataset...[/bold blue]") + dataset = load_and_process_dataset(dataset_name) + + if num_samples is not None: + # dataset = dataset.shuffle(seed=0).select(range(min(num_samples, len(dataset)))) + dataset = dataset.select(range(min(num_samples, len(dataset)))) + console.print(f"[green]✓ Loaded {len(dataset)} {cfg['label']} problems[/green]") + + all_ar, all_dlm_tps, all_tlm_tps, all_spd_tps = [], [], [], [] + per_sample_rows = [] + + for i, sample in enumerate(dataset): + user_content = sample["turns"][0] + messages = [{"role": "user", "content": user_content}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + console.print(f"[cyan]({i + 1}/{len(dataset)})[/cyan] Input: {user_content[:80].strip()}") + + # try: + metrics = run_spd_inference_single( + prompt_text=prompt_text, + tokenizer=tokenizer, + dlm_session=dlm_session, + tlm_session=tlm_session, + vocab_size=vocab_size, + prompt_chunk_size=prompt_chunk_size, + ctx_len=ctx_len, + block_size=block_size, + max_iterations=max_iterations, + hidden_size=hidden_size, + generation_len=generation_len, + mask_token_embed=mask_token_embed, + ) + + ar = metrics.acceptance_rate() + dlm_tps = metrics.dlm_tok_rate() + tlm_tps = metrics.tlm_tok_rate() + spd_tps = metrics.spd_tok_rate() + + all_ar.append(ar) + all_dlm_tps.append(dlm_tps) + all_tlm_tps.append(tlm_tps) + all_spd_tps.append(spd_tps) + + per_sample_rows.append( + { + "dataset": dataset_name, + "sample_idx": i, + "acceptance_rate": round(ar, 4), + "dlm_tps": round(dlm_tps, 2), + "tlm_tps": round(tlm_tps, 2), + "spd_tps": round(spd_tps, 2), + "total_generated_tokens": metrics.total_generated_tokens, + "num_iters": metrics.num_total_iters, + "prefill_time_s": round(metrics.total_prefill_time, 4), + "tlm_decode_time_s": round(metrics.tlm_decode_time, 4), + "dlm_decode_time_s": round(metrics.dlm_decode_time, 4), + } + ) + + console.print(f" AR={ar:.2f} DLM={dlm_tps:.1f} tok/s TLM={tlm_tps:.1f} tok/s SPD={spd_tps:.1f} tok/s") + + output_parts = ["Output: "] + for tok_id, source in zip(metrics.generated_ids, metrics.generated_sources): + text = escape(tokenizer.decode([tok_id], skip_special_tokens=True)) + if source == "dlm": + output_parts.append(f"[blue]{text}[/blue]") + else: + output_parts.append(f"[white]{text}[/white]") + console.print("".join(output_parts)) + + # except Exception as e: + # console.print(f"[red] ✗ Error on sample {i}: {e}[/red]") + + # ===== SUMMARY ===== + if all_ar: + w = 46 + print("\n" + "=" * w) + print(f" {cfg['label']} SPD Evaluation — Averages") + print("=" * w) + print(f" {'Metric':<30} {'Avg':>6} {'Min':>6} {'Max':>6}") + print("-" * w) + for name, vals in [ + ("Acceptance Rate (tok/iter)", all_ar), + ("DLM Throughput (tok/s)", all_dlm_tps), + ("TLM Throughput (tok/s)", all_tlm_tps), + ("SPD Decode Speed (tok/s)", all_spd_tps), + ]: + print(f" {name:<30} {np.mean(vals):>6.2f} {np.min(vals):>6.2f} {np.max(vals):>6.2f}") + print("=" * w) + print(f" Evaluated {len(all_ar)} / {len(dataset)} samples successfully.") + print("=" * w + "\n") + + # ===== SAVE CSV ===== + os.makedirs(output_dir, exist_ok=True) + _write_per_sample_csv( + per_sample_rows, + os.path.join(output_dir, f"{dataset_name}_per_sample.csv"), + ) + _append_summary_csv( + { + "dataset": dataset_name, + "num_evaluated": len(all_ar), + "num_total": len(dataset), + "avg_acceptance_rate": round(float(np.mean(all_ar)), 4), + "min_acceptance_rate": round(float(np.min(all_ar)), 4), + "max_acceptance_rate": round(float(np.max(all_ar)), 4), + "avg_dlm_tps": round(float(np.mean(all_dlm_tps)), 2), + "min_dlm_tps": round(float(np.min(all_dlm_tps)), 2), + "max_dlm_tps": round(float(np.max(all_dlm_tps)), 2), + "avg_tlm_tps": round(float(np.mean(all_tlm_tps)), 2), + "min_tlm_tps": round(float(np.min(all_tlm_tps)), 2), + "max_tlm_tps": round(float(np.max(all_tlm_tps)), 2), + "avg_spd_tps": round(float(np.mean(all_spd_tps)), 2), + "min_spd_tps": round(float(np.min(all_spd_tps)), 2), + "max_spd_tps": round(float(np.max(all_spd_tps)), 2), + }, + os.path.join(output_dir, "summary.csv"), + ) + else: + print("No successful results.") + + +# ===== ARGUMENT PARSING ===== + + +def parse_args(): + parser = argparse.ArgumentParser(description="SPD benchmark — gsm8k / math500 / humaneval") + parser.add_argument("--dataset", required=True, choices=list(DATASET_CONFIG.keys())) + parser.add_argument("--tlm_qpc", required=True) + parser.add_argument("--dlm_qpc", required=True) + parser.add_argument("--tlm_model_name", required=True) + parser.add_argument("--dlm_model_name", required=True) + parser.add_argument("--noise_embed_path", required=True) + parser.add_argument("--iteration", type=int, default=300) + parser.add_argument("--ctx_len", type=int, default=4096) + parser.add_argument("--generation_len", type=int, default=1024) + parser.add_argument("--tlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--dlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--hf_token", default=None) + parser.add_argument("--num_samples", type=int, default=0, help="Number of samples to run (0 = all)") + parser.add_argument("--output_dir", default="./results", help="Directory for CSV output (default: ./results)") + return parser.parse_args() + + +# ===== MAIN ===== + + +def main(): + args = parse_args() + num_samples = args.num_samples if args.num_samples > 0 else None + + console.print("[bold blue]Loading tokenizer and config...[/bold blue]") + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.tlm_model_name, token=args.hf_token, trust_remote_code=True + ) + config = transformers.AutoConfig.from_pretrained(args.dlm_model_name, token=args.hf_token, trust_remote_code=True) + vocab_size = config.vocab_size + hidden_size = config.hidden_size + block_size = config.block_size + mask_token_embed = np.load(args.noise_embed_path) + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + console.print("[bold blue]Loading QAIC inference sessions...[/bold blue]") + dlm_session = QAICInferenceSession(args.dlm_qpc, args.dlm_devices) + tlm_session = QAICInferenceSession(args.tlm_qpc, args.tlm_devices) + dlm_session.skip_buffers( + set([x for x in dlm_session.input_names + dlm_session.output_names if x.startswith("past_")]) + ) + tlm_session.skip_buffers( + set([x for x in tlm_session.input_names + tlm_session.output_names if x.startswith("past_")]) + ) + + prompt_chunk_size = max( + [x[tlm_session.binding_index_map["input_ids"]][1][1] for x in tlm_session.allowed_shapes] + + [tlm_session.bindings[tlm_session.binding_index_map["input_ids"]].dims[1]] + ) + console.print(f"prompt_chunk_size = {prompt_chunk_size}") + + evaluate_dataset( + dataset_name=args.dataset, + tokenizer=tokenizer, + dlm_session=dlm_session, + tlm_session=tlm_session, + vocab_size=vocab_size, + prompt_chunk_size=prompt_chunk_size, + mask_token_embed=mask_token_embed, + ctx_len=args.ctx_len, + block_size=block_size, + max_iterations=args.iteration, + hidden_size=hidden_size, + generation_len=args.generation_len, + num_samples=num_samples, + output_dir=args.output_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/dflash_spd_single_prompt.py b/examples/performance/dflash/dflash_spd_single_prompt.py new file mode 100644 index 0000000000..5cff09c9cc --- /dev/null +++ b/examples/performance/dflash/dflash_spd_single_prompt.py @@ -0,0 +1,401 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import time + +import numpy as np +import torch +import transformers +from rich.console import Console +from rich.markup import escape +from utils import format_prompt + +from QEfficient.generation.cloud_infer import QAICInferenceSession + +torch.manual_seed(42) +np.random.seed(42) + +console = Console() + + +# ===== METRICS ===== + + +class SpecDecodingMetrics: + def __init__(self, block_size: int = 10): + self.block_size = block_size + self.total_prefill_time = 0.0 + self.tlm_decode_time = 0.0 + self.dlm_decode_time = 0.0 + self.total_accepted_tokens = 0 + self.total_rejected_tokens = 0 + self.total_generated_tokens = 0 + self.num_total_iters = 0 + self.acceptance_history = [] + self.generated_ids: list = [] + self.generated_sources: list = [] + + def acceptance_rate(self) -> float: + if self.num_total_iters == 0: + return 0.0 + return self.total_generated_tokens / self.num_total_iters + + def dlm_tok_rate(self) -> float: + if self.dlm_decode_time <= 0: + return 0.0 + return (self.block_size * self.num_total_iters) / self.dlm_decode_time + + def tlm_tok_rate(self) -> float: + if self.tlm_decode_time <= 0: + return 0.0 + ar = self.acceptance_rate() + num_tok_tlm = self.total_generated_tokens / (1 + ar) if (1 + ar) > 0 else 0.0 + return num_tok_tlm / self.tlm_decode_time + + def spd_tok_rate(self) -> float: + total_decode_s = self.tlm_decode_time + self.dlm_decode_time + if total_decode_s <= 0: + return 0.0 + return self.total_generated_tokens / total_decode_s + + +# ===== INFERENCE ===== + + +def run_spd_inference_single( + prompt_text: str, + tokenizer, + dlm_session: QAICInferenceSession, + tlm_session: QAICInferenceSession, + mask_token_embed, + vocab_size: int, + prompt_chunk_size: int, + ctx_len: int = 4096, + block_size: int = 16, + max_iterations: int = 300, + hidden_size: int = 4096, + generation_len: int = 256, +) -> SpecDecodingMetrics: + eos_token_ids = {tokenizer.eos_token_id} if tokenizer.eos_token_id is not None else set() + + prompt = [prompt_text] + batch_size = 1 + metrics = SpecDecodingMetrics(block_size=block_size) + + # Tokenize + tlm_inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = tlm_inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -prompt_chunk_size) + padded_len = num_chunks * prompt_chunk_size + tlm_inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + tlm_inputs["position_ids"] = np.where(tlm_inputs.pop("attention_mask"), np.arange(padded_len), -1) + + tlm_inputs.pop("token_type_ids", None) + tlm_inputs = {k: torch.from_numpy(v) for k, v in tlm_inputs.items()} + tlm_inputs.pop("past_key_values", None) + tlm_inputs = {k: v.detach().numpy() for k, v in tlm_inputs.items()} + + generated_ids = np.full((batch_size, ctx_len - padded_len), tokenizer.pad_token_id) + + # Set output buffers + tlm_session.set_buffers({"logits": np.zeros((batch_size, prompt_chunk_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) + tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) + dlm_session.set_buffers({"logits": np.zeros((batch_size, block_size, vocab_size), dtype=np.float32)}) + + tlm_cache_index = np.array([0]) + dlm_cache_index = np.array([0]) + dlm_inputs = {} + + # ===== PREFILL ===== + prefill_start = time.time() + num_sub_blocks = prompt_chunk_size // block_size + remainder = prompt_chunk_size % block_size + + for pi in range(num_chunks - 1): + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_prefill_outputs = tlm_session.run(chunk_inputs) + for sub_i in range(num_sub_blocks): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_session.run(dlm_inputs) + if remainder > 0: + sub_start = num_sub_blocks * block_size + target_hidden_rem = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden_rem[:, :remainder, :] = tlm_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target_rem = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target_rem[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["target_hidden"] = target_hidden_rem + dlm_inputs["position_ids_target"] = pos_ids_target_rem + dlm_inputs["position_ids"] = pos_ids_target_rem + block_size + dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_session.run(dlm_inputs) + tlm_cache_index[0] += prompt_chunk_size + dlm_cache_index[0] += prompt_chunk_size + + # Last prefill chunk + chunk_inputs = { + "input_ids": tlm_inputs["input_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + "position_ids": tlm_inputs["position_ids"][:, tlm_cache_index[0] : tlm_cache_index[0] + prompt_chunk_size], + } + tlm_last_prefill_outputs = tlm_session.run(chunk_inputs) + last_prefill_pos_in_chunk = chunk_inputs["position_ids"].argmax() + new_tlm_token = tlm_last_prefill_outputs["logits"][:, last_prefill_pos_in_chunk] + + last_sub = last_prefill_pos_in_chunk // block_size + for sub_i in range(last_sub): + sub_start = sub_i * block_size + dlm_inputs["target_hidden"] = tlm_last_prefill_outputs["hidden_states"][ + :, sub_start : sub_start + block_size, : + ] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size + dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_session.run(dlm_inputs) + + noise_embeds = np.tile(mask_token_embed, (1, block_size, 1)) + noise_embeds[:, 0, :] = tlm_last_prefill_outputs["output_embeds"][:, last_prefill_pos_in_chunk, :] + sub_start = last_sub * block_size + if last_sub < num_sub_blocks: + target_hidden = tlm_last_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] + dlm_inputs["position_ids_target"] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size + ] + else: + target_hidden = np.zeros((1, block_size, hidden_size), dtype=np.float32) + target_hidden[:, :remainder, :] = tlm_last_prefill_outputs["hidden_states"][:, sub_start:, :] + pos_ids_target = np.full((1, block_size), -1, dtype=tlm_inputs["position_ids"].dtype) + pos_ids_target[:, :remainder] = tlm_inputs["position_ids"][ + :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + remainder + ] + dlm_inputs["position_ids_target"] = pos_ids_target + dlm_inputs["position_ids"] = np.arange( + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1, + tlm_cache_index[0] + last_prefill_pos_in_chunk + 1 + block_size, + ).reshape(1, -1) + dlm_inputs["noise_embeds"] = noise_embeds + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + + metrics.total_prefill_time += time.time() - prefill_start + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + # ===== DECODE ===== + spd_counter_idx = tlm_cache_index[0] + last_prefill_pos_in_chunk + gen_idx = 0 + iteration_count = 0 + continue_generation = True + + tlm_session.set_buffers({"logits": np.zeros((batch_size, block_size), dtype=np.int32)}) + tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) + tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) + + while gen_idx < generation_len and iteration_count < max_iterations and continue_generation: + iteration_count += 1 + dlm_candidates[:, 0] = new_tlm_token + + tlm_decode_start = time.time() + tlm_decode_outputs = tlm_session.run( + { + "input_ids": dlm_candidates, + "position_ids": dlm_inputs["position_ids"], + } + ) + metrics.tlm_decode_time += time.time() - tlm_decode_start + + tlm_logits = tlm_decode_outputs["logits"] + target_hidden = tlm_decode_outputs["hidden_states"] + + accepted_length = 0 + rejected_flag = False + + for spec_idx in range(block_size - 1): + tlm_token = tlm_logits[:, spec_idx] + dlm_token = dlm_candidates[:, spec_idx + 1] + if tlm_token == dlm_token: + accepted_length += 1 + metrics.total_accepted_tokens += 1 + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = dlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(dlm_token[0])) + metrics.generated_sources.append("dlm") + else: + metrics.total_rejected_tokens += block_size - spec_idx - 1 + rejected_flag = True + new_tlm_token = tlm_token + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(tlm_token[0])) + metrics.generated_sources.append("tlm") + break + + metrics.acceptance_history.append(accepted_length) + metrics.total_generated_tokens += accepted_length + 1 + + if not rejected_flag: + new_tlm_token = tlm_logits[:, block_size - 1] + if gen_idx < len(generated_ids[0]): + generated_ids[0, gen_idx] = new_tlm_token[0] + gen_idx += 1 + metrics.generated_ids.append(int(new_tlm_token[0])) + metrics.generated_sources.append("tlm") + + dlm_candidate_ids = list(dlm_candidates[0, 1 : accepted_length + 1]) + this_iter_gen_ids = dlm_candidate_ids + [new_tlm_token[0]] + for tok_id in this_iter_gen_ids: + if tok_id in eos_token_ids: + continue_generation = False + break + + if not continue_generation: + break + + dlm_decode_start = time.time() + dlm_inputs["position_ids_target"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape( + 1, -1 + ) + spd_counter_idx += accepted_length + 1 + dlm_inputs["position_ids_target"][:, accepted_length + 1 :] = -1 + dlm_inputs["position_ids"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape(1, -1) + noise_embeds[:, 0, :] = tlm_decode_outputs["output_embeds"][:, accepted_length, :] + dlm_inputs["noise_embeds"] = noise_embeds + dlm_inputs["target_hidden"] = target_hidden + dlm_outputs = dlm_session.run(dlm_inputs) + metrics.dlm_decode_time += time.time() - dlm_decode_start + + dlm_candidates = dlm_outputs["logits"].argmax(axis=-1) + + metrics.num_total_iters = iteration_count + return metrics + + +# ===== ARGUMENT PARSING ===== + + +def parse_args(): + parser = argparse.ArgumentParser(description="SPD single-prompt inference") + parser.add_argument("--prompt", required=True, help="Input prompt text") + parser.add_argument("--tlm_qpc", required=True) + parser.add_argument("--dlm_qpc", required=True) + parser.add_argument("--tlm_model_name", required=True) + parser.add_argument("--dlm_model_name", required=True) + parser.add_argument("--noise_embed_path", required=True) + parser.add_argument("--iteration", type=int, default=300) + parser.add_argument("--ctx_len", type=int, default=4096) + parser.add_argument("--generation_len", type=int, default=256) + parser.add_argument("--tlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--dlm_devices", nargs="+", type=int, required=True) + parser.add_argument("--hf_token", default=None) + parser.add_argument( + "--category", + default="", + help="Prompt category for formatting (math, coding, reasoning, …). Defaults to the general reasoning format.", + ) + return parser.parse_args() + + +# ===== MAIN ===== + + +def main(): + args = parse_args() + + console.print("[bold blue]Loading tokenizer and config...[/bold blue]") + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.tlm_model_name, token=args.hf_token, trust_remote_code=True + ) + config = transformers.AutoConfig.from_pretrained(args.dlm_model_name, token=args.hf_token, trust_remote_code=True) + vocab_size = config.vocab_size + hidden_size = config.hidden_size + block_size = config.block_size + mask_token_embed = np.load(args.noise_embed_path) + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + console.print("[bold blue]Loading QAIC inference sessions...[/bold blue]") + dlm_session = QAICInferenceSession(args.dlm_qpc, args.dlm_devices) + tlm_session = QAICInferenceSession(args.tlm_qpc, args.tlm_devices) + dlm_session.skip_buffers( + set([x for x in dlm_session.input_names + dlm_session.output_names if x.startswith("past_")]) + ) + tlm_session.skip_buffers( + set([x for x in tlm_session.input_names + tlm_session.output_names if x.startswith("past_")]) + ) + + prompt_chunk_size = max( + [x[tlm_session.binding_index_map["input_ids"]][1][1] for x in tlm_session.allowed_shapes] + + [tlm_session.bindings[tlm_session.binding_index_map["input_ids"]].dims[1]] + ) + console.print(f"prompt_chunk_size = {prompt_chunk_size}") + + messages = [{"role": "user", "content": format_prompt(args.prompt, args.category)}] + prompt_text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + console.print(f"[cyan]Input:[/cyan] {args.prompt[:120].strip()}") + + metrics = run_spd_inference_single( + prompt_text=prompt_text, + tokenizer=tokenizer, + dlm_session=dlm_session, + tlm_session=tlm_session, + vocab_size=vocab_size, + prompt_chunk_size=prompt_chunk_size, + ctx_len=args.ctx_len, + block_size=block_size, + max_iterations=args.iteration, + hidden_size=hidden_size, + generation_len=args.generation_len, + mask_token_embed=mask_token_embed, + ) + + output_parts = ["Output: "] + for tok_id, source in zip(metrics.generated_ids, metrics.generated_sources): + text = escape(tokenizer.decode([tok_id], skip_special_tokens=True)) + if source == "dlm": + output_parts.append(f"[blue]{text}[/blue]") + else: + output_parts.append(f"[white]{text}[/white]") + console.print("".join(output_parts)) + + ar = metrics.acceptance_rate() + dlm_tps = metrics.dlm_tok_rate() + tlm_tps = metrics.tlm_tok_rate() + spd_tps = metrics.spd_tok_rate() + + w = 46 + print("\n" + "=" * w) + print(" SPD Inference — Metrics") + print("=" * w) + print(f" {'Acceptance Rate (tok/iter)':<30} {ar:>6.2f}") + print(f" {'DLM Throughput (tok/s)':<30} {dlm_tps:>6.1f}") + print(f" {'TLM Throughput (tok/s)':<30} {tlm_tps:>6.1f}") + print(f" {'SPD Decode Speed (tok/s)':<30} {spd_tps:>6.1f}") + print(f" {'Generated tokens':<30} {metrics.total_generated_tokens:>6}") + print(f" {'Iterations':<30} {metrics.num_total_iters:>6}") + print(f" {'Prefill time (s)':<30} {metrics.total_prefill_time:>6.3f}") + print("=" * w + "\n") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/make_models.py b/examples/performance/dflash/make_models.py new file mode 100644 index 0000000000..8cbeeae4e4 --- /dev/null +++ b/examples/performance/dflash/make_models.py @@ -0,0 +1,144 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Build and compile TLM + DLM QPC packages for DFlash speculative decoding. + +Usage: + python make_models.py # build both (TLM + DLM) in separate subprocesses + python make_models.py --mode tlm # build TLM only (single process) + python make_models.py --mode dlm # build DLM only (single process) + python make_models.py --mode both # alias for default + +The default 'both' mode launches one subprocess per model so that compiler +state from the TLM build cannot affect the DLM build (back-to-back compiles +in the same process have been observed to segfault). +""" + +import argparse +import os +import subprocess +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +import torch +from transformers import AutoModelForCausalLM +from utils import build_dlm_model, build_tlm_model, extract_lm_head, load_dflash_checkpoint + +from QEfficient import QEFFAutoModelForCausalLM + +# ── Paths ───────────────────────────────────────────────────────────────────── +TLM_MODEL_PATH = "Qwen/Qwen3-4B" +# TLM_MODEL_PATH = "Qwen/Qwen3-8B" +# TLM_MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct" + +DFLASH_MODEL_PATH = "z-lab/Qwen3-4B-DFlash-b16" +# DFLASH_MODEL_PATH = "z-lab/Qwen3-8B-DFlash-b16" +# DFLASH_MODEL_PATH = "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + +# ── Compile options ─────────────────────────────────────────────────────────── +COMPILE_KWARGS = dict( + ctx_len=4096, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, +) + +TLM_NUM_CORES = 8 +DLM_NUM_CORES = 16 + + +def _load_dflash_meta(): + dflash_state_dict, cfg = load_dflash_checkpoint(DFLASH_MODEL_PATH) + target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) + mask_token_id = cfg.get("dflash_config", {}).get("mask_token_id", []) + block_size = cfg.get("block_size", None) + print(f" target_layer_ids : {target_layer_ids}") + print(f" mask_token_id : {mask_token_id}") + print(f" block_size : {block_size}") + return dflash_state_dict, target_layer_ids, block_size + + +def build_tlm(): + print(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") + dflash_state_dict, target_layer_ids, block_size = _load_dflash_meta() + + print(f"\nLoading base model: {TLM_MODEL_PATH}") + base_model = AutoModelForCausalLM.from_pretrained(TLM_MODEL_PATH, torch_dtype=torch.float32) + + print("\n=== TLM ===") + tlm_target_ids = [i + 1 for i in target_layer_ids] + build_tlm_model(base_model, dflash_state_dict, tlm_target_ids) + + tlm_qeff = QEFFAutoModelForCausalLM(base_model, qaic_config={"target_layer_ids": tlm_target_ids}) + tlm_qpc_path = tlm_qeff.compile( + prefill_seq_len=128, + num_cores=TLM_NUM_CORES, + dflash_block_size=block_size, + **COMPILE_KWARGS, + ) + print(f"tlm_qpc_path: {tlm_qpc_path}") + return tlm_qpc_path + + +def build_dlm(): + print(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") + _, _, block_size = _load_dflash_meta() + + print(f"\nLoading base model (for lm_head): {TLM_MODEL_PATH}") + base_model = AutoModelForCausalLM.from_pretrained(TLM_MODEL_PATH, torch_dtype=torch.float32) + lm_head_weight, lm_head_bias = extract_lm_head(base_model) + del base_model + + print("\n=== DLM ===") + dlm_model = build_dlm_model(DFLASH_MODEL_PATH, lm_head_weight, lm_head_bias) + + dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) + dlm_qpc_path = dlm_qeff.compile( + prefill_seq_len=block_size, + num_cores=DLM_NUM_CORES, + prefill_only=True, + **COMPILE_KWARGS, + ) + print(f"dlm_qpc_path: {dlm_qpc_path}") + return dlm_qpc_path + + +def _run_subprocess(mode: str): + print(f"\n>>> Spawning subprocess: --mode {mode}") + result = subprocess.run( + [sys.executable, os.path.abspath(__file__), "--mode", mode], + check=False, + ) + if result.returncode != 0: + raise SystemExit(f"Subprocess for --mode {mode} exited with code {result.returncode}") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--mode", + choices=["tlm", "dlm", "both"], + default="both", + help="Which model(s) to build. 'both' (default) runs TLM and DLM in separate subprocesses.", + ) + args = parser.parse_args() + + if args.mode == "tlm": + build_tlm() + elif args.mode == "dlm": + build_dlm() + else: + _run_subprocess("tlm") + _run_subprocess("dlm") + print("\n=== Done ===") + + +if __name__ == "__main__": + main() diff --git a/examples/performance/dflash/noise_embedding/Llama-3.1-8B-Instruct_noise_embeds.npy b/examples/performance/dflash/noise_embedding/Llama-3.1-8B-Instruct_noise_embeds.npy new file mode 100755 index 0000000000000000000000000000000000000000..d81c2348d6046718e449e9c7896a4f15c61fcc8a GIT binary patch literal 16512 zcmbW;jqlHO`#10pDU~9WzMoI0x)`a8X}EHet}JG#i_)*0vDCWZl>3`R z^7)kHW*F0Im||KbV_G#TwOU#0#!^`|LvdGJbKOl-~TQymF};UN)xN4(!1qSsWp2U zOot}?kJpq+Uy6MX$5xk0ool7i5ZW^Sf5dNw2jrGuF>M4MEoTYrk=GR;q~-dL8UH5x zP4=p~Qt2DcaygqhQlD* zQ?v?hhrckU7d{|11<%I!&^q7`;i7V>)Kl*L>^9Dz~F+w5pa$m-E-*fABw( z_aMGgyb<`4( zREwI}f~V3C;~v&wg!{k5{ zFaM0()aPM#1NL+n1x>`i!P)cuo^b&lMw?hJmF|;26t^|zGdv%sH*}KMoZVIKJjk3jQh%4Rvo@!UY0f@Pdzk$Xn9jb9cCGkP*Xe3Z=CcRFSBXs$8*GoVk30E! zuC;I{eGh*SY%c33{2Sq7_6(dpvanJrjdb0b8eGU9!=E9(81Ci&r)n)|=~WH!DbM<( zSkBwq@SONZ?!8EC3Vce-ULAJLqgNF3xlYdA^u%AM|G<8YwhYdvjjNSP{XJJZ=t4Uv zzK@o9H>Z8XUmRQR5_)<>>ii5WgnCGw?hs2|jl@ry=L?XzJ_`d}=N=d;o-^DZ-yp97 z*YgKf^+9Wpb6lpSeqN$Yfz;1Qd=4L`uf;3iecH>gftG7`trTN&rZcDC;W@aTzevt( z$g}K*3)ny5@y0|yWUhU1YG)lB&$l6S{+2e3om#nxU#;3B+Af^)pZ&^yuY}BP zh1~VH8N6H8NAPkvOW3=_-Z%Ct{$TbEFbkH8rAOsl=XtuF6zjpCiI>qb_pId&xQ}>0 z$n!R%jfa;#%QgH#;th<+I<rtVS)pW@~)9!iy>U!7sMVb8`(!$|9ig6I zC%%<7pZ_DjCI4FdI6u$w|N05exjjbD8hqe6rqCur>ZQHdJ@A!hS|X_q;frZ);8kNXuP^YuuzZgbMT*uFxSJUpq57B>EtuD$%k2tr+8K7^0 z%x4|HPR=F#%W%%mD{}XV&4t)I%;{I0GkO#3rH|%cCI3GD6!slB_s$vWj4z_4KV{8R zQ!OjbC_GKOf_8_wE_K}r?E7fn!FF?4-X3>uA03i*gQIi>%xC&c% z`Q5}C;}x|1c%OSUqgGBo%M{C^~#>Fz+LfF+!CjT zzJ%;;&P|@7%FbH%W>13d)#APxw-WwM>j)pRmqDKWE@OHavx5B@tY9)99z7`3R>^<=p&--2vZ%Q@`iK&T{eF> z_s19Ux6${qFNC?CVT@-QP0KkJYND-p=6OH<&ND1z zr(UxbS-(f|L9z6m!D3m*xB2yO5Gr!z_6+Bn7MsSq2TI>e%OaE&p_BCxXExoP{ zJGFPc^%{ax2Xo5BJMA5Or@ZH35xmH6C#N4C<9g4s^PbAHjKjC`C$Mvde{sDxTr2Ow ztm|0#9p=hAjt7aYE*E?I7f%1Jh%IFg7vBgaTGseg_qvRIH~Rs+jNjDtvv;Rxb&zkH zC#%KxSFYcZpXcf-C-1;VjLA8lZjb8mJus4*X*oQclO@D^*qBb)#4p}DZWN5 zXKID`D%ZM znd`-J@=TYoe}He{)(a=G*G&a`CsI7Pt|lpU%a{tEJL!a>ldMqjIKFbDi-~2>&eSx{CLWyh;4?*#mG3 z*Z70BO#CWZfA)v;)NN`fHF!pBhdqcrf%Z8}6;EAU00-eOvHbn!eK;sT^*ELN9Hejc zrDYy@*JLg0E5*C(_Qcq2#Zw<|8#g>YJKwzb@}FT}4mY#Yt9Mt6_t-w#C+y52efn+J z`Zr!n%d?D+zGuO~*g+a|Y(|vj=DJaA*SAgFNe1>_PZv_h^T^!rfQ8 zk8AEXrU^ZJJ;ik<%Dszzj(xqncX56{r#5C()f&u()L%US-77*RuPtqm9NnH`#-n$M$d=^ygRUw`OC`v4l#ABT~& z*emg^^wrRjKLN5Q10ZX+2xpHczzI3e($a4}$E$tMeM3th+<+J3hs6eY&e$)~ni-eA z5*$*e`Ht8_-zz6;n(x1?MIAfOU53;Ab#n8rUoZbj=uK+^+1sP=v{>F#sgH+ouKOK3 z>vV;@!L&p0JwNrFKJ*G??ec6_(yqqKao$(&%N@p_V_xZ{<6sBvNAbMl)8hwLig#e1 zH}}haH>9O6Zsn)HXFhrVTuZx#-3#}DY3%m+1)Q}wD1QdKDZ4X!DEk%|2oD(lHGO5Z zxcD-hj40tB;*LTG@{>$bKMD$Iy)?W1MOCJ>USZY&%cGe z3ExxGyFFuazv4d*&1gf#^PSh1f0q9gyh}SK?=jkTyiYu}ay@+(tqZ=HUt%}Kslk=< zp636o=32?mJ;$vndTsjkbbh{nviEbv*Rs#D?`BUE`w;SuyS-ZUuk^F4VLmP2z^&L- zd`^6FtvJ)?)Jiw@Y&a?}eSIs8fG+$o_G^>aMf`906`cBhrdsr-oSS8^#vF1kU!tep zvVU!9IcGKZJVHMpzJQ&&`bO*{+O?2#dm6ut-;(o{obBwqA3N}~PIWkEA=m2wGv&ws z73X~RWzU7=U5r=L(_a_6PXqqV?3|JR@O$7VXxLYsS6DZX%Xm`5SwYob;7Z z&db-vT?l!O(Q?vX@?PC#{Ti0lB|H0f5HG`*i>Dr1!}aV(X=~XZ;r94Z*uwtOy>l+p zKWDSIRExU%!gZUmH~5{pOMHj?qx^yF_whw?U$!?D{z;f%{DsgT&WPt3`{C^OUfM70 zypzsnzs+8cyFdl*aj)#f6YRSyYPMW_Yvo;Xaqc5F4i3?uur_%IAA`}+tWA64rm}a# z=5le3-NvP6QtxAE^X0sbSK{nzzJGGAlaoHOjh}Pfh<`ck;9tsL!Op!>m-R5a>>Tl% z;Hx3uMCluMx?lFDbxj|}JzO)jc8a|QvaYF#Mmgl*~O}K@zv)Ngv>`&%- zyjuKbXd0FiZ3DYcO-KO~klT|mpbH|KUE{Xu&=!kGKS^1jO2 z=6nvfUz72}@^Tgy$^D+6{mNP$pjGVgi|`b`6@8i9S;jBN>1!=%&+sS9-Nk=~p7+Zl zc50+1_B+n`NzeP3z7uYgmwD$bj)lAv6VLoNil@hoWH*9gm7;(5#r+`9eWCa%*beLX zhds-y)#4qq6*rf6TrB(1hW|ExQQit1J>yzbi?g|j{v-QR+OP06tc5o`Th8)C*BXnn z*E{ID_&HbaitlhHKccO1jqF=XvGk~I;(1U0X)SKWLuf<9=0o;n96SV_Tq8YXUd_JA zTY}Gdmea=1p=WQ-gU0;v_*>WOCogN7K9av9X3eVX0gyBC1@x{KXTF121LHpDKL?+( zZ=hxF5AbKRtN48S^R$t8E-a%h7aN9K@Uvz=;+?eSv^lVueU|<`yE!{G)kb~gtX(X3 zpt+?E(u>khJ{7--z729-8q%^JS-<@5$~oV!*1p8EXj$vbBei%AlD9;D&cH7Jp52ta zmo^A;-g@Dkw8`wB@hvb*-Z%WzT54_>PHo(ePr(plrt`mb-`Dwjrkt;w?C9nf0FuudzGgoTJrZ%j(P@?~t3{!d=CV zLI-(ImSWaKy4CSo163;i)K=x5~?wg))wVcc=HSn-Ce}KKF zVx94m@R2ds;;i?|YH{DZbI+&kz&USMuv2fjPER>`*6i8q(1X4TXARbr^ZUyr-XxsLrS;I1U*`FD;;HgW_$ZzRAHpVMp0dwd z>Di0D)xV#OpT2Wdt@!<(Gd+xdFLYveG;RPp>z&$qn3kGM9rvJp06Q`tIOG25BdLq5 zPv-rB+%0mB!e;g#a?+bl$nPMho1DJ#8?*DC%ek7%ZvnNFa8a1 z2km-!^IUHny(4=QPVIgwmL8jT!8Y0=c53-twRi{C+_RI|Becw8HGKlU684+t9k@?T z55*(!8F~HLkFv9;_u+-Kq4+xfcH=v-4b|>|g^=|Lp>1!_-vkSVh^Gr93%@LcA(~Eo3 zr$Y8Nf190!Q*&9no9LsB$vgNoZVfBS#lHT9zr3=j#oqFAjokYpob#T#x`clo|3|s+ z^1s1zY3WxvYctv1an41~de(oaIby~HB_w{a^@70{A*Kk*37s|=+^P6EftPuM* zyBQoT7kz0gEobyY+8Xv?`U!h}Av-lNk^UWYXJ^eD(4XV4XJ5Ok3v!A40Aa)Jy2tRA`GA;GE%)0F|r@Ys)j#uFu@mSa+FYD1AR>}PhKThB5z1tI~ z9XjX98>mbrI%X=^1rO(T|TK;VQezARW2J*MiSF%%oPqB0VW#U=G z{r(P-=lP1CdOigWVT#xWYm^@TvTLVT=jvq$+KLSyz&>s1KX({+Eo~$}@2I`})$-=S#r&L$)JdMPCNFEU6py5B zpv{KV^v(QZ_$A01HN=_YNBAAwPP~qucg8d2;tZxf@2vCt0KYA75afN)iQm!OvImbC zy9MVt^DQ`xy@{Rkdk;>}n_Vluu@hg;9tGXS^LwWZdCv6ieHG`9{?nS`EH0_)&-~QH zboMcMfATYzKD77P-?P&n`a*i+4E_eFz>EBR3#`Di`8gY3)Qb1al(Kqo@2pAozArm{ zsyDy4G3h%8aOyI1OJ8|Y&L;X?Ib-Mcuw6AFyKac$?dprEZ9}C&vi7=V{Ieh}ocj$xV z;(NS_*bM$G*E=qD6Eqgfb3ZDUeli8}?487RuopqD+pb)^FPgDmq2+JVO`!pOn%vp; zJ@d@G57WM)En)9wA7n3LkAgrj1AK+fGgkiPLw;+5iC<2;EzYjVYlXH;2ZQdsD zD)t(@fxpq1x$K;c2k3L<9$}Ajowm3G&U3G%=l|z+C4Ygu)b}&cmA_Ve1mtYh;8EAT zUrx^UcvuQI@zaO$j%;B}H#yIqmbulUMc>@(U-lL-2&>Nx3P=mwx&179caf~XEi_l zBj@M@zX9$KdG@>`a-FPk`hEk*8DCv4_I-t%Ns#$nYV2j=sg-fKE8Ys(hm&=lNzO`k zo@=e#T15>(fBJj!awc*HGrv54Q#q~0hqJR+Z{W<~3tWX0>@hg|^g&twbdOK*1Y@$k zsh!NHxA-7>TQxI4>?VHJ@?x=i_ZY@ME1p_SE#}{SKNH`~?pQ5;W8J`RNW0XS)b(WJ zwz3zn7mLji8;@JidcY`t`bhRZbunJvwK!+#2;2=_Am{&OIoGf=t`+}H{^PhK>^1HI z`oGy7*q7oV{7&#a|26YCN=xs*6E39phr8J6(cgHMtVQ(f_ZO~R!q32b zYQ-KrN$=$p#h}kI~v!I{kfdAcsstF{-JB#$6g4z-UhKn@DY7P+4~hn^2b4A_KaHbu04(C zc=q&xTrYh+|G&M_uG5y^UhW+BMfCKMWzlJGi7kcv{qU4}c>?bj>x-{9_o4FZ;YoJR zR?h9!{A1#Ik3AuG0{dN@HOrjugEp{De97N->SRKl{vdC={1IZ`!t#nb=kKHUXOD&F z_~EWKy^DRDakj0Cwal5G2zkCB&$p4DzgMrTs5QtQof4a6Tnn-E zvgEd6-yqf#zg{iAOK;@o9l8M?VZTSudiS$N@3SWu_cA@-bg%I57F)n?4Fklo=HJkU z`CXCUeZN2(div5k)-unQv-GRE^ujG6XJi$gLhB@NJ-);5!pyBKmVLObTs;3L@^+fb zY}ftW_1f^KiEo03AU$Zc@!6YkVvQ@sTDNAuWnB8wI(B+Y>Mm<~wp`5d1KKqCncH<@ z-_d`E2kNUw3cp4bFRHa=DmC{x0$kJL}wp-B~QplYWst^@=gg z*qP7paxwM=TCUfIe;*#oZz%o%J8L}NbxyN0CUudz&oiZ_=JPYJvvqnCzY8qor>|#k z@8joNszc2_^Lyg-puCH-ukC1uU`5$EqSfK-S$f0-$amx*d0Dd_@L09zUz5dV@!wTn zHFo|TCFgaLczVqkoM%W)+8-xhC$r;R%+{{Vha{3hC6;>@aqIV|-NmzwnkE AQ~&?~ literal 0 HcmV?d00001 diff --git a/examples/performance/dflash/noise_embedding/Qwen3-4B_noise_embeds.npy b/examples/performance/dflash/noise_embedding/Qwen3-4B_noise_embeds.npy new file mode 100755 index 0000000000000000000000000000000000000000..31c70ee77c37848870e813fadb1f72d0186758ba GIT binary patch literal 10368 zcmbW-iT{>U{s!>7$WqACMApiZNQ=_)eToP%gtC(?31f>g1`U-;h)U6>lqMupNUG4Fj%s%~l z_iIwsut~?W+B9j_u*q3{`(4uS^b1ez+waUi@gIEp`ThIw_dnM&}8Mf+F6#v9~RTssR{6Ek$=^YjqMP2$ZzH8Yn*tg=2 zd;_bBVkCcSZM*Qz;k%db0DK4k12}!%B356{0{jfrE*HfH9g1R+?B0j@?e~x9pkaqUh`$4&kdQHlE!sT6`n+hHx8u8s4Yc^YwQhdmDL&^Ud^L zeQ~8VWLzbjG4zJo&;-U;>)#lv@cQh9`W`0t9JrS5&v!lk z8}x*2VSK454&*;on~glLK3^?5bH6SBrf?B`Apd;1TP-Y#$LY16v#Yt;T5LJJnY@ho ze6gc&D?D0%w`j8xHj;CuobBkm(<5cI0_(|b$KMOGUR~hEQc+BVG4LB$&x5SrOl{WZJ7Ymn+(2K8&k);!?^E_*`dH(+n0;TlDE7y{;2HeS;|thb zqik;&00vK)=TSG=FMjHUCTadF?(E%e=o#Z)wiolK%}~_D#lD zPn+B6Ik1ZT6|{!BupDkEsWG?%ma!kl7wYS1>$o3Y!u}doie-KehZ>&I)fk$KXJ0IW zcIC=BF~jpFi}zt?ji-x^VecDeKf(VUdlzHvg{Q!$);l$Hm3UwAjq$teN&IiKH>fU( zKjb~6O;7d~u)0dE%R5lsZS42ix3Sk_Ph{_eFT{DjoDoZL);4=~9Bu$_z!L2q)PK%_ z?4J!_qjFJ95X+op9#3TN0>8`69PB~o{cGTU?3Wi6#gTM#xRZYd9uAr39%5;8D4xXL zgWj>)Jkr0>dG4<48O(*|u=?Smc)hGfp}9WWWMAOBS4 zP2xY6ok8rU#CF8h>~H9=+Z9D)@ti-k_4jf|^)H^iUlUK~A1ikZojsZJDD$_K_=x!7 zR_!0czvJ}rs&~JHeKD-l{?3wpMK|-jjm5KWv-l3ina>*VJA0{|hIoN~2GULWGjFfa z>%p)3eU*JPbkbM$Z+-rzkQ%#+{}p;M&N=jVXsM6a78S)o{6F$9NEyYKQ% zrVpX3;O9T%v*{6Z_S@EF?~NzJCZ2mXdyIGcLjNVc)aSPJ`TYArau37Rcr|_-FT_3Y zGoH7eKK^??(5Z>d*{Q`_@GE@(hW+K9>&$uxkA{uC{|Y$`78S)1@n-bL&`0bH?VhHm z)7e))%9$=_7_Q0xIOI%QA+|tGT#2{gdtLl~@3D?Ok3C77oF(tmAK-k0ccYJ^A6=wB zoV}CzJG4|;%gpPq{CA7z`}``-x4Z+q#Xr2G{lk|!>VmJr4``F`SR?uox|W<##&x@Q zJb~SpJya}vtcARZ{2lPaP!>N!tPM^LG?w!k^pLZd|9bdC>~_4R{9WPx?6ZDfc+!Bvq57(|cdo#KB%gcVtI%gdQ(T~YVzf<8zZ8H90^z(R)w&#hD zV&^RB1bdp}tkX#GiDDUB6P$HCmdEAKYw$-fzUsNBJL314S;&U?8e+6Yt)ka-P4(cX_EYp7(L; zcz@42h5ieSl-HB*N_sE+ptjrTV+@@&>x)lg?`_TVTjVyr7ho=5&dRzt_o;399@=AJ)J)com*2d4IV%gAZWOGKL9qC(|9p&VoE|x}5C2tW(bCtbJc?hO#s7>#?_j zoC}}wWqt;-XW-Jp%AI89%rGkG=3({sDZ;VI#i!a^`El0sSK^g0%KR=23ly7ecN8% z9{eZ6@qCNLmhs-?4+}dh(8k+%=bpzezuc^P4C2l(Q}CyrKALI^WRT zEpyJMKC_Sa@%_j*@lJWIv}r2#BAw^9lyeV080Yur_Rx;KG2ANmUOZK7J|2a~zzF+d()L$cdG0s_WjNh?O>w8nd@0R>VXD<+s@A|~G&t2;cv5o03@U}3G{s+$clwk~gg0{8t7Nxy(~)CgUpP-J_q0 zbS=CqPVOpg?kiWm*Yof~{)6zXa62r3tle-ttAw>(LD$iLC;T)#2Kk+s_sczDiayul zt1G95=iU}SeI4IA$lmMC^k+A~I~&hO zcsLB;tL=Gr%Uf#ATk@?G&-~`u@6n_1fu5P~Xlu9(@@-uNz3EXn`!K)zJ}Xzwu|JF_ z;~XrW?|1g)WU*!Va-2OgivMJ&$!PT(3@}Pl68ZAo-v<&Kb`UA z9Zo5$8@@Ym&hX?uMQ;vY8(;44=|AuJ6wW#g!h2R#zWEQzI|GJ5C;s&PEIkVU?OqA?suA48{B^P6!B`*6tl_dZ|h z>PEa9tfJ@O-{4kAO>R@J?EUui$HtTI+ogQXpbh-0%~SAYsZ!r7@d@nLd35e8JMcdt{>(ymA)MbxrSv2IwY>k{zuBp~ zc{nxLOU_jJm(WvToV>R9g>q%TX5U=mZ&0bjto0)NZ}FBeQ0|?0DS!6B$M_At?TzJn zdD&l$+1a1zGvEKQa`LY0!~0?fvycAMZsgoOimt~u8Rxx9Fb1;s-o}&QI{I$h6*h-I z%9Y=OD<^xOZuB7j?1h|1W!M6;5B7mJ`uG~ZSFW7xqv`F;MK%9s^s|tC)u~*` zuT5t!H<2?8uA|?SyD`1WxR1o`_I2bY|ou#&!=FXO!#7RcS3e+nMS zUPh;;GUqpWhyD0^nY(fLasE83cE?JcOysKtd+>cK){1|yb~!`0fx3ME5WfjNfXx3& zoPG0yc;@9SzGkr4GqR_%uA|^F`Y@cDSO*#Ni8yCR{>J(gdlsF&QG@*nUJV=aEh$yL ziP_^f8r#qKaXha#jt+hw?-~zw8@>RBVW$h zkNCRD*-!T&pdHGYM!GyVi-(lhb7iz@e@+4L=t{WJ?Q-i+-n@kw;%CUctk z9uM90xy*B_`M;FY5~qC^dPlXL-*3OudH3US_VkhT+^Wj&@~qWJ`a_)G|7~#UFJtNr z8?v|W;Ecy}*l*+KA@lNycJ1(~?0ipCTUB%!j^occnR{yL^Pl=oy?rj0JvT$!&*(vL zo3{B)@I2p*+Vq67+>gZO!wB}%?7T}$I)4kz{r+2eCENg6oBQxcSPVIn`@<@@9^QiO zs+@V{%J<+2cE;8X8qhz z6@0mG52O2tt;Rzj{iTl|`1Y%^CUo}VaD670+M9^a!Rb5uXr0_f#``_)%g&zK6tdpE zoKI(s=g?H_@=)YTh! zO{sE*r^Yf@kFoEEZ?)+vCuc@#D|L4*)D)YepB^~ht#|RxxWD`j=@E45;0kv7NqtU% z-tap6M4Y-^qK`3jYx+R^J2axl!Go}~XLiAp&E0SM`l?*PBihZN4;Rb4PL9T*Gk#5Slh ze)?UpA#$Ff-=I6gb#MWDv3~1|{UNs%-;p?LoN;_Tzj9yAx!aZRL!AA73H=>(hj;XG zFW+7C9on|XHTjnE4Tk;TS-y_AF+LX-(3e8a?g!ZG$|L52{L;4+#Gp3gzV_j4BU51>4 z&$3gaw?f|G0r|(kDE%H^uAJv>a9?Ao79R&+u~R?C$-97^`kpQS9eOU_3a_v4rsjKp z`uTEYJxll_b}Pu=DF(8q(=+k;+C{IWSLk;b?##}+k7vJtn?d~d!_M+&;}&=_{6LQ> zIVW&q&%VsrF_-?7?+Ca;PWIKq^g_s(Qr8*hZtTaZd<)qpL+0X0$QsPX8BgjWwfV4q zp5cEK65E5FddT1Z{>5I(9^u}azb&t2Z!dN?-T_j_+3z{C?u5^^$=~))GS(rb%H3%) z9>xE^-^i!14}nYcn{~XbT$zir>51~PS99L(iEEkD_3?}BUGby(z7ppv$M8?;Ayx>{bZd}^WE?#Fi>CjFHN{ScitC%J(hboCoNlqUnRDQv8|$4(2Ma8xEjx}o-^pL%*SuI24A85MBId3X!i@=o$mx(hy5D9 z3@+Dy?l77E51}iabLAU2vRt{J?j+~0?8D(s=)u3ay!_kT2c%Z6)qg#l@7gl` zWK6BJTf?3M)A&+zL!p42q4(hS`nrow+q_HHg_Yk*smJ%=WxiG7bIa+Y-Ux!x}3xzKD7__b`6gKi# zvm3Ebtt=Ggvd?6HfxGehh>vDJLc50D0?)=*bnf_|<`@n9wJ?Xjb)BKC! zcGyXuiw9N~3U9M7#>;Rcyip&S!#(WG^$0kts!;d|AG)}F{TA%p`(aoN^D10l&K25h z=PxUn2ma^mLt+PT)A+EG{R7mLdoC<=y_)oDr9z<?u#!E3)_{F2 zyQKXj{>}V}kp0q5Up?{b_#VifJmy}O%T4?b*LV&$$D>Mx!pAtd55vtcSF8av7R&vt zqvd*uuZ0$3CvdT3k7~Px-A_FGc_*!gcwPRLV(;=-)BC#jLG%&)7W8TOa-1=z{T*=U zE^C8_!}gLGu8oJ`o8)C2`}ixMi~HHe&m5twn$E@=}b{}@`G4-_BrJ5Cwlhh*ZN#b%UB=fZ($E$ z4`Dy$+HKhx?>X`gi8W!jfe(xJYGwJH3>AA{EHz~f{Tu#d_G!4aoDal4g{Ror124%* zE$zzB{`^5K`|dgV32mB*?c`5~+{-bTSyd=(mUj@g$myFr`Pa}dsIa!|>F)gk@n6^( zV+VFiZBs`U<304$pi9_WX}^p2<&Wd%J{z%9cdx`P`Qyafvm3)~_B;3!$hH227fa?$ z-nH82dabp+l%4wiAb+-TJuN4-ZV_Z%e;3PooGM;RZgoiQo53GuFP*KA8T{w?`=N{Y z5`IIRJ#>b6_I{B)OFY+Y2U}_LX{p6e7u8SK>yCHRvM%qlo4~#7UHC_NnUkLU2XSi7 zAZ@Zwrj*M2XC6EEKZ8Dqz1#RkilzND{+*CBxe?xwGaWyI58x%0<+F8j(KAwB1wZ5Z zQEVVC@f)%?vj?zS$!W#@khT_g#MNmN#k#Qf(z@#VA^v0R1~3$A?DCwGm%sI;LZLT% z7&M1mJL{VH>xL#$Jmz;mb;DiM)1|_OO2c=C@&whCFAE%Db30S1fy{ChZCN z4~P%Loy5*y|42UzUxDk`@6G7b*?ErD(KdTKwWWeKxXL+XA7rNa;(tJ| z0mtQL?jOdv{`cZ@`5)tU+8n@n4rVW1$j)5%#IL!}%jJDkDnIu>!KaI@!L`Nf!Rg{z z>v!$ZA?&~PF@oM%PDgqzcG_QqPu6Yme*EA0sa$GR^Oq}&^fa~*n z=<8bJdyD;o>zu$R|HQAXEI*e<)7lox{T`R|5@g+Pp{>Lp;`OxDtW9#8R(O`v@8f5T zci=wsz4pr#c6YcDn!s3kf1J8CQ@fn!)Tm|rJfkM*_kQ++_*8z*c6&(OnpSkis?4o6 z|6zZteQNOnNUhIaJ6k++^eX=?*Ux-kq-}Hh3Gtk#*I+VapJYv9FDg0XP)p8ic!+kd zSk8V6cE(b_N?pTqb`=V*nET#hU1=Yrt(@BIe(a}RcM1C<@gmNCUsPGH$@OsF6UK-w z!h6N$XuFI)7$-g(o}_=qKjK<-qIG2NgvR^}`KbY~LgqYc(Gq{b zKewbW?em_WeQ`E^8Rr@=;F|Jg(5_`v@*G=?b0!m8VE(iJ=lZrtd<8u< z|2Xc%|1B|gVtJHv`!adg z@N=dQ^V3J_;%t5mycF_Yo$H-N+hzaM#NWuvzFUv8FD_#bkdvAa`zP1Tz3*|&%yaJd zfLIsyICf&s$jR8JYnQV<2G8VY3@6t{-hJ%cZ+o13$iB$g&ig|4esANM=o$lY+AeWl z)y3ase+sFIdA2=A%Q}>BEu3+0Vs9;}3H)|=XjS=JF86vmTwPS-_4_${idfEQf4Mv8 zS&vg`&)_`key1hpJ9d!RnPRCWsbQJZ2K;8w118A1h}J=Dr7^#SXG8XU0WZWa;UM1( z-{I7U8uZ(=8wTy+J^poKr-frJCL9yIdU3TX9ZP&Voow_v< zZjKEFIS0gY{=Q&$h2iu!v^nfr^VlnC?O}zSzvxF`wLSKbSoTWJWY#{{UXSxV(GQQt z)o8!4KXr{u+3y$2XCe3Z|MqKr_CftM6CVgU*V&s{+tjF5a0qg(oa3CQP5cJx++5dK z&9BeTJJ7@C{aOA3s7Bjc`e*H`%ID`?+5xy-yZU0;lS}bD+6fp)&pp0F%lB6L&-+=P z^MA0{iS31)hg;=j{+^askADZA$p2P+Jl@8?gWeIccT$rZiOsJne{+wPlk=6e%HBB# zMncY8iIy{P3hoK3^|icM-q+a&vuT<4srnp_bH3gYpDr)&d#QE#dq%8eZddbffKMTH z>MHs9%N%vnCTH;7QhDw#(dT-81F_VbydUKEzv?*8_w(r4M}Nv6MthIH1hTg~(R<@* z&<=8@e-_K07=UNeQa`RJmFM_0@k#nleY{01?*t!-{Um1sKmCn%em3&=LF(9XSX?qD z$bPAp9LS!^9CpW9=f_-g4Za8J$?cB&&|0%ga+bK}VEGrYzo2zy55iseP2^5xx1;Tn z_qRFh$Icw&{5Qp^k4vG*Zz*S~HrZ3_AkVFgH}9dnwKx$aBsMwQM391_c!8AYGZ?gXg^lVQ7lelL1o(8p$W9oiPT8F%(* zV{P)znCEr+xLiEvr#&7_`xPD*pGaHCu1?Q&PSB>}?D-#Hj#v+V*6&=n+_k1Sk8Rnv zh))vFx&6i3<_zAe{ayUbQRX{yeI88JF7FIE-%sGP#quuJn_ruqYk#eewPHEzYex5f?*q@2l!6W(ED|hi%7(;4Z_FCrlPJYhC zFxbP-zU%`H`Ki}A|2dNf@MrLXc75OleVjvYF85mRs`ukh_!r8_{NFEU6|8qZ$7q+z zPtEHuwh?}$?~vb*eJ3<#=UI3yEMrfs^1X=X!X(<5V!1A5k5ATDYGCHFmss{rYQiA; zMq2c=dlK?DQtV^b%d;!*z5VE^Uzv}r;qT^eA%8{D9)r}|=3;;2)TLFhMqXEbo?{vJ z3FyVYj-U0qU3@pcQAK$Vw`8Xl=HAc3-{G9K#`ukr^^yNNUO?*v+vr!|Hng5%TaD?c zzBWS*eue8S!foJHu^Z&4&(!Ly&lvH`_*Yk!_iOsdeWYGbVqZxg&fhQhDt7u<0jc}B z$JDL-Thv7HvFx$ z&mHA=hwS0(%MJ2&@z>$rRpsxL@7Z_4EVxNdPwi44`xz`eG zE|uS(7O*cA%Y0?5c@Cw|?7#o)GzLig2yH3MEP1EEdB4j&r@w+WS-(f*^`d7S+wl>2 z2A0rA(e9C#+WZT9AT8s#PW%Ge~0} z&2wiq`#ycllXHN7G330zOn+Ycrnmz?-=e7}*^8TLTjeEZj<)&hBAzkKp=Yn8%}$&; zp8b;R&Y(@fb+tL2U5{Q3*3q)2ec9C^^WGLNVULA;w?CszYTdX>dlnjL(@ULfNNWuf z#b3gCRy6ttUl~(>+QCX^3^J!V8-Hkf4ZF6yuKY>-_uxMMa(uhkGHW_PSy^b6Uk zYe)GRckX+h+?QSBG+Ism4u0062ou?PkH|GMpQ9n`K8PO9+~&+?JySpbgEL3(;`a1C z`kT$44^#Q6JDu-V63a}R@DJ2k5%`(v>W;2L&kdTVXg@LyzSt*h}nuyfWjZ#jP@_NTO- zyUOEF9eJO>NW10g`4W7yywr=yVhbVrK6B8Fp7Gv-^IXdu_JFMMrTB_c`8^_QnCma6 zEyZWbUm&Lw|5^UYXCmH6JB0IWoeCLCYIQT(moUDve9voXXBRy?A@`knlKUDe{uOO8 zUI}&CBg8Uy*-P^v&y<|;TI?6`0I}5KS~z=SA#D(UnD}7+6Z}VDzWM%%o^PCPu9bSw z5~ud<#!YsWpW}7p=6-SpQnP1>XP#@rWAsJvf}D(hEuN#@H1?B_eLYfMzI{@gXUKhy zUW5J_WUl+rcj9(Xq)mYA`!nTcf32Wzvp;syF4j+KYOeW**jRR3_A}aa#7{vFcAn=S zS6UPKsVQ^m+25b@k81aqcn|T(@CK~~&bfbI&hV=8KKqLOqy5sD{~$YScCEHG<(&m* z@+aU=isd!0!~T?(b5e&^qsn;j4rA|(kI?JVa(?TwN8=CVXAj&6sRijHdoc4IKYO}1 zPHoI{@{oS|@vG5?;KZ}e9bu{1HMHc`=ilVMa#nI*f8rM*ex3`bLgqblle7CZp3bj} zYx47svY4HA`S-T9u!?qqKi>UhkECwooog-aF1gidGa=Xil9q4L1N^g%Bm1Qf{Z94+ z^!)9WQ;mHG{eXV1##iw(uV0C^XFuSZVISi_j9iEw_PlC`w`ejs|VfKBP+}CSNMbZsYBTVZ{Y0RpUib?>lE5voOSJo z`$KADFMdy$0wegT2iY$JGzqr?Cgye^e(h_*aLTY-phTSUqRo@&UvaqAA@s-t`lpDPr*Yg%fJ8q#6Dgs zukU(1nZ5xpfI~%ROMdD}p8vg}4!;xs4ECw)Q{)_1m<&bNQ7LT9r@|@WzaAQ$i z?0Z`N=E@mMe*hY=GuPQ)PvGYK3VQa{2>H*^s{NzCNzZp=1K3L&1s}tO&RJ)q9$OQGO51oaH;^FT6o4H7LKo<_z^`Z^1LfXW)sI<@p`J?kx9VeiJyV|03;C$hdMI zdg1*2u5G(wxt2AdrQMz4Z#$0%<>#y%kdx=y6_9oA4v(@k)|~4%wZ8!8Y*d5H{VTKr z`{db{-;teX$6xGuVx8n&vMjR>Qhx?oX9_gZNT*##^NSm-ZCDj@+f}?3*Ewvlf36 z)a74{`~VHd!-R=Al|P1Vc5=JR#l#(3-pm%^n`Pl^|?_j zd;MN^D>=*fc{b$Pcm^DxXU}K9Wqp(LGduI1-&Tq^HTZPM{_h57R+=BNPPFIQsh@Yk zM6u-NdfB(z<-BMt7xNdvX)ubOb;uk)TUGu&Gc|BJe+A_Fuk-Jdo3VDl$M}~)WBpxB z%l>%*dWr49dtm_kp{nvbLSuHuotpI+WbN~poFV+qcoOueGXL79j^!`!&3R7LupT+P zm+|xLsEZr%lfMOa^7}&8?6=ZR(gvzBYvivweSo;JuhQvXh4AE3Q2 zmUA|dzrpx~yz6n!P4-&8C2Ptpuv1U#Lh}9#9ZPDa{CQ$?*aK-< zw})_T-{z07uhDjucxprDaS?wjo&#NIr||E96LR~|vMzPBKg!R#JTIppo_)H?cssfN zcsaS=lj7CsX}gp?0Plc&JKZcN_nH0q3!aO!RwLjW*d^YtsJ6j9V$Eq;himZ5w9lZr ze!k=X+@rk5kMcAB`8R~rluhhE!nWTEOnNj+PNR{&h)i@n&7sT#(>|a z=e}|lZ^6CcP3;!&7n;|-cn?m!snAyo{IU2m+8w~zS?mCp%ipA3 zo~uXXEoQHy?U(lzzb-xFUk|CBPvh~l8gLpr@yzQC{JFOC*#p>H*vIvAn4MbmQ^~#K zpXpm;!-sM&qrHv4fSl<9&Um-tlW|&6PWI>Dw9WWu7(vhYuVBxjAB8+m@=nx>J=L6K zPpsqT+^;H?$C~q=^VX7fKCL!nj@Ho5fsgoE!?7@wy%^>d%m3c7nYNAJkv|UFYnT7t zC3EpUKkt1@*%{Xev7xjJVKYBgGd^oN3rA}M-|1sFrI6dRb^C#<=Ip~73mf1U5 zgSEIJ|0j0pYx-CPncwR4x+Uk-Sh7zSv8T`$v(Luavzf>2lg9X8^4^eh9s4vWRQiUL zdp+%IcoNn^`pEg633*mO44W&RjbeGPoX*dj-NV0AejWB)$i7HjYQVk*dWz?L;2k_% zEZ+gShi)*MKTPZ<$nQf}8^xnfgj_tE~t&c83^z3P5`{xbguT=N9%re&V$YM1-{nmq}Q!TWG6d?`QAzY9v{ zliig)0PYietjhPh=iCDJYVpaq8SMjh#@+<#%RL1?qUSv!>vkIqr9D|RFT2z-W4v0t z4SXYZBhKF2DxRA28=hEct=U82PWEQncu37#hPTnPpZnnC*2fk249M9UET^ydH0S|C zUAH$Z5Wi0DG4^NBoxgxTMg9`p5igXV{jp0w(c7{!rgrp<^{tBXeNU%-&CYLl^Zgc> zI+5{Z{c2R2XYnnrF`nH37D4vR?&wvXyV@MnCVL@u_$!?Kvq<|@_y_*?>{Hk;$SYKq z|E(o^AbWK{n3i>UfuD1ceYYP^hSc{5a1Zd8kqYj6R36Z-_aCLEHV=g1Cr>fq(#7vj|F1^7Gqe~abk<|VXyU@PrPedn4LkbUtJ z&b?=U&*o>}T*e-&kNo$5C;z>QcDw7{njC2Ax|hIe_7HhF3wi(OgB!su^z4maan5{a zcD|=p@$;-&%HPL+4!W5C)T43iH((50sjnAto|)lhr81|_eXtT1!^Qfp&|l6+{_T1| zv0QiO^6xPA2gQ1L&&+z~9Ly6>ja=Y*r{g@Y8oO53`*waU`egVVuAq4X5sOV`trR7V`W!$WLFHtE0c z3#Ib!Hl4)Vy7njk_%p82C>ojcY6NY`PJ;y+-3af^3IhvmNrx@=VX<2 z9Z&z6-%s9T_6PLryUqA){==}lhrWHs6zO&Ncf&04!;tfp=jc_iU(L&Yxu2fe*qD~H zIsyMm%UWj)=h52YVbD0a^pVhkJ&Cqfd^`JT_T%iX?4IK5*nMf|v8TYj{NZ9XXgNoX zAbro^AEBquN8DHDYC63xJ_irNkK%K1+Rdr3{w4dBo%_z%53`R$_IA$SO|-0WJ-M&r z9{gYETky$oRXH=Rb%J(<_-0&>|0Z4{mS;fr@)mP_BwDHb+dw1SN*}2~`KIbBmUVc{ zb$a8IYp0({?Cjegq1Lid=KA+vo~_aa<;F6>hy0Qd%i6{^R!o9S9aFAvwm0M+~@5O zZ6D%b*Fn}L z>$?~>%WsE&mh-LL#$vM}{pP!*K%W9TjrTbI9;Y@u1=*jOkBP3QZRzn*yBooj^K=8O(XM~d a{K|VCzr+7@r{6{Rx#ypEmEXN*RsA2Ns`=pn literal 0 HcmV?d00001 diff --git a/examples/performance/dflash/noise_embedding/gpt-oss-20b_noise_embeds.npy b/examples/performance/dflash/noise_embedding/gpt-oss-20b_noise_embeds.npy new file mode 100755 index 0000000000000000000000000000000000000000..3c0852ffc0223181e7fbba2b14ca5cbed53dd307 GIT binary patch literal 11648 zcmbW-iQCs>+6M4nQCbv5N}@%il%=fI^E|CG5eA7DZ&D(%l|4(lw5yb&EZG`M$r5^& zYQ#_?Gug`0j3w)2Uxw!WoX77!cpb-iT-SA8%YA?E{r6<{N&Wksa#pp%HHA@ahFmam z@Q5~@TeRsq^oTa?TeKN^`G~7V47zOKNBD=uIkba?v~9gjKY zu(s`6+|=U#_g|etVSZ(y@N8+J@Hx4hY|^bz_;gjFaA`%M@JvaeFqz)FtWc<;Hz3dE z9}heD|1AD$^7ryW;cr!i!hmD>#VioJne5xOPjsP1d@Ida}3TZ@@M4wth6w}8Gwt(Rfw z3xz`4%0i(t%!9FFc2*eG2pSNnBw`-rQ>K9X-GJ8geSM&I7We5@=K zPUu`H+@`%NwSNfz814N3?C1D4k&WeKeCxtCI0L@`Ucld}ofcwFlrxd9+KNJ9 z5BzERR<71c{7m{6V&0)=jqDGp`;PcWp&9*EawB;n%ocwaS*Nm47>2(ZrX{b!_uzK& zjp^@y@%7AoPkcdJkCS8hn^zSIhlwA^KZ$)hKI`&nG0V;U!*Yk)si#?eu06mu)zv5+-&GyYoqoAgI;U#QiKT#MU>uNwId-`nc$B==fe z!S@UxzZW-%>>E2I^r6NxgTE*W@>j0 z`@iA`^Y0?IH~xRfvA7fYvJbz4y6R=mWb7X%KP0mrcNM#hd>J=XAAcbasw%F_h32_J zy`#lUDse_ZeRc1)eqOW}9^~5{-qVl6)u=7*KytFUX6jtS9?oAAI{Su$$uayJ5~Hnl zYQ02XF8=RgC(!4~`#}4n*q`F3(yME4IDM)7diXE3-A5)zB;VCt8e@%t)a(5S(`LBVmVBaAo z`(;1=!7vtXz`rJDKXsdmTOucG=tlTmyXTR6TW90MJ;ir#X>rdCA}?lVZLKlCpUOR5 zou&9TVtbO=Pu<}?agSFTL;7p%w9q{4fOtgfv=WCOm2 zWEXsWZKUSYYM)DQgh0mz5ONq2(p%(-g2`h@@D$es^Z=`2R~6e z`-&S%PWO$ssxd{}mwcJCoS`>&)hBr?U>&}L`rq(9S5hc^OwLm;_g`<{If}kS%-!<3 zi_f0iP265;=ZroOpL1!e_AR7=^|LX5`k!-g92_TaCE1y8viJ_#{h7@B z?jU^i(&D@NT;DYce;DNKe~E8QmGfh@KH(0<)upc%UxvR7SKmC=Syg;H4aPsEW`1kt zyw1AXh|9j2BDO=hyAEFq_YFO7f!_2R)Xn*t->28J&lkU)J^(JJ=T7*HY|ZzC+AUy) zoXk~CsL$TO?ki>;Os{kgm-q(XSphTkG29})oMEl?w>LRY?(yQ<^1a0OxIVRDcc4$k zb>Pq4*oplZ`7zvES)9w|G<@uOx*7H?1_BO z4)%5QJ=MyY^Q->0)VFK!Ti6@q9?d`5xOXFGle@~>if^ROOZ1Lity%He^ZAYOXTC{% z2e9kGKwSE=T>dC>t2T0PH-qokM;O1?<=+mQ=xfz`PHad1?{WRf3&i!H|Cem4f2ZRv zW#21)cgVS!_TO|4#XbY}Qez+dljJ94YHpUZvds7CW7d5IK5IVbYTBDi<}B@qn%kPh&a-WcQF@G~$yYM}{H^tl!HTjNIqXC>npCf;)yvJdtv3N%ON^#TV zXP!E@Do9nZu1|UG^gOB6YHVpNF^T z>&2gmkA0q)SIzS*c5`+g?frzCFYaJ+9PT~H`lzqX=gF@4{`^@RSK_9~{|f)TxU9R8 zP}7^MhPeE8zD-{?tulw=K2?7z+y@=ixSKy~ZG_mozYm1%(1bs4t(=uR@#hV=f&LIZ z@64>Jc4W@C)6}?I{TswQ4o|}_>gCM8Ltj?V_lK;7XXMYsE#(~}?ga7K&qw2SVc&-TsB7`gt|_O6I!pL26VpfRH|(BbvUZ2U6Jqb; zy9wI!-vMc_xmr`@-ESUV6@Qjm*{9F52jcTv?KJs^$p4J|QBL;x)qE$2&HDP0J%T+R zPOLD_xUp)q7Jqy<>qDD6%ksN#CFkS&PGP(_x(4VffwXXNY-@JyPyNWySmY3%)MmcjeETo(1Dz zXZteuLO=DsCyx=caHakFLb3K$xWl2(>f&4FGWj{1=d+ure>u#M^M{;Uj6=mL<3p|& zbBaFnmYcEP2cE?@mcKvb9rhqu2E8iGNoBG9PU0s(2l?Bzv5UR)V0poXTs)^<6t`Ae zU$Zms`Hen7?0?An_?{@SugTo=8|5v7cJu}O)$}dz#c}-EJFj#%hH{(WKf*P_wTcfn zmOh@mgngWxjO#eao8k|Be+a)u%;moENZdNUjNxuo_6FZ=YOWLa2fi&k=T+V~OCjHz zGw@GjL-N~F@1*CoL*@)!LS~+SC+~rr?K?rv=wVPLuN`?coMVgE|Dl{(<~i^4W%QgWxx0>#KOJ8yx$Md0 z!~B<%%^>fLTI^-?#if4Fkh_wZ-8?FjIf#;a(MUKUr>0FA;Ye zU+%5%@%!Q*#Lc7sN!}Iotk>oI$BNrbUt;|JM!%15D1N@!|7AY|59r(KZpK<$r|{=4 z?1}#oB4>3iuG6J_bJ-uOccGZm$s@{)JAJgV$s9h;e_N$D17v^i#y$j}zHb)uEqMWZ zIQf=(Js{`kp|D!5->s{W`m?9JY2rS?pQqOA?8Y#M-cy}*`a4H|yZWxz#B}6auhxa+ z<*-riWMlOQxxc#U|2XwV;4AUJ!aV)SJvkj$n>-!26?Vd}lT#1+vB&GfHDYFx zGs&&?XJ7jLtNd+1-HcCumvrP?Opb(qR21Jk^KrHKvlm;!1nX=FdA69%Vy0?0e~0Ld zJDI%)*<1V)-*60lUwD{5=frIK2J&Gt=jhJ#ylw7f&w~6L${2NXZnR@J2w&>{!eyQ^ z9tYt@$jP}iSlmy>dl&7UMn4}n4@R-CD|L?JYRJvnA4M+iY8{Xn;~9K|jBRys2ji-= z-CNzIWcru;Zmall_w)`&R*a6930NHn|A?M!h6~*5~%lTW-%kb0b+v$U~b2xo5J^!wkzxB>$SHdm0rsnh% zdV9!x9Z2RreSw~_enI;$;xY#f*pJZ<^S$|7!jua2#pQ46HR%2M&(P=GtIfoHM~=f^ zDmUMLs<^yU9uZdqCW(EC>qYxT+O>xb3)1b-O22`pSytW{IYdiF_h0evTJ{8O!8$@$iI)@9E8yw`t} za~|0T@_y)2S$qo=*sbZ4dd|JLt=)XJXR#-Fq3P5#$$ACPCOR}X)HarlOO z0lt-UrMQdfSuX{?J8?tk$N7#WeEC~Qb#{4a@i)(1*2PHnu`rl_hPb?I7V&q7ciAuT zmC1XWyu3m_t~+^fS@GVUtIz$}8HY)7vj&EXnT0E-cb3jNik)t%KOx**TBWugUCA;(rr6j^5E* z`WOD~^!F=@=j@-w^~F6S<|y1x^q$>~k=ncB#`@kP)yw+2hd+0~&5--&L-s$&yR`9V z{;bn3^54V1M$g~gPp!1paQl%Ni%aP(wK>^XzL^|+_T?C8DKGo_YjVBzvyR^p|AFxy zgIgjebM-ZQ1-S=g-2Q28R7*eUd2?=nm-sr8FT$g6y*9^HI$!WvA75&FI=M|djZ5^! z_&tex6b9p$LH2U?-)cEek`3WdeaL!lj+>?*w~`wndpp0w8}VfwHI_RJHqmcT^Zb(H zZ@s_L^Sd;EJKc$Xni_BE`}^YR<1_!`;WBnpXiHyC9!O3pbB<^~=gM5j-+3w^<9a6K z9lNJ?nv2Qaxl9|mXNI#AUjwT14<)Y?yHvmTVCTF0>ie@|Z-fQ#8f@dso8%7bYo&hH zlGEMVsExm!K1ocz<#Kjg`u~!{@z-iM=SSwIdxhHMNqjjobJvb%&k!?!-mXjW+|KXZ z%=@|IZv2^>yzi#+zo}+JW0X5-n3$aZv*_FTpH^oXJ9ox0rTQrDCvtyr+5f-D&%Rux z?Yq_73!mTJE#&8ZTZPLS><0PULHN<7+E%X-nSG!0IpdghGzhjr-tvbyYgeiDu9}}$ z;mMwSQ_0*-Ewz0eS;zRz#$BziZ;J#g({U*c-(4*Va;8-b|0-v)^|mbKZ1SJ96bvv)Ut(<4tMk5D7`?{fK9(LbbTKb%13y__}{sCkyyoPpcb z%)jeYtFRB@Ay}$U*YnMz-y{AT@~JZ8qfRw-b4TW_b0s_TTtzNZFK^!A@}{V{JNr1k zQe5_R?%=P~Xe~c;^osoKt6qFP$Se4&)5pNc`uP@_J-r5hlUg;%tk<{6G5Xs|PTsM{ z;aZ99P4ybWztVoAXDwF~d#$`fE38Lyw)~!abHt6{AHv@mKZk!Hdx|=(`JN>| z7rP7pP#Dj*6u#oCBX+a)+u<+7tzoaGXMaCbS^VwPt=!q9#`EGQ@~6&oeCOh_j`w5F zrTwUHmA>*^<4M_eS2hxpRK-jcderV!y(k3G?MXh07Yg*|?1> z^)~}(q(;ua+|LcfWuND}3swFtfG_iucikbzeqa0$dhW#h?Podv@JjQGTP*G#{L|KB z_V<1CEnTfG`px!J#`;9|#rUhVUyW>~X6ASveHwWPShcF zYrCGfJNZ&$1>`>Z9hdRD4D#mNN&LlpIr|&Q*;BiTIhg-`e9pA_@w3PWe677J^9Jk#Gx(~so%zZ*H;}it z`nQVBIXi@%JN{TWkN+I;Tk&=2ts&=f&iix49SuM5Pb5olS?gz&IorrZ_}<1P>w77^ zjK4iu1s{~S#~^F(2{<0I*Ei|c9Tm>S#P%qj3v=i>le3N+iTQ}!T_3VW^Uld{m0z_n z!?$(Q-htL&YiCCG%jx|2{S&<@|C{2n-e#(~mTw_UhuheXX!~cmKk+{*CueHLejI torch.Tensor: + offset = 1 + selected_states = [] + for layer_id in layer_ids: + selected_states.append(hidden_states[layer_id + offset]) + target_hidden = torch.cat(selected_states, dim=-1) + return target_hidden + + +def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: + if temperature < 1e-5: + return torch.argmax(logits, dim=-1) + bsz, seq_len, vocab_size = logits.shape + logits = logits.view(-1, vocab_size) + logits = logits / temperature + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) + + +def load_and_process_dataset(data_name: str): + # Math datasets + if data_name == "gsm8k": + dataset = load_dataset("openai/gsm8k", "main", split="test") + prompt_fmt = "{question}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "math500": + dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "aime24": + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "aime25": + dataset = load_dataset("MathArena/aime_2025", split="train") + prompt_fmt = "{problem}\nPlease reason step by step, and put your final answer within \\boxed{{}}." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + # Chat datasets + elif data_name == "alpaca": + dataset = load_dataset("tatsu-lab/alpaca", split="train") + dataset = dataset.map( + lambda x: { + "formatted_input": (f"{x['instruction']}\n\nInput:\n{x['input']}" if x["input"] else x["instruction"]) + } + ) + dataset = dataset.map(lambda x: {"turns": [x["formatted_input"]]}) + + elif data_name == "mt-bench": + dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") + dataset = dataset.map(lambda x: {"turns": x["prompt"]}) + + # Coding datasets + elif data_name == "humaneval": + dataset = load_dataset("openai/openai_humaneval", split="test") + prompt_fmt = "Write a solution to the following problem and make sure that it passes the tests:\n```python\n{prompt}\n```" + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "mbpp": + dataset = load_dataset("google-research-datasets/mbpp", "sanitized", split="test") + dataset = dataset.map(lambda x: {"turns": [x["prompt"]]}) + + elif data_name == "lbpp": + LBPP_PY_TEST_URL = "https://huggingface.co/datasets/CohereLabs/lbpp/resolve/main/python/test.parquet" + dataset = load_dataset("parquet", data_files={"test": LBPP_PY_TEST_URL})["test"] + dataset = dataset.map(lambda x: {"turns": [x["instruction"]]}) + + elif data_name == "swe-bench": + dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + prompt_fmt = "Problem Statement:\n{problem_statement}\nPlease fix the issue described above." + dataset = dataset.map(lambda x: {"turns": [prompt_fmt.format(**x)]}) + + elif data_name == "livecodebench": + base = "https://huggingface.co/datasets/livecodebench/code_generation_lite/resolve/main/" + allowed_files = ["test.jsonl", "test2.jsonl", "test3.jsonl", "test4.jsonl", "test5.jsonl", "test6.jsonl"] + urls = [base + fn for fn in allowed_files] + dataset = load_dataset("json", data_files={"test": urls})["test"] + + def format_lcb(doc): + system_prompt = ( + "You are an expert Python programmer. You will be given a question (problem specification) " + "and will generate a correct Python program that matches the specification and passes all tests. " + "You will NOT return anything except for the program" + ) + question_block = f"### Question:\n{doc['question_content']}" + if doc.get("starter_code"): + format_message = "### Format: Use the following code structure:" + code_block = f"```python\n{doc['starter_code']}\n```" + else: + format_message = "### Format: Write your code in the following format:" + code_block = "```python\n# YOUR CODE HERE\n```" + answer_footer = "### Answer: (use the provided format with backticks)" + return f"{system_prompt}\n\n{question_block}\n\n{format_message}\n{code_block}\n\n{answer_footer}" + + target_features = Features({"turns": Sequence(Value("large_string"))}) + dataset = dataset.map( + lambda x: {"turns": [format_lcb(x)]}, remove_columns=dataset.column_names, features=target_features + ) + + return dataset + + +_DEFAULT_FMT = "{prompt}\nPlease reason step by step, and put your final answer within \\boxed{{}}." +_CODING_FMT = ( + "Write a solution to the following problem and make sure that it passes the tests:\n```python\n{prompt}\n```" +) + +_CATEGORY_FMT = { + "math": _DEFAULT_FMT, + "math_reasoning": _DEFAULT_FMT, + "coding": _CODING_FMT, + "reasoning": _DEFAULT_FMT, + "stem": _DEFAULT_FMT, + "qa": _DEFAULT_FMT, + "rag": _DEFAULT_FMT, + "extraction": _DEFAULT_FMT, + "humanities": _DEFAULT_FMT, + "writing": _DEFAULT_FMT, + "summarization": _DEFAULT_FMT, + "translation": _DEFAULT_FMT, + "roleplay": _DEFAULT_FMT, +} + + +def format_prompt(prompt: str, category: str = "") -> str: + fmt = _CATEGORY_FMT.get(category, _DEFAULT_FMT) + return fmt.format(prompt=prompt) + + +def reformat_jsonl_by_category(questions: list) -> list: + """Apply instruction prefix to JSONL questions based on their category. + + Mirrors the prompt_fmt logic in load_and_process_dataset: for categories + where a specific instruction is obvious (math, coding) a tailored prefix is + used; for all others the same step-by-step reasoning instruction is applied. + """ + for q in questions: + category = q.get("category", "") + q["turns"][0] = format_prompt(q["turns"][0], category) + return questions + + +_TARGET_ABSMAX = 128.0 + + +def print_stats(x, name: str) -> None: + if isinstance(x, torch.Tensor): + x_np = x.detach().cpu().to(torch.float32).numpy() + elif isinstance(x, np.ndarray): + x_np = x.astype(np.float32) + else: + raise TypeError("Input must be a torch.Tensor or numpy.ndarray") + print(f"[STATS] {name}") + print(f" Shape : {x_np.shape}") + print(f" Min : {x_np.min():.6f}") + print(f" Max : {x_np.max():.6f}") + print(f" Mean : {x_np.mean():.6f}") + print(f" Median: {np.median(x_np):.6f}") + print(f" Std : {x_np.std():.6f}") + + +def load_dflash_checkpoint(dflash_model_path: str) -> tuple[dict, dict]: + """Download and load the DFlash safetensors checkpoint and config. + + Returns + ------- + state_dict : dict[str, Tensor] — all tensors in fp32 + cfg : dict — parsed config.json + """ + bin_path = hf_hub_download(repo_id=dflash_model_path, filename="model.safetensors") + config_path = hf_hub_download(repo_id=dflash_model_path, filename="config.json") + + with open(config_path, "r") as f: + cfg = json.load(f) + + state_dict = {} + with safe_open(bin_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(torch.float32) + + return state_dict, cfg + + +def extract_lm_head(model: AutoModelForCausalLM) -> tuple[torch.Tensor, torch.Tensor | None]: + """Return (lm_head_weight, lm_head_bias) from a HuggingFace causal LM (fp32).""" + sd = model.state_dict() + weight = sd["lm_head.weight"].to(torch.float32) + bias = sd.get("lm_head.bias") + if bias is not None: + bias = bias.to(torch.float32) + return weight, bias + + +def build_tlm_model( + base_model: AutoModelForCausalLM, + dflash_state_dict: dict, + target_layer_ids: list[int], + target_absmax: float = _TARGET_ABSMAX, +) -> AutoModelForCausalLM: + """Attach fc + hidden_norm to *base_model*, inject DFlash weights, and scale fc. + + Modifies *base_model* in-place and returns it. + """ + inner = base_model.model + hidden_size = base_model.config.hidden_size + model_type = getattr(base_model.config, "model_type", "") + n = len(target_layer_ids) + + # Add fc and hidden_norm + inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False) + + if "qwen3" in model_type: + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + + inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) + elif "llama" in model_type: + from transformers.models.llama.modeling_llama import LlamaRMSNorm + + inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) + else: + warnings.warn(f"Unknown model_type '{model_type}'; using nn.RMSNorm for hidden_norm.") + inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(base_model.config, "rms_norm_eps", 1e-6)) + + # Inject weights from DFlash checkpoint + fc_tensor = dflash_state_dict["fc.weight"].to(torch.float32) + inner.fc.weight.data.copy_(fc_tensor) + + hn_tensor = dflash_state_dict["hidden_norm.weight"].to(torch.float32) + inner.hidden_norm.weight.data.copy_(hn_tensor) + + # Scale fc weights so activations stay within fp16 range + # RMSNorm(x/s) == RMSNorm(x), so this is zero-accuracy-cost + with torch.no_grad(): + in_feat = inner.fc.in_features + max_row_norm = inner.fc.weight.data.norm(dim=1).max().item() + fc_out_bound = (in_feat**0.5) * max_row_norm + s = max(fc_out_bound / target_absmax, 1.0) + inner.fc.weight.data.div_(s) + print( + f"[TLM] fc scale: in_features={in_feat}, max_row_norm={max_row_norm:.4f}, " + f"fc_out_bound={fc_out_bound:.2f}, s={s:.6f}" + ) + + print(f"[TLM] fc ({n * hidden_size} -> {hidden_size}) and hidden_norm attached and scaled") + return base_model + + +def build_dlm_model( + dflash_model_path: str, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None = None, +) -> AutoModelForCausalLM: + """Load the DFlash model and inject lm_head weights from the base TLM model. + + Also removes fc / hidden_norm if the DFlash checkpoint has them. + """ + dlm_model = AutoModelForCausalLM.from_pretrained(dflash_model_path, torch_dtype=torch.float32) + + with torch.no_grad(): + dlm_model.lm_head.weight.copy_(lm_head_weight) + + if lm_head_bias is not None: + if dlm_model.lm_head.bias is None: + dlm_model.lm_head.bias = nn.Parameter(lm_head_bias) + else: + with torch.no_grad(): + dlm_model.lm_head.bias.copy_(lm_head_bias) + + # DFlash checkpoints occasionally carry fc / hidden_norm — strip them + for attr in ("fc", "hidden_norm"): + if hasattr(dlm_model, attr): + delattr(dlm_model, attr) + print(f"[DLM] Removed dlm_model.{attr}") + + print(f"[DLM] lm_head injected (shape: {lm_head_weight.shape})") + return dlm_model From 9a477e87d1f05c0ef28eaaf91592cd58ae3923bb Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Fri, 22 May 2026 13:34:33 -0700 Subject: [PATCH 2/5] DFlash: add embed_tokens to DLM to remove noise embedding requirement Co-authored-by: Vahid Janfaza Co-authored-by: fannanya Signed-off-by: Vahid Janfaza Signed-off-by: Vahid Janfaza --- .../models/llama/modeling_llama.py | 10 +- .../transformers/models/modeling_auto.py | 9 +- .../models/qwen3/modeling_qwen3.py | 10 +- .../qwen3/modeling_qwen3_dflash_draft.py | 3 + .../performance/dflash/basic_inference.py | 143 ++---- examples/performance/dflash/benchmark.py | 232 +--------- examples/performance/dflash/dbg.log | 0 .../dflash/dflash_spd_benchmark.py | 30 +- .../dflash/dflash_spd_single_prompt.py | 35 +- examples/performance/dflash/make_models.py | 32 +- .../Llama-3.1-8B-Instruct_noise_embeds.npy | Bin 16512 -> 0 bytes .../noise_embedding/Qwen3-4B_noise_embeds.npy | Bin 10368 -> 0 bytes .../noise_embedding/Qwen3-8B_noise_embeds.npy | Bin 16512 -> 0 bytes .../gpt-oss-20b_noise_embeds.npy | Bin 11648 -> 0 bytes .../results-Qwen3-4B/humaneval_per_sample.csv | 165 ------- .../dflash/results-Qwen3-4B/summary.csv | 2 - examples/performance/dflash/utils.py | 406 ++++++++++++------ 17 files changed, 394 insertions(+), 683 deletions(-) delete mode 100644 examples/performance/dflash/dbg.log delete mode 100755 examples/performance/dflash/noise_embedding/Llama-3.1-8B-Instruct_noise_embeds.npy delete mode 100755 examples/performance/dflash/noise_embedding/Qwen3-4B_noise_embeds.npy delete mode 100755 examples/performance/dflash/noise_embedding/Qwen3-8B_noise_embeds.npy delete mode 100755 examples/performance/dflash/noise_embedding/gpt-oss-20b_noise_embeds.npy delete mode 100644 examples/performance/dflash/results-Qwen3-4B/humaneval_per_sample.csv delete mode 100644 examples/performance/dflash/results-Qwen3-4B/summary.csv diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 22f866b8ce..347b466ed0 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,7 +5,6 @@ # # ----------------------------------------------------------------------------- -from dataclasses import dataclass from typing import List, Optional, Tuple, Type, Union import torch @@ -37,11 +36,6 @@ from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE -@dataclass -class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): - output_embeds: Optional[torch.FloatTensor] = None - - class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -401,14 +395,12 @@ def forward( hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states).float() predicted_token_ids = logits.argmax(dim=-1).to(torch.int32) - output_embed = self.model.embed_tokens(predicted_token_ids) - return QEffCausalLMOutputWithPast( + return CausalLMOutputWithPast( loss=None, logits=predicted_token_ids, past_key_values=outputs.past_key_values, hidden_states=target_hidden, attentions=outputs.attentions, - output_embeds=output_embed, ) hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0263e0b91f..f6837d2e42 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3152,14 +3152,16 @@ def export( if self.dflash_dlm: example_inputs = { - "noise_embeds": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + # "noise_embeds": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), "target_hidden": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), "position_ids": torch.arange(seq_len, 2 * seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), "position_ids_target": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), "past_key_values": [[] for _ in range(self.num_layers)], } dynamic_axes = { - "noise_embeds": {0: "batch_size", 1: "seq_len"}, + # "noise_embeds": {0: "batch_size", 1: "seq_len"}, + "input_ids": {0: "batch_size", 1: "seq_len"}, "target_hidden": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, "position_ids_target": {0: "batch_size", 1: "seq_len"}, @@ -3284,7 +3286,6 @@ def export( if self.dflash_tlm: output_names.append("hidden_states") - output_names.append("output_embeds") return self._export( example_inputs, @@ -3414,7 +3415,7 @@ def build_decode_specialization( spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None - if self.dflash_tlm: + if self.dflash_tlm or self.dflash_dlm: spec["seq_len"] = dflash_block_size if self.continuous_batching: diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index a13cfa9a3a..c2b0ab668d 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -7,7 +7,6 @@ """PyTorch Qwen3 model.""" -from dataclasses import dataclass from typing import List, Optional, Tuple, Type, Union import torch @@ -40,11 +39,6 @@ from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE -@dataclass -class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): - output_embeds: Optional[torch.FloatTensor] = None - - # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): """ @@ -439,14 +433,12 @@ def forward( hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states).float() predicted_token_ids = logits.argmax(dim=-1).to(torch.int32) - output_embed = self.model.embed_tokens(predicted_token_ids) - return QEffCausalLMOutputWithPast( + return CausalLMOutputWithPast( loss=None, logits=predicted_token_ids, past_key_values=outputs.past_key_values, hidden_states=target_hidden, attentions=outputs.attentions, - output_embeds=output_embed, ) hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py b/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py index 0f6bc7bb64..11801490c1 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3_dflash_draft.py @@ -425,6 +425,9 @@ def forward( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) + if input_ids is not None and noise_embeds is None: + noise_embeds = self.embed_tokens(input_ids) + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True diff --git a/examples/performance/dflash/basic_inference.py b/examples/performance/dflash/basic_inference.py index ff69165492..5fbb91d983 100644 --- a/examples/performance/dflash/basic_inference.py +++ b/examples/performance/dflash/basic_inference.py @@ -11,23 +11,22 @@ Given a TLM model_name (short name OR full HF repo path) and a prompt, this script: 1. Looks up the matching DFlash DLM repo on Hugging Face. - 2. Reads hidden_size and block_size from the DLM config. - 3. Compiles TLM + DLM QPCs (only the side(s) not provided via + 2. Compiles TLM + DLM QPCs (only the side(s) not provided via --tlm_qpc / --dlm_qpc). - 4. Runs the SPD single-prompt inference script. + 3. Runs the SPD single-prompt inference script. Examples: # Compile + run with all defaults - python basic_inference.py --model_name Qwen3-4B \ + python basic_inference.py --model_name Qwen3-4B \\ --prompt "Explain speculative decoding in two sentences." # Full HF path also accepted - python basic_inference.py --model_name Qwen/Qwen3-4B \ + python basic_inference.py --model_name Qwen/Qwen3-4B \\ --prompt "Hello" # Reuse pre-compiled QPCs - python basic_inference.py --model_name Qwen3-4B \ - --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc \ + python basic_inference.py --model_name Qwen3-4B \\ + --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc \\ --prompt "What is 17 * 23?" """ @@ -41,12 +40,11 @@ sys.path.insert(0, REPO_ROOT) sys.path.insert(0, THIS_DIR) -from benchmark import MODEL_MAP, resolve_model_name # noqa: E402 # reuse the alias table +from utils import MODEL_MAP, compile_dlm_qpc, compile_tlm_qpc, resolve_model_name # noqa: E402 + +from QEfficient.utils.logging_utils import logger # noqa: E402 -# ───────────────────────────────────────────────────────────────────────────── -# Argument parsing -# ───────────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument( @@ -59,6 +57,12 @@ def parse_args(): ) p.add_argument("--prompt", required=True, help="Input prompt text.") p.add_argument("--category", default="", help="Prompt category for formatting (math, coding, reasoning, …).") + p.add_argument( + "--format_prompt", + action="store_true", + help="If set, wrap the prompt with the category-specific reasoning/coding template before sending to the model. " + "Off by default — the prompt is used verbatim.", + ) p.add_argument("--tlm_hf_path", default=None, help="Override TLM HF repo (required if mapping has None).") # Optional pre-built QPCs (skip compilation) @@ -66,8 +70,8 @@ def parse_args(): p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") # Devices / cores - p.add_argument("--tlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) - p.add_argument("--dlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) + p.add_argument("--tlm_devices", nargs="+", type=int, default=[60, 61, 62, 63]) + p.add_argument("--dlm_devices", nargs="+", type=int, default=[60, 61, 62, 63]) p.add_argument("--tlm_cores", type=int, default=8) p.add_argument("--dlm_cores", type=int, default=8) @@ -77,17 +81,10 @@ def parse_args(): p.add_argument("--generation_len", type=int, default=256) p.add_argument("--iteration", type=int, default=300) - p.add_argument("--noise_embed_path", default=None, help="Defaults to noise_embedding/_noise_embeds.npy") p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN")) - - # Internal modes used by self-spawned compile subprocesses - p.add_argument("--_build", choices=["tlm", "dlm"], default=None, help=argparse.SUPPRESS) return p.parse_args() -# ───────────────────────────────────────────────────────────────────────────── -# Main -# ───────────────────────────────────────────────────────────────────────────── def main(): args = parse_args() @@ -96,73 +93,35 @@ def main(): if tlm_repo is None: raise SystemExit(f"No default TLM HF path for '{args.model_name}'. Pass --tlm_hf_path.") - # Sub-mode: spawned compile subprocess. Reuse benchmark.py's builders so we - # don't duplicate the compile pipeline. - if args._build is not None: - from benchmark import _build_dlm, _build_tlm - - if args._build == "tlm": - _build_tlm(args, tlm_repo, dlm_repo) - else: - _build_dlm(args, tlm_repo, dlm_repo) - return - - # ── Resolve / discover hidden_size + block_size from DLM config ──────── - import transformers - - config = transformers.AutoConfig.from_pretrained(dlm_repo, token=args.hf_token, trust_remote_code=True) - hidden_size = config.hidden_size - block_size = getattr(config, "block_size", None) - print(f"DLM repo : {dlm_repo}") - print(f"hidden_size : {hidden_size}") - print(f"block_size : {block_size}") - - # ── Resolve QPC paths (compile only the side that wasn't pre-supplied) ─ - forwarded = [ - "--model_name", - args.model_name, - "--prompt", - args.prompt, - "--ctx_len", - str(args.ctx_len), - "--prefill_seq_len", - str(args.prefill_seq_len), - "--tlm_cores", - str(args.tlm_cores), - "--dlm_cores", - str(args.dlm_cores), - "--tlm_devices", - *[str(d) for d in args.tlm_devices], - "--dlm_devices", - *[str(d) for d in args.dlm_devices], - ] - if args.tlm_hf_path: - forwarded += ["--tlm_hf_path", args.tlm_hf_path] - if args.hf_token: - forwarded += ["--hf_token", args.hf_token] - if args.tlm_qpc: - print(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") + logger.info(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") tlm_qpc = args.tlm_qpc else: - tlm_qpc = _spawn_compile("tlm", forwarded) + tlm_qpc = compile_tlm_qpc( + tlm_repo, + dlm_repo, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.tlm_cores, + num_devices=len(args.tlm_devices), + hf_token=args.hf_token, + ) if args.dlm_qpc: - print(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") + logger.info(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") dlm_qpc = args.dlm_qpc else: - dlm_qpc = _spawn_compile("dlm", forwarded) - print(f"TLM qpc : {tlm_qpc}") - print(f"DLM qpc : {dlm_qpc}") - - # ── Resolve noise embed path ─────────────────────────────────────────── - noise_embed = args.noise_embed_path or os.path.join( - THIS_DIR, "noise_embedding", f"{args.model_name}_noise_embeds.npy" - ) - if not os.path.exists(noise_embed): - raise SystemExit(f"noise embedding not found: {noise_embed}\nPass --noise_embed_path explicitly.") + dlm_qpc = compile_dlm_qpc( + tlm_repo, + dlm_repo, + ctx_len=args.ctx_len, + num_cores=args.dlm_cores, + num_devices=len(args.dlm_devices), + hf_token=args.hf_token, + ) + logger.info(f"TLM qpc : {tlm_qpc}") + logger.info(f"DLM qpc : {dlm_qpc}") - # ── Run the existing single-prompt inference script ──────────────────── eval_script = os.path.join(THIS_DIR, "dflash_spd_single_prompt.py") cmd = [ sys.executable, @@ -177,8 +136,6 @@ def main(): tlm_repo, "--dlm_model_name", dlm_repo, - "--noise_embed_path", - noise_embed, "--iteration", str(args.iteration), "--ctx_len", @@ -194,31 +151,15 @@ def main(): cmd += ["--hf_token", args.hf_token] if args.category: cmd += ["--category", args.category] + if args.format_prompt: + cmd += ["--format_prompt"] - print("\n>>> launching SPD single-prompt inference:") - print(" ".join(cmd)) + logger.info("\n>>> launching SPD single-prompt inference:") + logger.info(" ".join(cmd)) rc = subprocess.run(cmd, check=False).returncode if rc != 0: raise SystemExit(f"single-prompt inference exited with rc={rc}") -def _spawn_compile(mode, argv_template): - """Run this same script with --_build {mode} in a fresh process and return - the qpc path printed on the line starting with TLM_QPC= or DLM_QPC=.""" - cmd = [sys.executable, os.path.abspath(__file__), "--_build", mode] + argv_template - print(f"\n>>> spawning compile subprocess: {' '.join(cmd)}") - proc = subprocess.run(cmd, check=False, capture_output=True, text=True) - sys.stdout.write(proc.stdout) - sys.stderr.write(proc.stderr) - if proc.returncode != 0: - raise SystemExit(f"compile subprocess (--_build {mode}) failed (rc={proc.returncode})") - - tag = "TLM_QPC=" if mode == "tlm" else "DLM_QPC=" - qpc_line = next((ln for ln in reversed(proc.stdout.splitlines()) if ln.startswith(tag)), None) - if qpc_line is None: - raise SystemExit(f"could not find {tag} line in compile output") - return qpc_line.split("=", 1)[1].strip() - - if __name__ == "__main__": main() diff --git a/examples/performance/dflash/benchmark.py b/examples/performance/dflash/benchmark.py index c28b942ee1..1a7601d83b 100644 --- a/examples/performance/dflash/benchmark.py +++ b/examples/performance/dflash/benchmark.py @@ -16,15 +16,15 @@ Examples: # Compile + run with all defaults - python run_spd.py --model_name Qwen3-4B + python benchmark.py --model_name Qwen3-4B # Reuse pre-compiled QPCs (no compilation step) - python run_spd.py --model_name Qwen3-4B \ + python benchmark.py --model_name Qwen3-4B \\ --tlm_qpc /path/to/tlm/qpc --dlm_qpc /path/to/dlm/qpc # Custom devices / cores / dataset - python run_spd.py --model_name Llama-3.1-8B-Instruct \ - --tlm_devices 0 1 2 3 --dlm_devices 4 5 6 7 \ + python benchmark.py --model_name Llama-3.1-8B-Instruct \\ + --tlm_devices 0 1 2 3 --dlm_devices 4 5 6 7 \\ --tlm_cores 8 --dlm_cores 8 --dataset gsm8k """ @@ -38,66 +38,9 @@ sys.path.insert(0, REPO_ROOT) sys.path.insert(0, THIS_DIR) +from utils import MODEL_MAP, compile_dlm_qpc, compile_tlm_qpc, resolve_model_name # noqa: E402 -# ───────────────────────────────────────────────────────────────────────────── -# model_name (TLM short) → (TLM HF repo, DLM HF repo) -# DLM column comes verbatim from the user's supported list. -# TLM column is the standard HF repo when known; otherwise None and must be -# supplied via --tlm_hf_path on the command line. -# ───────────────────────────────────────────────────────────────────────────── -MODEL_MAP = { - "gemma-4-31B-it": (None, "z-lab/gemma-4-31B-it-DFlash"), - "gemma-4-26B-A4B-it": (None, "z-lab/gemma-4-26B-A4B-it-DFlash"), - "MiniMax-M2.7": (None, "z-lab/MiniMax-M2.7-DFlash"), - "MiniMax-M2.5": (None, "z-lab/MiniMax-M2.5-DFlash"), - "Kimi-K2.6": (None, "z-lab/Kimi-K2.6-DFlash"), - "Kimi-K2.5": (None, "z-lab/Kimi-K2.5-DFlash"), - "Qwen3.6-27B": (None, "z-lab/Qwen3.6-27B-DFlash"), - "Qwen3.6-35B-A3B": (None, "z-lab/Qwen3.6-35B-A3B-DFlash"), - "Qwen3.5-4B": (None, "z-lab/Qwen3.5-4B-DFlash"), - "Qwen3.5-9B": (None, "z-lab/Qwen3.5-9B-DFlash"), - "Qwen3.5-27B": (None, "z-lab/Qwen3.5-27B-DFlash"), - "Qwen3.5-35B-A3B": (None, "z-lab/Qwen3.5-35B-A3B-DFlash"), - "Qwen3.5-122B-A10B": (None, "z-lab/Qwen3.5-122B-A10B-DFlash"), - "gpt-oss-20b": ("openai/gpt-oss-20b", "z-lab/gpt-oss-20b-DFlash"), - "gpt-oss-120b": ("openai/gpt-oss-120b", "z-lab/gpt-oss-120b-DFlash"), - "Qwen3-Coder-Next": (None, "z-lab/Qwen3-Coder-Next-DFlash"), - "Qwen3-4B": ("Qwen/Qwen3-4B", "z-lab/Qwen3-4B-DFlash-b16"), - "Qwen3-8B": ("Qwen/Qwen3-8B", "z-lab/Qwen3-8B-DFlash-b16"), - "Qwen3-Coder-30B-A3B": ("Qwen/Qwen3-Coder-30B-A3B-Instruct", "z-lab/Qwen3-Coder-30B-A3B-DFlash"), - "Llama-3.1-8B-Instruct": ("meta-llama/Llama-3.1-8B-Instruct", "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat"), -} - -# Build alias table: full HF repo path (e.g. "Qwen/Qwen3-4B") and basename -# (case-insensitive) → canonical short name. Lets users pass either form. -def _build_aliases(model_map): - aliases = {} - for short, (tlm_repo, _) in model_map.items(): - aliases[short.lower()] = short - if tlm_repo: - aliases[tlm_repo.lower()] = short - aliases[tlm_repo.split("/", 1)[-1].lower()] = short - return aliases - - -MODEL_ALIASES = _build_aliases(MODEL_MAP) - - -def resolve_model_name(name): - """Map a user-supplied model name (short, full HF path, or basename) to - the canonical short name used as a key in MODEL_MAP.""" - canonical = MODEL_ALIASES.get(name.lower()) - if canonical is None: - raise argparse.ArgumentTypeError( - f"unknown model_name '{name}'. Supported: " + ", ".join(sorted(MODEL_MAP.keys())) - ) - return canonical - - -# ───────────────────────────────────────────────────────────────────────────── -# Argument parsing -# ───────────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument( @@ -130,106 +73,10 @@ def parse_args(): p.add_argument("--dataset", default="humaneval", choices=["humaneval", "gsm8k", "math500"]) p.add_argument("--num_samples", type=int, default=0, help="0 = all samples") p.add_argument("--output_dir", default=None, help="Default: ./results-") - p.add_argument("--noise_embed_path", default=None, help="Defaults to noise_embedding/_noise_embeds.npy") p.add_argument("--hf_token", default=os.environ.get("HF_TOKEN")) - - # Internal modes used by self-spawned compile subprocesses - p.add_argument("--_build", choices=["tlm", "dlm"], default=None, help=argparse.SUPPRESS) return p.parse_args() -# ───────────────────────────────────────────────────────────────────────────── -# Compilation helpers — mirror make_models.py but parameterised -# ───────────────────────────────────────────────────────────────────────────── -def _read_dlm_meta(dlm_repo, hf_token): - from utils import load_dflash_checkpoint - - state_dict, cfg = load_dflash_checkpoint(dlm_repo) - target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) - block_size = cfg.get("block_size", None) - return state_dict, target_layer_ids, block_size - - -def _build_tlm(args, tlm_repo, dlm_repo): - import torch - from transformers import AutoModelForCausalLM - from utils import build_tlm_model - - from QEfficient import QEFFAutoModelForCausalLM - - state_dict, target_layer_ids, block_size = _read_dlm_meta(dlm_repo, args.hf_token) - tlm_target_ids = [i + 1 for i in target_layer_ids] - - print(f"[build_tlm] base={tlm_repo} dlm={dlm_repo} block_size={block_size}") - base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=args.hf_token) - build_tlm_model(base_model, state_dict, tlm_target_ids) - - tlm_qeff = QEFFAutoModelForCausalLM(base_model, qaic_config={"target_layer_ids": tlm_target_ids}) - qpc = tlm_qeff.compile( - prefill_seq_len=args.prefill_seq_len, - ctx_len=args.ctx_len, - num_cores=args.tlm_cores, - num_devices=len(args.tlm_devices), - mxfp6_matmul=True, - mxint8_kv_cache=True, - mos=1, - dflash_block_size=block_size, - ) - print(f"TLM_QPC={qpc}") - return qpc - - -def _build_dlm(args, tlm_repo, dlm_repo): - import torch - from transformers import AutoModelForCausalLM - from utils import build_dlm_model, extract_lm_head - - from QEfficient import QEFFAutoModelForCausalLM - - _, _, block_size = _read_dlm_meta(dlm_repo, args.hf_token) - - print(f"[build_dlm] dlm={dlm_repo} block_size={block_size}") - base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=args.hf_token) - lm_head_w, lm_head_b = extract_lm_head(base_model) - del base_model - - dlm_model = build_dlm_model(dlm_repo, lm_head_w, lm_head_b) - dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) - qpc = dlm_qeff.compile( - prefill_seq_len=block_size, - ctx_len=args.ctx_len, - num_cores=args.dlm_cores, - num_devices=len(args.dlm_devices), - mxfp6_matmul=True, - mxint8_kv_cache=True, - mos=1, - prefill_only=True, - ) - print(f"DLM_QPC={qpc}") - return qpc - - -def _spawn_compile(mode, argv_template): - """Run this same script with --_build {mode} in a fresh process and return - the qpc path printed on the line starting with TLM_QPC= or DLM_QPC=.""" - cmd = [sys.executable, os.path.abspath(__file__), "--_build", mode] + argv_template - print(f"\n>>> spawning compile subprocess: {' '.join(cmd)}") - proc = subprocess.run(cmd, check=False, capture_output=True, text=True) - sys.stdout.write(proc.stdout) - sys.stderr.write(proc.stderr) - if proc.returncode != 0: - raise SystemExit(f"compile subprocess (--_build {mode}) failed (rc={proc.returncode})") - - tag = "TLM_QPC=" if mode == "tlm" else "DLM_QPC=" - qpc_line = next((ln for ln in reversed(proc.stdout.splitlines()) if ln.startswith(tag)), None) - if qpc_line is None: - raise SystemExit(f"could not find {tag} line in compile output") - return qpc_line.split("=", 1)[1].strip() - - -# ───────────────────────────────────────────────────────────────────────────── -# Main -# ───────────────────────────────────────────────────────────────────────────── def main(): args = parse_args() @@ -238,70 +85,37 @@ def main(): if tlm_repo is None: raise SystemExit(f"No default TLM HF path for '{args.model_name}'. Pass --tlm_hf_path.") - # ── Sub-mode: this process exists only to compile one model ───────────── - if args._build == "tlm": - _build_tlm(args, tlm_repo, dlm_repo) - return - if args._build == "dlm": - _build_dlm(args, tlm_repo, dlm_repo) - return - - # ── Resolve / discover hidden_size + block_size from DLM config ──────── - import transformers - - config = transformers.AutoConfig.from_pretrained(dlm_repo, token=args.hf_token, trust_remote_code=True) - hidden_size = config.hidden_size - block_size = getattr(config, "block_size", None) - print(f"DLM repo : {dlm_repo}") - print(f"hidden_size : {hidden_size}") - print(f"block_size : {block_size}") - - # ── Resolve QPC paths (compile only the side that wasn't pre-supplied) ─ - forwarded = [ - "--model_name", - args.model_name, - "--ctx_len", - str(args.ctx_len), - "--prefill_seq_len", - str(args.prefill_seq_len), - "--tlm_cores", - str(args.tlm_cores), - "--dlm_cores", - str(args.dlm_cores), - "--tlm_devices", - *[str(d) for d in args.tlm_devices], - "--dlm_devices", - *[str(d) for d in args.dlm_devices], - ] - if args.tlm_hf_path: - forwarded += ["--tlm_hf_path", args.tlm_hf_path] - if args.hf_token: - forwarded += ["--hf_token", args.hf_token] - if args.tlm_qpc: print(f"[skip compile] using provided TLM qpc: {args.tlm_qpc}") tlm_qpc = args.tlm_qpc else: - tlm_qpc = _spawn_compile("tlm", forwarded) + tlm_qpc = compile_tlm_qpc( + tlm_repo, + dlm_repo, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + num_cores=args.tlm_cores, + num_devices=len(args.tlm_devices), + hf_token=args.hf_token, + ) if args.dlm_qpc: print(f"[skip compile] using provided DLM qpc: {args.dlm_qpc}") dlm_qpc = args.dlm_qpc else: - dlm_qpc = _spawn_compile("dlm", forwarded) + dlm_qpc = compile_dlm_qpc( + tlm_repo, + dlm_repo, + ctx_len=args.ctx_len, + num_cores=args.dlm_cores, + num_devices=len(args.dlm_devices), + hf_token=args.hf_token, + ) print(f"TLM qpc : {tlm_qpc}") print(f"DLM qpc : {dlm_qpc}") - # ── Resolve noise embed path ─────────────────────────────────────────── - noise_embed = args.noise_embed_path or os.path.join( - THIS_DIR, "noise_embedding", f"{args.model_name}_noise_embeds.npy" - ) - if not os.path.exists(noise_embed): - raise SystemExit(f"noise embedding not found: {noise_embed}\nPass --noise_embed_path explicitly.") - output_dir = args.output_dir or os.path.join(THIS_DIR, f"results-{args.model_name}") - # ── Run the existing SPD eval script ─────────────────────────────────── eval_script = os.path.join(THIS_DIR, "dflash_spd_benchmark.py") cmd = [ sys.executable, @@ -316,8 +130,6 @@ def main(): tlm_repo, "--dlm_model_name", dlm_repo, - "--noise_embed_path", - noise_embed, "--iteration", str(args.iteration), "--ctx_len", diff --git a/examples/performance/dflash/dbg.log b/examples/performance/dflash/dbg.log deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/performance/dflash/dflash_spd_benchmark.py b/examples/performance/dflash/dflash_spd_benchmark.py index 0644bffc9b..643e05574d 100644 --- a/examples/performance/dflash/dflash_spd_benchmark.py +++ b/examples/performance/dflash/dflash_spd_benchmark.py @@ -76,7 +76,7 @@ def run_spd_inference_single( tokenizer, dlm_session: QAICInferenceSession, tlm_session: QAICInferenceSession, - mask_token_embed, + mask_token_id: int, vocab_size: int, prompt_chunk_size: int, ctx_len: int = 4096, @@ -110,7 +110,6 @@ def run_spd_inference_single( # Set output buffers tlm_session.set_buffers({"logits": np.zeros((batch_size, prompt_chunk_size), dtype=np.int32)}) tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) - tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) dlm_session.set_buffers({"logits": np.zeros((batch_size, block_size, vocab_size), dtype=np.float32)}) tlm_cache_index = np.array([0]) @@ -136,7 +135,7 @@ def run_spd_inference_single( :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size ] dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size - dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) dlm_session.run(dlm_inputs) ## Add support when prefill_seq_len is not a multiple of block_size @@ -151,7 +150,7 @@ def run_spd_inference_single( dlm_inputs["target_hidden"] = target_hidden_rem dlm_inputs["position_ids_target"] = pos_ids_target_rem dlm_inputs["position_ids"] = pos_ids_target_rem + block_size - dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) dlm_session.run(dlm_inputs) tlm_cache_index[0] += prompt_chunk_size dlm_cache_index[0] += prompt_chunk_size @@ -176,11 +175,11 @@ def run_spd_inference_single( :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size ] dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size - dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) dlm_session.run(dlm_inputs) - noise_embeds = np.tile(mask_token_embed, (1, block_size, 1)) - noise_embeds[:, 0, :] = tlm_last_prefill_outputs["output_embeds"][:, last_prefill_pos_in_chunk, :] + input_ids = np.full((1, block_size), mask_token_id, dtype=np.int64) + input_ids[:, 0] = new_tlm_token sub_start = last_sub * block_size ## Add support when prefill_seq_len is not a multiple of block_size @@ -202,7 +201,7 @@ def run_spd_inference_single( tlm_cache_index[0] + last_prefill_pos_in_chunk + 1, tlm_cache_index[0] + last_prefill_pos_in_chunk + 1 + block_size, ).reshape(1, -1) - dlm_inputs["noise_embeds"] = noise_embeds + dlm_inputs["input_ids"] = input_ids dlm_inputs["target_hidden"] = target_hidden dlm_outputs = dlm_session.run(dlm_inputs) @@ -217,7 +216,6 @@ def run_spd_inference_single( tlm_session.set_buffers({"logits": np.zeros((batch_size, block_size), dtype=np.int32)}) tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) - tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) while gen_idx < generation_len and iteration_count < max_iterations and continue_generation: iteration_count += 1 @@ -290,8 +288,8 @@ def run_spd_inference_single( spd_counter_idx += accepted_length + 1 dlm_inputs["position_ids_target"][:, accepted_length + 1 :] = -1 dlm_inputs["position_ids"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape(1, -1) - noise_embeds[:, 0, :] = tlm_decode_outputs["output_embeds"][:, accepted_length, :] - dlm_inputs["noise_embeds"] = noise_embeds + input_ids[:, 0] = new_tlm_token + dlm_inputs["input_ids"] = input_ids dlm_inputs["target_hidden"] = target_hidden dlm_outputs = dlm_session.run(dlm_inputs) metrics.dlm_decode_time += time.time() - dlm_decode_start @@ -387,7 +385,7 @@ def evaluate_dataset( tokenizer, dlm_session, tlm_session, - mask_token_embed, + mask_token_id: int, vocab_size: int, prompt_chunk_size: int, ctx_len: int = 4096, @@ -431,7 +429,7 @@ def evaluate_dataset( max_iterations=max_iterations, hidden_size=hidden_size, generation_len=generation_len, - mask_token_embed=mask_token_embed, + mask_token_id=mask_token_id, ) ar = metrics.acceptance_rate() @@ -533,7 +531,6 @@ def parse_args(): parser.add_argument("--dlm_qpc", required=True) parser.add_argument("--tlm_model_name", required=True) parser.add_argument("--dlm_model_name", required=True) - parser.add_argument("--noise_embed_path", required=True) parser.add_argument("--iteration", type=int, default=300) parser.add_argument("--ctx_len", type=int, default=4096) parser.add_argument("--generation_len", type=int, default=1024) @@ -560,7 +557,8 @@ def main(): vocab_size = config.vocab_size hidden_size = config.hidden_size block_size = config.block_size - mask_token_embed = np.load(args.noise_embed_path) + dflash_cfg = getattr(config, "dflash_config", None) or config.to_dict().get("dflash_config", {}) + mask_token_id = dflash_cfg["mask_token_id"] if isinstance(dflash_cfg, dict) else dflash_cfg.mask_token_id if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -588,7 +586,7 @@ def main(): tlm_session=tlm_session, vocab_size=vocab_size, prompt_chunk_size=prompt_chunk_size, - mask_token_embed=mask_token_embed, + mask_token_id=mask_token_id, ctx_len=args.ctx_len, block_size=block_size, max_iterations=args.iteration, diff --git a/examples/performance/dflash/dflash_spd_single_prompt.py b/examples/performance/dflash/dflash_spd_single_prompt.py index 5cff09c9cc..bcc11eb72b 100644 --- a/examples/performance/dflash/dflash_spd_single_prompt.py +++ b/examples/performance/dflash/dflash_spd_single_prompt.py @@ -72,7 +72,7 @@ def run_spd_inference_single( tokenizer, dlm_session: QAICInferenceSession, tlm_session: QAICInferenceSession, - mask_token_embed, + mask_token_id: int, vocab_size: int, prompt_chunk_size: int, ctx_len: int = 4096, @@ -105,7 +105,6 @@ def run_spd_inference_single( # Set output buffers tlm_session.set_buffers({"logits": np.zeros((batch_size, prompt_chunk_size), dtype=np.int32)}) tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) - tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, prompt_chunk_size, hidden_size), dtype=np.float32)}) dlm_session.set_buffers({"logits": np.zeros((batch_size, block_size, vocab_size), dtype=np.float32)}) tlm_cache_index = np.array([0]) @@ -130,7 +129,7 @@ def run_spd_inference_single( :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size ] dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size - dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) dlm_session.run(dlm_inputs) if remainder > 0: sub_start = num_sub_blocks * block_size @@ -143,7 +142,7 @@ def run_spd_inference_single( dlm_inputs["target_hidden"] = target_hidden_rem dlm_inputs["position_ids_target"] = pos_ids_target_rem dlm_inputs["position_ids"] = pos_ids_target_rem + block_size - dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) dlm_session.run(dlm_inputs) tlm_cache_index[0] += prompt_chunk_size dlm_cache_index[0] += prompt_chunk_size @@ -167,11 +166,11 @@ def run_spd_inference_single( :, tlm_cache_index[0] + sub_start : tlm_cache_index[0] + sub_start + block_size ] dlm_inputs["position_ids"] = dlm_inputs["position_ids_target"] + block_size - dlm_inputs["noise_embeds"] = np.full((1, block_size, hidden_size), 1, dtype=np.float32) + dlm_inputs["input_ids"] = np.full((1, block_size), mask_token_id, dtype=np.int64) dlm_session.run(dlm_inputs) - noise_embeds = np.tile(mask_token_embed, (1, block_size, 1)) - noise_embeds[:, 0, :] = tlm_last_prefill_outputs["output_embeds"][:, last_prefill_pos_in_chunk, :] + input_ids = np.full((1, block_size), mask_token_id, dtype=np.int64) + input_ids[:, 0] = new_tlm_token sub_start = last_sub * block_size if last_sub < num_sub_blocks: target_hidden = tlm_last_prefill_outputs["hidden_states"][:, sub_start : sub_start + block_size, :] @@ -190,7 +189,7 @@ def run_spd_inference_single( tlm_cache_index[0] + last_prefill_pos_in_chunk + 1, tlm_cache_index[0] + last_prefill_pos_in_chunk + 1 + block_size, ).reshape(1, -1) - dlm_inputs["noise_embeds"] = noise_embeds + dlm_inputs["input_ids"] = input_ids dlm_inputs["target_hidden"] = target_hidden dlm_outputs = dlm_session.run(dlm_inputs) @@ -205,7 +204,6 @@ def run_spd_inference_single( tlm_session.set_buffers({"logits": np.zeros((batch_size, block_size), dtype=np.int32)}) tlm_session.set_buffers({"hidden_states": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) - tlm_session.set_buffers({"output_embeds": np.zeros((batch_size, block_size, hidden_size), dtype=np.float32)}) while gen_idx < generation_len and iteration_count < max_iterations and continue_generation: iteration_count += 1 @@ -276,8 +274,8 @@ def run_spd_inference_single( spd_counter_idx += accepted_length + 1 dlm_inputs["position_ids_target"][:, accepted_length + 1 :] = -1 dlm_inputs["position_ids"] = np.arange(spd_counter_idx + 1, spd_counter_idx + block_size + 1).reshape(1, -1) - noise_embeds[:, 0, :] = tlm_decode_outputs["output_embeds"][:, accepted_length, :] - dlm_inputs["noise_embeds"] = noise_embeds + input_ids[:, 0] = new_tlm_token + dlm_inputs["input_ids"] = input_ids dlm_inputs["target_hidden"] = target_hidden dlm_outputs = dlm_session.run(dlm_inputs) metrics.dlm_decode_time += time.time() - dlm_decode_start @@ -298,7 +296,6 @@ def parse_args(): parser.add_argument("--dlm_qpc", required=True) parser.add_argument("--tlm_model_name", required=True) parser.add_argument("--dlm_model_name", required=True) - parser.add_argument("--noise_embed_path", required=True) parser.add_argument("--iteration", type=int, default=300) parser.add_argument("--ctx_len", type=int, default=4096) parser.add_argument("--generation_len", type=int, default=256) @@ -310,6 +307,12 @@ def parse_args(): default="", help="Prompt category for formatting (math, coding, reasoning, …). Defaults to the general reasoning format.", ) + parser.add_argument( + "--format_prompt", + action="store_true", + help="If set, wrap the prompt with the category-specific template from utils.format_prompt. " + "Off by default — the prompt is used verbatim.", + ) return parser.parse_args() @@ -327,7 +330,8 @@ def main(): vocab_size = config.vocab_size hidden_size = config.hidden_size block_size = config.block_size - mask_token_embed = np.load(args.noise_embed_path) + dflash_cfg = getattr(config, "dflash_config", None) or config.to_dict().get("dflash_config", {}) + mask_token_id = dflash_cfg["mask_token_id"] if isinstance(dflash_cfg, dict) else dflash_cfg.mask_token_id if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -348,7 +352,8 @@ def main(): ) console.print(f"prompt_chunk_size = {prompt_chunk_size}") - messages = [{"role": "user", "content": format_prompt(args.prompt, args.category)}] + user_content = format_prompt(args.prompt, args.category) if args.format_prompt else args.prompt + messages = [{"role": "user", "content": user_content}] prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) @@ -366,7 +371,7 @@ def main(): max_iterations=args.iteration, hidden_size=hidden_size, generation_len=args.generation_len, - mask_token_embed=mask_token_embed, + mask_token_id=mask_token_id, ) output_parts = ["Output: "] diff --git a/examples/performance/dflash/make_models.py b/examples/performance/dflash/make_models.py index 8cbeeae4e4..8ae7fcae4c 100644 --- a/examples/performance/dflash/make_models.py +++ b/examples/performance/dflash/make_models.py @@ -28,9 +28,10 @@ import torch from transformers import AutoModelForCausalLM -from utils import build_dlm_model, build_tlm_model, extract_lm_head, load_dflash_checkpoint +from utils import build_dlm_model, build_tlm_model, extract_embed, extract_lm_head, load_dflash_checkpoint from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.logging_utils import logger # ── Paths ───────────────────────────────────────────────────────────────────── TLM_MODEL_PATH = "Qwen/Qwen3-4B" @@ -59,20 +60,20 @@ def _load_dflash_meta(): target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) mask_token_id = cfg.get("dflash_config", {}).get("mask_token_id", []) block_size = cfg.get("block_size", None) - print(f" target_layer_ids : {target_layer_ids}") - print(f" mask_token_id : {mask_token_id}") - print(f" block_size : {block_size}") + logger.info(f" target_layer_ids : {target_layer_ids}") + logger.info(f" mask_token_id : {mask_token_id}") + logger.info(f" block_size : {block_size}") return dflash_state_dict, target_layer_ids, block_size def build_tlm(): - print(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") + logger.info(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") dflash_state_dict, target_layer_ids, block_size = _load_dflash_meta() - print(f"\nLoading base model: {TLM_MODEL_PATH}") + logger.info(f"\nLoading base model: {TLM_MODEL_PATH}") base_model = AutoModelForCausalLM.from_pretrained(TLM_MODEL_PATH, torch_dtype=torch.float32) - print("\n=== TLM ===") + logger.info("\n=== TLM ===") tlm_target_ids = [i + 1 for i in target_layer_ids] build_tlm_model(base_model, dflash_state_dict, tlm_target_ids) @@ -83,21 +84,22 @@ def build_tlm(): dflash_block_size=block_size, **COMPILE_KWARGS, ) - print(f"tlm_qpc_path: {tlm_qpc_path}") + logger.info(f"tlm_qpc_path: {tlm_qpc_path}") return tlm_qpc_path def build_dlm(): - print(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") + logger.info(f"Loading DFlash checkpoint: {DFLASH_MODEL_PATH}") _, _, block_size = _load_dflash_meta() - print(f"\nLoading base model (for lm_head): {TLM_MODEL_PATH}") + logger.info(f"\nLoading base model (for lm_head): {TLM_MODEL_PATH}") base_model = AutoModelForCausalLM.from_pretrained(TLM_MODEL_PATH, torch_dtype=torch.float32) lm_head_weight, lm_head_bias = extract_lm_head(base_model) + embed_weight = extract_embed(base_model) del base_model - print("\n=== DLM ===") - dlm_model = build_dlm_model(DFLASH_MODEL_PATH, lm_head_weight, lm_head_bias) + logger.info("\n=== DLM ===") + dlm_model = build_dlm_model(DFLASH_MODEL_PATH, lm_head_weight, lm_head_bias, embed_weight) dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) dlm_qpc_path = dlm_qeff.compile( @@ -106,12 +108,12 @@ def build_dlm(): prefill_only=True, **COMPILE_KWARGS, ) - print(f"dlm_qpc_path: {dlm_qpc_path}") + logger.info(f"dlm_qpc_path: {dlm_qpc_path}") return dlm_qpc_path def _run_subprocess(mode: str): - print(f"\n>>> Spawning subprocess: --mode {mode}") + logger.info(f"\n>>> Spawning subprocess: --mode {mode}") result = subprocess.run( [sys.executable, os.path.abspath(__file__), "--mode", mode], check=False, @@ -137,7 +139,7 @@ def main(): else: _run_subprocess("tlm") _run_subprocess("dlm") - print("\n=== Done ===") + logger.info("\n=== Done ===") if __name__ == "__main__": diff --git a/examples/performance/dflash/noise_embedding/Llama-3.1-8B-Instruct_noise_embeds.npy b/examples/performance/dflash/noise_embedding/Llama-3.1-8B-Instruct_noise_embeds.npy deleted file mode 100755 index d81c2348d6046718e449e9c7896a4f15c61fcc8a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16512 zcmbW;jqlHO`#10pDU~9WzMoI0x)`a8X}EHet}JG#i_)*0vDCWZl>3`R z^7)kHW*F0Im||KbV_G#TwOU#0#!^`|LvdGJbKOl-~TQymF};UN)xN4(!1qSsWp2U zOot}?kJpq+Uy6MX$5xk0ool7i5ZW^Sf5dNw2jrGuF>M4MEoTYrk=GR;q~-dL8UH5x zP4=p~Qt2DcaygqhQlD* zQ?v?hhrckU7d{|11<%I!&^q7`;i7V>)Kl*L>^9Dz~F+w5pa$m-E-*fABw( z_aMGgyb<`4( zREwI}f~V3C;~v&wg!{k5{ zFaM0()aPM#1NL+n1x>`i!P)cuo^b&lMw?hJmF|;26t^|zGdv%sH*}KMoZVIKJjk3jQh%4Rvo@!UY0f@Pdzk$Xn9jb9cCGkP*Xe3Z=CcRFSBXs$8*GoVk30E! zuC;I{eGh*SY%c33{2Sq7_6(dpvanJrjdb0b8eGU9!=E9(81Ci&r)n)|=~WH!DbM<( zSkBwq@SONZ?!8EC3Vce-ULAJLqgNF3xlYdA^u%AM|G<8YwhYdvjjNSP{XJJZ=t4Uv zzK@o9H>Z8XUmRQR5_)<>>ii5WgnCGw?hs2|jl@ry=L?XzJ_`d}=N=d;o-^DZ-yp97 z*YgKf^+9Wpb6lpSeqN$Yfz;1Qd=4L`uf;3iecH>gftG7`trTN&rZcDC;W@aTzevt( z$g}K*3)ny5@y0|yWUhU1YG)lB&$l6S{+2e3om#nxU#;3B+Af^)pZ&^yuY}BP zh1~VH8N6H8NAPkvOW3=_-Z%Ct{$TbEFbkH8rAOsl=XtuF6zjpCiI>qb_pId&xQ}>0 z$n!R%jfa;#%QgH#;th<+I<rtVS)pW@~)9!iy>U!7sMVb8`(!$|9ig6I zC%%<7pZ_DjCI4FdI6u$w|N05exjjbD8hqe6rqCur>ZQHdJ@A!hS|X_q;frZ);8kNXuP^YuuzZgbMT*uFxSJUpq57B>EtuD$%k2tr+8K7^0 z%x4|HPR=F#%W%%mD{}XV&4t)I%;{I0GkO#3rH|%cCI3GD6!slB_s$vWj4z_4KV{8R zQ!OjbC_GKOf_8_wE_K}r?E7fn!FF?4-X3>uA03i*gQIi>%xC&c% z`Q5}C;}x|1c%OSUqgGBo%M{C^~#>Fz+LfF+!CjT zzJ%;;&P|@7%FbH%W>13d)#APxw-WwM>j)pRmqDKWE@OHavx5B@tY9)99z7`3R>^<=p&--2vZ%Q@`iK&T{eF> z_s19Ux6${qFNC?CVT@-QP0KkJYND-p=6OH<&ND1z zr(UxbS-(f|L9z6m!D3m*xB2yO5Gr!z_6+Bn7MsSq2TI>e%OaE&p_BCxXExoP{ zJGFPc^%{ax2Xo5BJMA5Or@ZH35xmH6C#N4C<9g4s^PbAHjKjC`C$Mvde{sDxTr2Ow ztm|0#9p=hAjt7aYE*E?I7f%1Jh%IFg7vBgaTGseg_qvRIH~Rs+jNjDtvv;Rxb&zkH zC#%KxSFYcZpXcf-C-1;VjLA8lZjb8mJus4*X*oQclO@D^*qBb)#4p}DZWN5 zXKID`D%ZM znd`-J@=TYoe}He{)(a=G*G&a`CsI7Pt|lpU%a{tEJL!a>ldMqjIKFbDi-~2>&eSx{CLWyh;4?*#mG3 z*Z70BO#CWZfA)v;)NN`fHF!pBhdqcrf%Z8}6;EAU00-eOvHbn!eK;sT^*ELN9Hejc zrDYy@*JLg0E5*C(_Qcq2#Zw<|8#g>YJKwzb@}FT}4mY#Yt9Mt6_t-w#C+y52efn+J z`Zr!n%d?D+zGuO~*g+a|Y(|vj=DJaA*SAgFNe1>_PZv_h^T^!rfQ8 zk8AEXrU^ZJJ;ik<%Dszzj(xqncX56{r#5C()f&u()L%US-77*RuPtqm9NnH`#-n$M$d=^ygRUw`OC`v4l#ABT~& z*emg^^wrRjKLN5Q10ZX+2xpHczzI3e($a4}$E$tMeM3th+<+J3hs6eY&e$)~ni-eA z5*$*e`Ht8_-zz6;n(x1?MIAfOU53;Ab#n8rUoZbj=uK+^+1sP=v{>F#sgH+ouKOK3 z>vV;@!L&p0JwNrFKJ*G??ec6_(yqqKao$(&%N@p_V_xZ{<6sBvNAbMl)8hwLig#e1 zH}}haH>9O6Zsn)HXFhrVTuZx#-3#}DY3%m+1)Q}wD1QdKDZ4X!DEk%|2oD(lHGO5Z zxcD-hj40tB;*LTG@{>$bKMD$Iy)?W1MOCJ>USZY&%cGe z3ExxGyFFuazv4d*&1gf#^PSh1f0q9gyh}SK?=jkTyiYu}ay@+(tqZ=HUt%}Kslk=< zp636o=32?mJ;$vndTsjkbbh{nviEbv*Rs#D?`BUE`w;SuyS-ZUuk^F4VLmP2z^&L- zd`^6FtvJ)?)Jiw@Y&a?}eSIs8fG+$o_G^>aMf`906`cBhrdsr-oSS8^#vF1kU!tep zvVU!9IcGKZJVHMpzJQ&&`bO*{+O?2#dm6ut-;(o{obBwqA3N}~PIWkEA=m2wGv&ws z73X~RWzU7=U5r=L(_a_6PXqqV?3|JR@O$7VXxLYsS6DZX%Xm`5SwYob;7Z z&db-vT?l!O(Q?vX@?PC#{Ti0lB|H0f5HG`*i>Dr1!}aV(X=~XZ;r94Z*uwtOy>l+p zKWDSIRExU%!gZUmH~5{pOMHj?qx^yF_whw?U$!?D{z;f%{DsgT&WPt3`{C^OUfM70 zypzsnzs+8cyFdl*aj)#f6YRSyYPMW_Yvo;Xaqc5F4i3?uur_%IAA`}+tWA64rm}a# z=5le3-NvP6QtxAE^X0sbSK{nzzJGGAlaoHOjh}Pfh<`ck;9tsL!Op!>m-R5a>>Tl% z;Hx3uMCluMx?lFDbxj|}JzO)jc8a|QvaYF#Mmgl*~O}K@zv)Ngv>`&%- zyjuKbXd0FiZ3DYcO-KO~klT|mpbH|KUE{Xu&=!kGKS^1jO2 z=6nvfUz72}@^Tgy$^D+6{mNP$pjGVgi|`b`6@8i9S;jBN>1!=%&+sS9-Nk=~p7+Zl zc50+1_B+n`NzeP3z7uYgmwD$bj)lAv6VLoNil@hoWH*9gm7;(5#r+`9eWCa%*beLX zhds-y)#4qq6*rf6TrB(1hW|ExQQit1J>yzbi?g|j{v-QR+OP06tc5o`Th8)C*BXnn z*E{ID_&HbaitlhHKccO1jqF=XvGk~I;(1U0X)SKWLuf<9=0o;n96SV_Tq8YXUd_JA zTY}Gdmea=1p=WQ-gU0;v_*>WOCogN7K9av9X3eVX0gyBC1@x{KXTF121LHpDKL?+( zZ=hxF5AbKRtN48S^R$t8E-a%h7aN9K@Uvz=;+?eSv^lVueU|<`yE!{G)kb~gtX(X3 zpt+?E(u>khJ{7--z729-8q%^JS-<@5$~oV!*1p8EXj$vbBei%AlD9;D&cH7Jp52ta zmo^A;-g@Dkw8`wB@hvb*-Z%WzT54_>PHo(ePr(plrt`mb-`Dwjrkt;w?C9nf0FuudzGgoTJrZ%j(P@?~t3{!d=CV zLI-(ImSWaKy4CSo163;i)K=x5~?wg))wVcc=HSn-Ce}KKF zVx94m@R2ds;;i?|YH{DZbI+&kz&USMuv2fjPER>`*6i8q(1X4TXARbr^ZUyr-XxsLrS;I1U*`FD;;HgW_$ZzRAHpVMp0dwd z>Di0D)xV#OpT2Wdt@!<(Gd+xdFLYveG;RPp>z&$qn3kGM9rvJp06Q`tIOG25BdLq5 zPv-rB+%0mB!e;g#a?+bl$nPMho1DJ#8?*DC%ek7%ZvnNFa8a1 z2km-!^IUHny(4=QPVIgwmL8jT!8Y0=c53-twRi{C+_RI|Becw8HGKlU684+t9k@?T z55*(!8F~HLkFv9;_u+-Kq4+xfcH=v-4b|>|g^=|Lp>1!_-vkSVh^Gr93%@LcA(~Eo3 zr$Y8Nf190!Q*&9no9LsB$vgNoZVfBS#lHT9zr3=j#oqFAjokYpob#T#x`clo|3|s+ z^1s1zY3WxvYctv1an41~de(oaIby~HB_w{a^@70{A*Kk*37s|=+^P6EftPuM* zyBQoT7kz0gEobyY+8Xv?`U!h}Av-lNk^UWYXJ^eD(4XV4XJ5Ok3v!A40Aa)Jy2tRA`GA;GE%)0F|r@Ys)j#uFu@mSa+FYD1AR>}PhKThB5z1tI~ z9XjX98>mbrI%X=^1rO(T|TK;VQezARW2J*MiSF%%oPqB0VW#U=G z{r(P-=lP1CdOigWVT#xWYm^@TvTLVT=jvq$+KLSyz&>s1KX({+Eo~$}@2I`})$-=S#r&L$)JdMPCNFEU6py5B zpv{KV^v(QZ_$A01HN=_YNBAAwPP~qucg8d2;tZxf@2vCt0KYA75afN)iQm!OvImbC zy9MVt^DQ`xy@{Rkdk;>}n_Vluu@hg;9tGXS^LwWZdCv6ieHG`9{?nS`EH0_)&-~QH zboMcMfATYzKD77P-?P&n`a*i+4E_eFz>EBR3#`Di`8gY3)Qb1al(Kqo@2pAozArm{ zsyDy4G3h%8aOyI1OJ8|Y&L;X?Ib-Mcuw6AFyKac$?dprEZ9}C&vi7=V{Ieh}ocj$xV z;(NS_*bM$G*E=qD6Eqgfb3ZDUeli8}?487RuopqD+pb)^FPgDmq2+JVO`!pOn%vp; zJ@d@G57WM)En)9wA7n3LkAgrj1AK+fGgkiPLw;+5iC<2;EzYjVYlXH;2ZQdsD zD)t(@fxpq1x$K;c2k3L<9$}Ajowm3G&U3G%=l|z+C4Ygu)b}&cmA_Ve1mtYh;8EAT zUrx^UcvuQI@zaO$j%;B}H#yIqmbulUMc>@(U-lL-2&>Nx3P=mwx&179caf~XEi_l zBj@M@zX9$KdG@>`a-FPk`hEk*8DCv4_I-t%Ns#$nYV2j=sg-fKE8Ys(hm&=lNzO`k zo@=e#T15>(fBJj!awc*HGrv54Q#q~0hqJR+Z{W<~3tWX0>@hg|^g&twbdOK*1Y@$k zsh!NHxA-7>TQxI4>?VHJ@?x=i_ZY@ME1p_SE#}{SKNH`~?pQ5;W8J`RNW0XS)b(WJ zwz3zn7mLji8;@JidcY`t`bhRZbunJvwK!+#2;2=_Am{&OIoGf=t`+}H{^PhK>^1HI z`oGy7*q7oV{7&#a|26YCN=xs*6E39phr8J6(cgHMtVQ(f_ZO~R!q32b zYQ-KrN$=$p#h}kI~v!I{kfdAcsstF{-JB#$6g4z-UhKn@DY7P+4~hn^2b4A_KaHbu04(C zc=q&xTrYh+|G&M_uG5y^UhW+BMfCKMWzlJGi7kcv{qU4}c>?bj>x-{9_o4FZ;YoJR zR?h9!{A1#Ik3AuG0{dN@HOrjugEp{De97N->SRKl{vdC={1IZ`!t#nb=kKHUXOD&F z_~EWKy^DRDakj0Cwal5G2zkCB&$p4DzgMrTs5QtQof4a6Tnn-E zvgEd6-yqf#zg{iAOK;@o9l8M?VZTSudiS$N@3SWu_cA@-bg%I57F)n?4Fklo=HJkU z`CXCUeZN2(div5k)-unQv-GRE^ujG6XJi$gLhB@NJ-);5!pyBKmVLObTs;3L@^+fb zY}ftW_1f^KiEo03AU$Zc@!6YkVvQ@sTDNAuWnB8wI(B+Y>Mm<~wp`5d1KKqCncH<@ z-_d`E2kNUw3cp4bFRHa=DmC{x0$kJL}wp-B~QplYWst^@=gg z*qP7paxwM=TCUfIe;*#oZz%o%J8L}NbxyN0CUudz&oiZ_=JPYJvvqnCzY8qor>|#k z@8joNszc2_^Lyg-puCH-ukC1uU`5$EqSfK-S$f0-$amx*d0Dd_@L09zUz5dV@!wTn zHFo|TCFgaLczVqkoM%W)+8-xhC$r;R%+{{Vha{3hC6;>@aqIV|-NmzwnkE AQ~&?~ diff --git a/examples/performance/dflash/noise_embedding/Qwen3-4B_noise_embeds.npy b/examples/performance/dflash/noise_embedding/Qwen3-4B_noise_embeds.npy deleted file mode 100755 index 31c70ee77c37848870e813fadb1f72d0186758ba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10368 zcmbW-iT{>U{s!>7$WqACMApiZNQ=_)eToP%gtC(?31f>g1`U-;h)U6>lqMupNUG4Fj%s%~l z_iIwsut~?W+B9j_u*q3{`(4uS^b1ez+waUi@gIEp`ThIw_dnM&}8Mf+F6#v9~RTssR{6Ek$=^YjqMP2$ZzH8Yn*tg=2 zd;_bBVkCcSZM*Qz;k%db0DK4k12}!%B356{0{jfrE*HfH9g1R+?B0j@?e~x9pkaqUh`$4&kdQHlE!sT6`n+hHx8u8s4Yc^YwQhdmDL&^Ud^L zeQ~8VWLzbjG4zJo&;-U;>)#lv@cQh9`W`0t9JrS5&v!lk z8}x*2VSK454&*;on~glLK3^?5bH6SBrf?B`Apd;1TP-Y#$LY16v#Yt;T5LJJnY@ho ze6gc&D?D0%w`j8xHj;CuobBkm(<5cI0_(|b$KMOGUR~hEQc+BVG4LB$&x5SrOl{WZJ7Ymn+(2K8&k);!?^E_*`dH(+n0;TlDE7y{;2HeS;|thb zqik;&00vK)=TSG=FMjHUCTadF?(E%e=o#Z)wiolK%}~_D#lD zPn+B6Ik1ZT6|{!BupDkEsWG?%ma!kl7wYS1>$o3Y!u}doie-KehZ>&I)fk$KXJ0IW zcIC=BF~jpFi}zt?ji-x^VecDeKf(VUdlzHvg{Q!$);l$Hm3UwAjq$teN&IiKH>fU( zKjb~6O;7d~u)0dE%R5lsZS42ix3Sk_Ph{_eFT{DjoDoZL);4=~9Bu$_z!L2q)PK%_ z?4J!_qjFJ95X+op9#3TN0>8`69PB~o{cGTU?3Wi6#gTM#xRZYd9uAr39%5;8D4xXL zgWj>)Jkr0>dG4<48O(*|u=?Smc)hGfp}9WWWMAOBS4 zP2xY6ok8rU#CF8h>~H9=+Z9D)@ti-k_4jf|^)H^iUlUK~A1ikZojsZJDD$_K_=x!7 zR_!0czvJ}rs&~JHeKD-l{?3wpMK|-jjm5KWv-l3ina>*VJA0{|hIoN~2GULWGjFfa z>%p)3eU*JPbkbM$Z+-rzkQ%#+{}p;M&N=jVXsM6a78S)o{6F$9NEyYKQ% zrVpX3;O9T%v*{6Z_S@EF?~NzJCZ2mXdyIGcLjNVc)aSPJ`TYArau37Rcr|_-FT_3Y zGoH7eKK^??(5Z>d*{Q`_@GE@(hW+K9>&$uxkA{uC{|Y$`78S)1@n-bL&`0bH?VhHm z)7e))%9$=_7_Q0xIOI%QA+|tGT#2{gdtLl~@3D?Ok3C77oF(tmAK-k0ccYJ^A6=wB zoV}CzJG4|;%gpPq{CA7z`}``-x4Z+q#Xr2G{lk|!>VmJr4``F`SR?uox|W<##&x@Q zJb~SpJya}vtcARZ{2lPaP!>N!tPM^LG?w!k^pLZd|9bdC>~_4R{9WPx?6ZDfc+!Bvq57(|cdo#KB%gcVtI%gdQ(T~YVzf<8zZ8H90^z(R)w&#hD zV&^RB1bdp}tkX#GiDDUB6P$HCmdEAKYw$-fzUsNBJL314S;&U?8e+6Yt)ka-P4(cX_EYp7(L; zcz@42h5ieSl-HB*N_sE+ptjrTV+@@&>x)lg?`_TVTjVyr7ho=5&dRzt_o;399@=AJ)J)com*2d4IV%gAZWOGKL9qC(|9p&VoE|x}5C2tW(bCtbJc?hO#s7>#?_j zoC}}wWqt;-XW-Jp%AI89%rGkG=3({sDZ;VI#i!a^`El0sSK^g0%KR=23ly7ecN8% z9{eZ6@qCNLmhs-?4+}dh(8k+%=bpzezuc^P4C2l(Q}CyrKALI^WRT zEpyJMKC_Sa@%_j*@lJWIv}r2#BAw^9lyeV080Yur_Rx;KG2ANmUOZK7J|2a~zzF+d()L$cdG0s_WjNh?O>w8nd@0R>VXD<+s@A|~G&t2;cv5o03@U}3G{s+$clwk~gg0{8t7Nxy(~)CgUpP-J_q0 zbS=CqPVOpg?kiWm*Yof~{)6zXa62r3tle-ttAw>(LD$iLC;T)#2Kk+s_sczDiayul zt1G95=iU}SeI4IA$lmMC^k+A~I~&hO zcsLB;tL=Gr%Uf#ATk@?G&-~`u@6n_1fu5P~Xlu9(@@-uNz3EXn`!K)zJ}Xzwu|JF_ z;~XrW?|1g)WU*!Va-2OgivMJ&$!PT(3@}Pl68ZAo-v<&Kb`UA z9Zo5$8@@Ym&hX?uMQ;vY8(;44=|AuJ6wW#g!h2R#zWEQzI|GJ5C;s&PEIkVU?OqA?suA48{B^P6!B`*6tl_dZ|h z>PEa9tfJ@O-{4kAO>R@J?EUui$HtTI+ogQXpbh-0%~SAYsZ!r7@d@nLd35e8JMcdt{>(ymA)MbxrSv2IwY>k{zuBp~ zc{nxLOU_jJm(WvToV>R9g>q%TX5U=mZ&0bjto0)NZ}FBeQ0|?0DS!6B$M_At?TzJn zdD&l$+1a1zGvEKQa`LY0!~0?fvycAMZsgoOimt~u8Rxx9Fb1;s-o}&QI{I$h6*h-I z%9Y=OD<^xOZuB7j?1h|1W!M6;5B7mJ`uG~ZSFW7xqv`F;MK%9s^s|tC)u~*` zuT5t!H<2?8uA|?SyD`1WxR1o`_I2bY|ou#&!=FXO!#7RcS3e+nMS zUPh;;GUqpWhyD0^nY(fLasE83cE?JcOysKtd+>cK){1|yb~!`0fx3ME5WfjNfXx3& zoPG0yc;@9SzGkr4GqR_%uA|^F`Y@cDSO*#Ni8yCR{>J(gdlsF&QG@*nUJV=aEh$yL ziP_^f8r#qKaXha#jt+hw?-~zw8@>RBVW$h zkNCRD*-!T&pdHGYM!GyVi-(lhb7iz@e@+4L=t{WJ?Q-i+-n@kw;%CUctk z9uM90xy*B_`M;FY5~qC^dPlXL-*3OudH3US_VkhT+^Wj&@~qWJ`a_)G|7~#UFJtNr z8?v|W;Ecy}*l*+KA@lNycJ1(~?0ipCTUB%!j^occnR{yL^Pl=oy?rj0JvT$!&*(vL zo3{B)@I2p*+Vq67+>gZO!wB}%?7T}$I)4kz{r+2eCENg6oBQxcSPVIn`@<@@9^QiO zs+@V{%J<+2cE;8X8qhz z6@0mG52O2tt;Rzj{iTl|`1Y%^CUo}VaD670+M9^a!Rb5uXr0_f#``_)%g&zK6tdpE zoKI(s=g?H_@=)YTh! zO{sE*r^Yf@kFoEEZ?)+vCuc@#D|L4*)D)YepB^~ht#|RxxWD`j=@E45;0kv7NqtU% z-tap6M4Y-^qK`3jYx+R^J2axl!Go}~XLiAp&E0SM`l?*PBihZN4;Rb4PL9T*Gk#5Slh ze)?UpA#$Ff-=I6gb#MWDv3~1|{UNs%-;p?LoN;_Tzj9yAx!aZRL!AA73H=>(hj;XG zFW+7C9on|XHTjnE4Tk;TS-y_AF+LX-(3e8a?g!ZG$|L52{L;4+#Gp3gzV_j4BU51>4 z&$3gaw?f|G0r|(kDE%H^uAJv>a9?Ao79R&+u~R?C$-97^`kpQS9eOU_3a_v4rsjKp z`uTEYJxll_b}Pu=DF(8q(=+k;+C{IWSLk;b?##}+k7vJtn?d~d!_M+&;}&=_{6LQ> zIVW&q&%VsrF_-?7?+Ca;PWIKq^g_s(Qr8*hZtTaZd<)qpL+0X0$QsPX8BgjWwfV4q zp5cEK65E5FddT1Z{>5I(9^u}azb&t2Z!dN?-T_j_+3z{C?u5^^$=~))GS(rb%H3%) z9>xE^-^i!14}nYcn{~XbT$zir>51~PS99L(iEEkD_3?}BUGby(z7ppv$M8?;Ayx>{bZd}^WE?#Fi>CjFHN{ScitC%J(hboCoNlqUnRDQv8|$4(2Ma8xEjx}o-^pL%*SuI24A85MBId3X!i@=o$mx(hy5D9 z3@+Dy?l77E51}iabLAU2vRt{J?j+~0?8D(s=)u3ay!_kT2c%Z6)qg#l@7gl` zWK6BJTf?3M)A&+zL!p42q4(hS`nrow+q_HHg_Yk*smJ%=WxiG7bIa+Y-Ux!x}3xzKD7__b`6gKi# zvm3Ebtt=Ggvd?6HfxGehh>vDJLc50D0?)=*bnf_|<`@n9wJ?Xjb)BKC! zcGyXuiw9N~3U9M7#>;Rcyip&S!#(WG^$0kts!;d|AG)}F{TA%p`(aoN^D10l&K25h z=PxUn2ma^mLt+PT)A+EG{R7mLdoC<=y_)oDr9z<?u#!E3)_{F2 zyQKXj{>}V}kp0q5Up?{b_#VifJmy}O%T4?b*LV&$$D>Mx!pAtd55vtcSF8av7R&vt zqvd*uuZ0$3CvdT3k7~Px-A_FGc_*!gcwPRLV(;=-)BC#jLG%&)7W8TOa-1=z{T*=U zE^C8_!}gLGu8oJ`o8)C2`}ixMi~HHe&m5twn$E@=}b{}@`G4-_BrJ5Cwlhh*ZN#b%UB=fZ($E$ z4`Dy$+HKhx?>X`gi8W!jfe(xJYGwJH3>AA{EHz~f{Tu#d_G!4aoDal4g{Ror124%* zE$zzB{`^5K`|dgV32mB*?c`5~+{-bTSyd=(mUj@g$myFr`Pa}dsIa!|>F)gk@n6^( zV+VFiZBs`U<304$pi9_WX}^p2<&Wd%J{z%9cdx`P`Qyafvm3)~_B;3!$hH227fa?$ z-nH82dabp+l%4wiAb+-TJuN4-ZV_Z%e;3PooGM;RZgoiQo53GuFP*KA8T{w?`=N{Y z5`IIRJ#>b6_I{B)OFY+Y2U}_LX{p6e7u8SK>yCHRvM%qlo4~#7UHC_NnUkLU2XSi7 zAZ@Zwrj*M2XC6EEKZ8Dqz1#RkilzND{+*CBxe?xwGaWyI58x%0<+F8j(KAwB1wZ5Z zQEVVC@f)%?vj?zS$!W#@khT_g#MNmN#k#Qf(z@#VA^v0R1~3$A?DCwGm%sI;LZLT% z7&M1mJL{VH>xL#$Jmz;mb;DiM)1|_OO2c=C@&whCFAE%Db30S1fy{ChZCN z4~P%Loy5*y|42UzUxDk`@6G7b*?ErD(KdTKwWWeKxXL+XA7rNa;(tJ| z0mtQL?jOdv{`cZ@`5)tU+8n@n4rVW1$j)5%#IL!}%jJDkDnIu>!KaI@!L`Nf!Rg{z z>v!$ZA?&~PF@oM%PDgqzcG_QqPu6Yme*EA0sa$GR^Oq}&^fa~*n z=<8bJdyD;o>zu$R|HQAXEI*e<)7lox{T`R|5@g+Pp{>Lp;`OxDtW9#8R(O`v@8f5T zci=wsz4pr#c6YcDn!s3kf1J8CQ@fn!)Tm|rJfkM*_kQ++_*8z*c6&(OnpSkis?4o6 z|6zZteQNOnNUhIaJ6k++^eX=?*Ux-kq-}Hh3Gtk#*I+VapJYv9FDg0XP)p8ic!+kd zSk8V6cE(b_N?pTqb`=V*nET#hU1=Yrt(@BIe(a}RcM1C<@gmNCUsPGH$@OsF6UK-w z!h6N$XuFI)7$-g(o}_=qKjK<-qIG2NgvR^}`KbY~LgqYc(Gq{b zKewbW?em_WeQ`E^8Rr@=;F|Jg(5_`v@*G=?b0!m8VE(iJ=lZrtd<8u< z|2Xc%|1B|gVtJHv`!adg z@N=dQ^V3J_;%t5mycF_Yo$H-N+hzaM#NWuvzFUv8FD_#bkdvAa`zP1Tz3*|&%yaJd zfLIsyICf&s$jR8JYnQV<2G8VY3@6t{-hJ%cZ+o13$iB$g&ig|4esANM=o$lY+AeWl z)y3ase+sFIdA2=A%Q}>BEu3+0Vs9;}3H)|=XjS=JF86vmTwPS-_4_${idfEQf4Mv8 zS&vg`&)_`key1hpJ9d!RnPRCWsbQJZ2K;8w118A1h}J=Dr7^#SXG8XU0WZWa;UM1( z-{I7U8uZ(=8wTy+J^poKr-frJCL9yIdU3TX9ZP&Voow_v< zZjKEFIS0gY{=Q&$h2iu!v^nfr^VlnC?O}zSzvxF`wLSKbSoTWJWY#{{UXSxV(GQQt z)o8!4KXr{u+3y$2XCe3Z|MqKr_CftM6CVgU*V&s{+tjF5a0qg(oa3CQP5cJx++5dK z&9BeTJJ7@C{aOA3s7Bjc`e*H`%ID`?+5xy-yZU0;lS}bD+6fp)&pp0F%lB6L&-+=P z^MA0{iS31)hg;=j{+^askADZA$p2P+Jl@8?gWeIccT$rZiOsJne{+wPlk=6e%HBB# zMncY8iIy{P3hoK3^|icM-q+a&vuT<4srnp_bH3gYpDr)&d#QE#dq%8eZddbffKMTH z>MHs9%N%vnCTH;7QhDw#(dT-81F_VbydUKEzv?*8_w(r4M}Nv6MthIH1hTg~(R<@* z&<=8@e-_K07=UNeQa`RJmFM_0@k#nleY{01?*t!-{Um1sKmCn%em3&=LF(9XSX?qD z$bPAp9LS!^9CpW9=f_-g4Za8J$?cB&&|0%ga+bK}VEGrYzo2zy55iseP2^5xx1;Tn z_qRFh$Icw&{5Qp^k4vG*Zz*S~HrZ3_AkVFgH}9dnwKx$aBsMwQM391_c!8AYGZ?gXg^lVQ7lelL1o(8p$W9oiPT8F%(* zV{P)znCEr+xLiEvr#&7_`xPD*pGaHCu1?Q&PSB>}?D-#Hj#v+V*6&=n+_k1Sk8Rnv zh))vFx&6i3<_zAe{ayUbQRX{yeI88JF7FIE-%sGP#quuJn_ruqYk#eewPHEzYex5f?*q@2l!6W(ED|hi%7(;4Z_FCrlPJYhC zFxbP-zU%`H`Ki}A|2dNf@MrLXc75OleVjvYF85mRs`ukh_!r8_{NFEU6|8qZ$7q+z zPtEHuwh?}$?~vb*eJ3<#=UI3yEMrfs^1X=X!X(<5V!1A5k5ATDYGCHFmss{rYQiA; zMq2c=dlK?DQtV^b%d;!*z5VE^Uzv}r;qT^eA%8{D9)r}|=3;;2)TLFhMqXEbo?{vJ z3FyVYj-U0qU3@pcQAK$Vw`8Xl=HAc3-{G9K#`ukr^^yNNUO?*v+vr!|Hng5%TaD?c zzBWS*eue8S!foJHu^Z&4&(!Ly&lvH`_*Yk!_iOsdeWYGbVqZxg&fhQhDt7u<0jc}B z$JDL-Thv7HvFx$ z&mHA=hwS0(%MJ2&@z>$rRpsxL@7Z_4EVxNdPwi44`xz`eG zE|uS(7O*cA%Y0?5c@Cw|?7#o)GzLig2yH3MEP1EEdB4j&r@w+WS-(f*^`d7S+wl>2 z2A0rA(e9C#+WZT9AT8s#PW%Ge~0} z&2wiq`#ycllXHN7G330zOn+Ycrnmz?-=e7}*^8TLTjeEZj<)&hBAzkKp=Yn8%}$&; zp8b;R&Y(@fb+tL2U5{Q3*3q)2ec9C^^WGLNVULA;w?CszYTdX>dlnjL(@ULfNNWuf z#b3gCRy6ttUl~(>+QCX^3^J!V8-Hkf4ZF6yuKY>-_uxMMa(uhkGHW_PSy^b6Uk zYe)GRckX+h+?QSBG+Ism4u0062ou?PkH|GMpQ9n`K8PO9+~&+?JySpbgEL3(;`a1C z`kT$44^#Q6JDu-V63a}R@DJ2k5%`(v>W;2L&kdTVXg@LyzSt*h}nuyfWjZ#jP@_NTO- zyUOEF9eJO>NW10g`4W7yywr=yVhbVrK6B8Fp7Gv-^IXdu_JFMMrTB_c`8^_QnCma6 zEyZWbUm&Lw|5^UYXCmH6JB0IWoeCLCYIQT(moUDve9voXXBRy?A@`knlKUDe{uOO8 zUI}&CBg8Uy*-P^v&y<|;TI?6`0I}5KS~z=SA#D(UnD}7+6Z}VDzWM%%o^PCPu9bSw z5~ud<#!YsWpW}7p=6-SpQnP1>XP#@rWAsJvf}D(hEuN#@H1?B_eLYfMzI{@gXUKhy zUW5J_WUl+rcj9(Xq)mYA`!nTcf32Wzvp;syF4j+KYOeW**jRR3_A}aa#7{vFcAn=S zS6UPKsVQ^m+25b@k81aqcn|T(@CK~~&bfbI&hV=8KKqLOqy5sD{~$YScCEHG<(&m* z@+aU=isd!0!~T?(b5e&^qsn;j4rA|(kI?JVa(?TwN8=CVXAj&6sRijHdoc4IKYO}1 zPHoI{@{oS|@vG5?;KZ}e9bu{1HMHc`=ilVMa#nI*f8rM*ex3`bLgqblle7CZp3bj} zYx47svY4HA`S-T9u!?qqKi>UhkECwooog-aF1gidGa=Xil9q4L1N^g%Bm1Qf{Z94+ z^!)9WQ;mHG{eXV1##iw(uV0C^XFuSZVISi_j9iEw_PlC`w`ejs|VfKBP+}CSNMbZsYBTVZ{Y0RpUib?>lE5voOSJo z`$KADFMdy$0wegT2iY$JGzqr?Cgye^e(h_*aLTY-phTSUqRo@&UvaqAA@s-t`lpDPr*Yg%fJ8q#6Dgs zukU(1nZ5xpfI~%ROMdD}p8vg}4!;xs4ECw)Q{)_1m<&bNQ7LT9r@|@WzaAQ$i z?0Z`N=E@mMe*hY=GuPQ)PvGYK3VQa{2>H*^s{NzCNzZp=1K3L&1s}tO&RJ)q9$OQGO51oaH;^FT6o4H7LKo<_z^`Z^1LfXW)sI<@p`J?kx9VeiJyV|03;C$hdMI zdg1*2u5G(wxt2AdrQMz4Z#$0%<>#y%kdx=y6_9oA4v(@k)|~4%wZ8!8Y*d5H{VTKr z`{db{-;teX$6xGuVx8n&vMjR>Qhx?oX9_gZNT*##^NSm-ZCDj@+f}?3*Ewvlf36 z)a74{`~VHd!-R=Al|P1Vc5=JR#l#(3-pm%^n`Pl^|?_j zd;MN^D>=*fc{b$Pcm^DxXU}K9Wqp(LGduI1-&Tq^HTZPM{_h57R+=BNPPFIQsh@Yk zM6u-NdfB(z<-BMt7xNdvX)ubOb;uk)TUGu&Gc|BJe+A_Fuk-Jdo3VDl$M}~)WBpxB z%l>%*dWr49dtm_kp{nvbLSuHuotpI+WbN~poFV+qcoOueGXL79j^!`!&3R7LupT+P zm+|xLsEZr%lfMOa^7}&8?6=ZR(gvzBYvivweSo;JuhQvXh4AE3Q2 zmUA|dzrpx~yz6n!P4-&8C2Ptpuv1U#Lh}9#9ZPDa{CQ$?*aK-< zw})_T-{z07uhDjucxprDaS?wjo&#NIr||E96LR~|vMzPBKg!R#JTIppo_)H?cssfN zcsaS=lj7CsX}gp?0Plc&JKZcN_nH0q3!aO!RwLjW*d^YtsJ6j9V$Eq;himZ5w9lZr ze!k=X+@rk5kMcAB`8R~rluhhE!nWTEOnNj+PNR{&h)i@n&7sT#(>|a z=e}|lZ^6CcP3;!&7n;|-cn?m!snAyo{IU2m+8w~zS?mCp%ipA3 zo~uXXEoQHy?U(lzzb-xFUk|CBPvh~l8gLpr@yzQC{JFOC*#p>H*vIvAn4MbmQ^~#K zpXpm;!-sM&qrHv4fSl<9&Um-tlW|&6PWI>Dw9WWu7(vhYuVBxjAB8+m@=nx>J=L6K zPpsqT+^;H?$C~q=^VX7fKCL!nj@Ho5fsgoE!?7@wy%^>d%m3c7nYNAJkv|UFYnT7t zC3EpUKkt1@*%{Xev7xjJVKYBgGd^oN3rA}M-|1sFrI6dRb^C#<=Ip~73mf1U5 zgSEIJ|0j0pYx-CPncwR4x+Uk-Sh7zSv8T`$v(Luavzf>2lg9X8^4^eh9s4vWRQiUL zdp+%IcoNn^`pEg633*mO44W&RjbeGPoX*dj-NV0AejWB)$i7HjYQVk*dWz?L;2k_% zEZ+gShi)*MKTPZ<$nQf}8^xnfgj_tE~t&c83^z3P5`{xbguT=N9%re&V$YM1-{nmq}Q!TWG6d?`QAzY9v{ zliig)0PYietjhPh=iCDJYVpaq8SMjh#@+<#%RL1?qUSv!>vkIqr9D|RFT2z-W4v0t z4SXYZBhKF2DxRA28=hEct=U82PWEQncu37#hPTnPpZnnC*2fk249M9UET^ydH0S|C zUAH$Z5Wi0DG4^NBoxgxTMg9`p5igXV{jp0w(c7{!rgrp<^{tBXeNU%-&CYLl^Zgc> zI+5{Z{c2R2XYnnrF`nH37D4vR?&wvXyV@MnCVL@u_$!?Kvq<|@_y_*?>{Hk;$SYKq z|E(o^AbWK{n3i>UfuD1ceYYP^hSc{5a1Zd8kqYj6R36Z-_aCLEHV=g1Cr>fq(#7vj|F1^7Gqe~abk<|VXyU@PrPedn4LkbUtJ z&b?=U&*o>}T*e-&kNo$5C;z>QcDw7{njC2Ax|hIe_7HhF3wi(OgB!su^z4maan5{a zcD|=p@$;-&%HPL+4!W5C)T43iH((50sjnAto|)lhr81|_eXtT1!^Qfp&|l6+{_T1| zv0QiO^6xPA2gQ1L&&+z~9Ly6>ja=Y*r{g@Y8oO53`*waU`egVVuAq4X5sOV`trR7V`W!$WLFHtE0c z3#Ib!Hl4)Vy7njk_%p82C>ojcY6NY`PJ;y+-3af^3IhvmNrx@=VX<2 z9Z&z6-%s9T_6PLryUqA){==}lhrWHs6zO&Ncf&04!;tfp=jc_iU(L&Yxu2fe*qD~H zIsyMm%UWj)=h52YVbD0a^pVhkJ&Cqfd^`JT_T%iX?4IK5*nMf|v8TYj{NZ9XXgNoX zAbro^AEBquN8DHDYC63xJ_irNkK%K1+Rdr3{w4dBo%_z%53`R$_IA$SO|-0WJ-M&r z9{gYETky$oRXH=Rb%J(<_-0&>|0Z4{mS;fr@)mP_BwDHb+dw1SN*}2~`KIbBmUVc{ zb$a8IYp0({?Cjegq1Lid=KA+vo~_aa<;F6>hy0Qd%i6{^R!o9S9aFAvwm0M+~@5O zZ6D%b*Fn}L z>$?~>%WsE&mh-LL#$vM}{pP!*K%W9TjrTbI9;Y@u1=*jOkBP3QZRzn*yBooj^K=8O(XM~d a{K|VCzr+7@r{6{Rx#ypEmEXN*RsA2Ns`=pn diff --git a/examples/performance/dflash/noise_embedding/gpt-oss-20b_noise_embeds.npy b/examples/performance/dflash/noise_embedding/gpt-oss-20b_noise_embeds.npy deleted file mode 100755 index 3c0852ffc0223181e7fbba2b14ca5cbed53dd307..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11648 zcmbW-iQCs>+6M4nQCbv5N}@%il%=fI^E|CG5eA7DZ&D(%l|4(lw5yb&EZG`M$r5^& zYQ#_?Gug`0j3w)2Uxw!WoX77!cpb-iT-SA8%YA?E{r6<{N&Wksa#pp%HHA@ahFmam z@Q5~@TeRsq^oTa?TeKN^`G~7V47zOKNBD=uIkba?v~9gjKY zu(s`6+|=U#_g|etVSZ(y@N8+J@Hx4hY|^bz_;gjFaA`%M@JvaeFqz)FtWc<;Hz3dE z9}heD|1AD$^7ryW;cr!i!hmD>#VioJne5xOPjsP1d@Ida}3TZ@@M4wth6w}8Gwt(Rfw z3xz`4%0i(t%!9FFc2*eG2pSNnBw`-rQ>K9X-GJ8geSM&I7We5@=K zPUu`H+@`%NwSNfz814N3?C1D4k&WeKeCxtCI0L@`Ucld}ofcwFlrxd9+KNJ9 z5BzERR<71c{7m{6V&0)=jqDGp`;PcWp&9*EawB;n%ocwaS*Nm47>2(ZrX{b!_uzK& zjp^@y@%7AoPkcdJkCS8hn^zSIhlwA^KZ$)hKI`&nG0V;U!*Yk)si#?eu06mu)zv5+-&GyYoqoAgI;U#QiKT#MU>uNwId-`nc$B==fe z!S@UxzZW-%>>E2I^r6NxgTE*W@>j0 z`@iA`^Y0?IH~xRfvA7fYvJbz4y6R=mWb7X%KP0mrcNM#hd>J=XAAcbasw%F_h32_J zy`#lUDse_ZeRc1)eqOW}9^~5{-qVl6)u=7*KytFUX6jtS9?oAAI{Su$$uayJ5~Hnl zYQ02XF8=RgC(!4~`#}4n*q`F3(yME4IDM)7diXE3-A5)zB;VCt8e@%t)a(5S(`LBVmVBaAo z`(;1=!7vtXz`rJDKXsdmTOucG=tlTmyXTR6TW90MJ;ir#X>rdCA}?lVZLKlCpUOR5 zou&9TVtbO=Pu<}?agSFTL;7p%w9q{4fOtgfv=WCOm2 zWEXsWZKUSYYM)DQgh0mz5ONq2(p%(-g2`h@@D$es^Z=`2R~6e z`-&S%PWO$ssxd{}mwcJCoS`>&)hBr?U>&}L`rq(9S5hc^OwLm;_g`<{If}kS%-!<3 zi_f0iP265;=ZroOpL1!e_AR7=^|LX5`k!-g92_TaCE1y8viJ_#{h7@B z?jU^i(&D@NT;DYce;DNKe~E8QmGfh@KH(0<)upc%UxvR7SKmC=Syg;H4aPsEW`1kt zyw1AXh|9j2BDO=hyAEFq_YFO7f!_2R)Xn*t->28J&lkU)J^(JJ=T7*HY|ZzC+AUy) zoXk~CsL$TO?ki>;Os{kgm-q(XSphTkG29})oMEl?w>LRY?(yQ<^1a0OxIVRDcc4$k zb>Pq4*oplZ`7zvES)9w|G<@uOx*7H?1_BO z4)%5QJ=MyY^Q->0)VFK!Ti6@q9?d`5xOXFGle@~>if^ROOZ1Lity%He^ZAYOXTC{% z2e9kGKwSE=T>dC>t2T0PH-qokM;O1?<=+mQ=xfz`PHad1?{WRf3&i!H|Cem4f2ZRv zW#21)cgVS!_TO|4#XbY}Qez+dljJ94YHpUZvds7CW7d5IK5IVbYTBDi<}B@qn%kPh&a-WcQF@G~$yYM}{H^tl!HTjNIqXC>npCf;)yvJdtv3N%ON^#TV zXP!E@Do9nZu1|UG^gOB6YHVpNF^T z>&2gmkA0q)SIzS*c5`+g?frzCFYaJ+9PT~H`lzqX=gF@4{`^@RSK_9~{|f)TxU9R8 zP}7^MhPeE8zD-{?tulw=K2?7z+y@=ixSKy~ZG_mozYm1%(1bs4t(=uR@#hV=f&LIZ z@64>Jc4W@C)6}?I{TswQ4o|}_>gCM8Ltj?V_lK;7XXMYsE#(~}?ga7K&qw2SVc&-TsB7`gt|_O6I!pL26VpfRH|(BbvUZ2U6Jqb; zy9wI!-vMc_xmr`@-ESUV6@Qjm*{9F52jcTv?KJs^$p4J|QBL;x)qE$2&HDP0J%T+R zPOLD_xUp)q7Jqy<>qDD6%ksN#CFkS&PGP(_x(4VffwXXNY-@JyPyNWySmY3%)MmcjeETo(1Dz zXZteuLO=DsCyx=caHakFLb3K$xWl2(>f&4FGWj{1=d+ure>u#M^M{;Uj6=mL<3p|& zbBaFnmYcEP2cE?@mcKvb9rhqu2E8iGNoBG9PU0s(2l?Bzv5UR)V0poXTs)^<6t`Ae zU$Zms`Hen7?0?An_?{@SugTo=8|5v7cJu}O)$}dz#c}-EJFj#%hH{(WKf*P_wTcfn zmOh@mgngWxjO#eao8k|Be+a)u%;moENZdNUjNxuo_6FZ=YOWLa2fi&k=T+V~OCjHz zGw@GjL-N~F@1*CoL*@)!LS~+SC+~rr?K?rv=wVPLuN`?coMVgE|Dl{(<~i^4W%QgWxx0>#KOJ8yx$Md0 z!~B<%%^>fLTI^-?#if4Fkh_wZ-8?FjIf#;a(MUKUr>0FA;Ye zU+%5%@%!Q*#Lc7sN!}Iotk>oI$BNrbUt;|JM!%15D1N@!|7AY|59r(KZpK<$r|{=4 z?1}#oB4>3iuG6J_bJ-uOccGZm$s@{)JAJgV$s9h;e_N$D17v^i#y$j}zHb)uEqMWZ zIQf=(Js{`kp|D!5->s{W`m?9JY2rS?pQqOA?8Y#M-cy}*`a4H|yZWxz#B}6auhxa+ z<*-riWMlOQxxc#U|2XwV;4AUJ!aV)SJvkj$n>-!26?Vd}lT#1+vB&GfHDYFx zGs&&?XJ7jLtNd+1-HcCumvrP?Opb(qR21Jk^KrHKvlm;!1nX=FdA69%Vy0?0e~0Ld zJDI%)*<1V)-*60lUwD{5=frIK2J&Gt=jhJ#ylw7f&w~6L${2NXZnR@J2w&>{!eyQ^ z9tYt@$jP}iSlmy>dl&7UMn4}n4@R-CD|L?JYRJvnA4M+iY8{Xn;~9K|jBRys2ji-= z-CNzIWcru;Zmall_w)`&R*a6930NHn|A?M!h6~*5~%lTW-%kb0b+v$U~b2xo5J^!wkzxB>$SHdm0rsnh% zdV9!x9Z2RreSw~_enI;$;xY#f*pJZ<^S$|7!jua2#pQ46HR%2M&(P=GtIfoHM~=f^ zDmUMLs<^yU9uZdqCW(EC>qYxT+O>xb3)1b-O22`pSytW{IYdiF_h0evTJ{8O!8$@$iI)@9E8yw`t} za~|0T@_y)2S$qo=*sbZ4dd|JLt=)XJXR#-Fq3P5#$$ACPCOR}X)HarlOO z0lt-UrMQdfSuX{?J8?tk$N7#WeEC~Qb#{4a@i)(1*2PHnu`rl_hPb?I7V&q7ciAuT zmC1XWyu3m_t~+^fS@GVUtIz$}8HY)7vj&EXnT0E-cb3jNik)t%KOx**TBWugUCA;(rr6j^5E* z`WOD~^!F=@=j@-w^~F6S<|y1x^q$>~k=ncB#`@kP)yw+2hd+0~&5--&L-s$&yR`9V z{;bn3^54V1M$g~gPp!1paQl%Ni%aP(wK>^XzL^|+_T?C8DKGo_YjVBzvyR^p|AFxy zgIgjebM-ZQ1-S=g-2Q28R7*eUd2?=nm-sr8FT$g6y*9^HI$!WvA75&FI=M|djZ5^! z_&tex6b9p$LH2U?-)cEek`3WdeaL!lj+>?*w~`wndpp0w8}VfwHI_RJHqmcT^Zb(H zZ@s_L^Sd;EJKc$Xni_BE`}^YR<1_!`;WBnpXiHyC9!O3pbB<^~=gM5j-+3w^<9a6K z9lNJ?nv2Qaxl9|mXNI#AUjwT14<)Y?yHvmTVCTF0>ie@|Z-fQ#8f@dso8%7bYo&hH zlGEMVsExm!K1ocz<#Kjg`u~!{@z-iM=SSwIdxhHMNqjjobJvb%&k!?!-mXjW+|KXZ z%=@|IZv2^>yzi#+zo}+JW0X5-n3$aZv*_FTpH^oXJ9ox0rTQrDCvtyr+5f-D&%Rux z?Yq_73!mTJE#&8ZTZPLS><0PULHN<7+E%X-nSG!0IpdghGzhjr-tvbyYgeiDu9}}$ z;mMwSQ_0*-Ewz0eS;zRz#$BziZ;J#g({U*c-(4*Va;8-b|0-v)^|mbKZ1SJ96bvv)Ut(<4tMk5D7`?{fK9(LbbTKb%13y__}{sCkyyoPpcb z%)jeYtFRB@Ay}$U*YnMz-y{AT@~JZ8qfRw-b4TW_b0s_TTtzNZFK^!A@}{V{JNr1k zQe5_R?%=P~Xe~c;^osoKt6qFP$Se4&)5pNc`uP@_J-r5hlUg;%tk<{6G5Xs|PTsM{ z;aZ99P4ybWztVoAXDwF~d#$`fE38Lyw)~!abHt6{AHv@mKZk!Hdx|=(`JN>| z7rP7pP#Dj*6u#oCBX+a)+u<+7tzoaGXMaCbS^VwPt=!q9#`EGQ@~6&oeCOh_j`w5F zrTwUHmA>*^<4M_eS2hxpRK-jcderV!y(k3G?MXh07Yg*|?1> z^)~}(q(;ua+|LcfWuND}3swFtfG_iucikbzeqa0$dhW#h?Podv@JjQGTP*G#{L|KB z_V<1CEnTfG`px!J#`;9|#rUhVUyW>~X6ASveHwWPShcF zYrCGfJNZ&$1>`>Z9hdRD4D#mNN&LlpIr|&Q*;BiTIhg-`e9pA_@w3PWe677J^9Jk#Gx(~so%zZ*H;}it z`nQVBIXi@%JN{TWkN+I;Tk&=2ts&=f&iix49SuM5Pb5olS?gz&IorrZ_}<1P>w77^ zjK4iu1s{~S#~^F(2{<0I*Ei|c9Tm>S#P%qj3v=i>le3N+iTQ}!T_3VW^Uld{m0z_n z!?$(Q-htL&YiCCG%jx|2{S&<@|C{2n-e#(~mTw_UhuheXX!~cmKk+{*CueHLejI str: + """Map a user-supplied model name (short, full HF path, or basename) to + the canonical short name used as a key in MODEL_MAP.""" + canonical = MODEL_ALIASES.get(name.lower()) + if canonical is None: + raise argparse.ArgumentTypeError( + f"unknown model_name '{name}'. Supported: " + ", ".join(sorted(MODEL_MAP.keys())) + ) + return canonical + def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): if num_draft_layers == 1: @@ -50,6 +104,220 @@ def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) +_TARGET_ABSMAX = 128.0 + + +def load_dflash_checkpoint(dflash_model_path: str) -> tuple[dict, dict]: + """Download and load the DFlash safetensors checkpoint and config. + + Returns + ------- + state_dict : dict[str, Tensor] — all tensors in fp32 + cfg : dict — parsed config.json + """ + bin_path = hf_hub_download(repo_id=dflash_model_path, filename="model.safetensors") + config_path = hf_hub_download(repo_id=dflash_model_path, filename="config.json") + + with open(config_path, "r") as f: + cfg = json.load(f) + + state_dict = {} + with safe_open(bin_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(torch.float32) + + return state_dict, cfg + + +def extract_lm_head(model: AutoModelForCausalLM) -> tuple[torch.Tensor, torch.Tensor | None]: + """Return (lm_head_weight, lm_head_bias) from a HuggingFace causal LM (fp32).""" + sd = model.state_dict() + weight = sd["lm_head.weight"].to(torch.float32) + bias = sd.get("lm_head.bias") + if bias is not None: + bias = bias.to(torch.float32) + return weight, bias + + +def extract_embed(model: AutoModelForCausalLM) -> torch.Tensor: + """Return embed_tokens.weight from a HuggingFace causal LM (fp32).""" + sd = model.state_dict() + return sd["model.embed_tokens.weight"].to(torch.float32) + + +def build_tlm_model( + base_model: AutoModelForCausalLM, + dflash_state_dict: dict, + target_layer_ids: list[int], + target_absmax: float = _TARGET_ABSMAX, +) -> AutoModelForCausalLM: + """Attach fc + hidden_norm to *base_model*, inject DFlash weights, and scale fc. + + Modifies *base_model* in-place and returns it. + """ + inner = base_model.model + hidden_size = base_model.config.hidden_size + model_type = getattr(base_model.config, "model_type", "") + n = len(target_layer_ids) + + # Add fc and hidden_norm + inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False) + + if "qwen3" in model_type: + from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + + inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) + elif "llama" in model_type: + from transformers.models.llama.modeling_llama import LlamaRMSNorm + + inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) + else: + warnings.warn(f"Unknown model_type '{model_type}'; using nn.RMSNorm for hidden_norm.") + inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(base_model.config, "rms_norm_eps", 1e-6)) + + # Inject weights from DFlash checkpoint + fc_tensor = dflash_state_dict["fc.weight"].to(torch.float32) + inner.fc.weight.data.copy_(fc_tensor) + + hn_tensor = dflash_state_dict["hidden_norm.weight"].to(torch.float32) + inner.hidden_norm.weight.data.copy_(hn_tensor) + + # Scale fc weights so activations stay within fp16 range + # RMSNorm(x/s) == RMSNorm(x), so this is zero-accuracy-cost + with torch.no_grad(): + in_feat = inner.fc.in_features + max_row_norm = inner.fc.weight.data.norm(dim=1).max().item() + fc_out_bound = (in_feat**0.5) * max_row_norm + s = max(fc_out_bound / target_absmax, 1.0) + inner.fc.weight.data.div_(s) + logger.info( + f"[TLM] fc scale: in_features={in_feat}, max_row_norm={max_row_norm:.4f}, " + f"fc_out_bound={fc_out_bound:.2f}, s={s:.6f}" + ) + + logger.info(f"[TLM] fc ({n * hidden_size} -> {hidden_size}) and hidden_norm attached and scaled") + return base_model + + +def build_dlm_model( + dflash_model_path: str, + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None = None, + embed_weight: torch.Tensor | None = None, +) -> AutoModelForCausalLM: + """Load the DFlash model and inject lm_head (and optionally embed_tokens) weights from the base TLM model. + + Also removes fc / hidden_norm if the DFlash checkpoint has them. + """ + dlm_model = AutoModelForCausalLM.from_pretrained(dflash_model_path, torch_dtype=torch.float32) + + with torch.no_grad(): + dlm_model.lm_head.weight.copy_(lm_head_weight) + if embed_weight is not None: + dlm_model.model.embed_tokens.weight.copy_(embed_weight) + + if lm_head_bias is not None: + if dlm_model.lm_head.bias is None: + dlm_model.lm_head.bias = nn.Parameter(lm_head_bias) + else: + with torch.no_grad(): + dlm_model.lm_head.bias.copy_(lm_head_bias) + + # DFlash checkpoints occasionally carry fc / hidden_norm — strip them + for attr in ("fc", "hidden_norm"): + if hasattr(dlm_model, attr): + delattr(dlm_model, attr) + logger.info(f"[DLM] Removed dlm_model.{attr}") + + logger.info(f"[DLM] lm_head injected (shape: {lm_head_weight.shape})") + if embed_weight is not None: + logger.info(f"[DLM] embed_tokens injected (shape: {embed_weight.shape})") + return dlm_model + + +def read_dlm_meta(dlm_repo: str, hf_token: Optional[str] = None): + """Load a DFlash checkpoint and return (state_dict, target_layer_ids, block_size).""" + state_dict, cfg = load_dflash_checkpoint(dlm_repo) + target_layer_ids = cfg.get("dflash_config", {}).get("target_layer_ids", []) + block_size = cfg.get("block_size", None) + return state_dict, target_layer_ids, block_size + + +def compile_tlm_qpc( + tlm_repo: str, + dlm_repo: str, + *, + prefill_seq_len: int, + ctx_len: int, + num_cores: int, + num_devices: int, + hf_token: Optional[str] = None, +) -> str: + """Build the TLM (base model + fc/hidden_norm from the DFlash checkpoint) and + compile it to a QPC. Returns the qpc directory path.""" + from QEfficient import QEFFAutoModelForCausalLM + + state_dict, target_layer_ids, block_size = read_dlm_meta(dlm_repo, hf_token) + tlm_target_ids = [i + 1 for i in target_layer_ids] + + logger.info(f"[compile_tlm] base={tlm_repo} dlm={dlm_repo} block_size={block_size}") + base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=hf_token) + build_tlm_model(base_model, state_dict, tlm_target_ids) + + tlm_qeff = QEFFAutoModelForCausalLM(base_model, qaic_config={"target_layer_ids": tlm_target_ids}) + qpc = tlm_qeff.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + dflash_block_size=block_size, + ) + qpc = str(qpc) + logger.info(f"[compile_tlm] qpc={qpc}") + return qpc + + +def compile_dlm_qpc( + tlm_repo: str, + dlm_repo: str, + *, + ctx_len: int, + num_cores: int, + num_devices: int, + hf_token: Optional[str] = None, +) -> str: + """Build the DLM (DFlash model with lm_head/embed_tokens copied from the base + TLM, and fc/hidden_norm stripped) and compile it to a QPC.""" + from QEfficient import QEFFAutoModelForCausalLM + + _, _, block_size = read_dlm_meta(dlm_repo, hf_token) + + logger.info(f"[compile_dlm] dlm={dlm_repo} block_size={block_size}") + base_model = AutoModelForCausalLM.from_pretrained(tlm_repo, torch_dtype=torch.float32, token=hf_token) + lm_head_w, lm_head_b = extract_lm_head(base_model) + embed_w = extract_embed(base_model) + del base_model + + dlm_model = build_dlm_model(dlm_repo, lm_head_w, lm_head_b, embed_w) + dlm_qeff = QEFFAutoModelForCausalLM(dlm_model, qaic_config={"dflash_dlm": True}) + qpc = dlm_qeff.compile( + prefill_seq_len=block_size, + ctx_len=ctx_len, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, + prefill_only=True, + ) + qpc = str(qpc) + logger.info(f"[compile_dlm] qpc={qpc}") + return qpc + + def load_and_process_dataset(data_name: str): # Math datasets if data_name == "gsm8k": @@ -174,139 +442,3 @@ def reformat_jsonl_by_category(questions: list) -> list: category = q.get("category", "") q["turns"][0] = format_prompt(q["turns"][0], category) return questions - - -_TARGET_ABSMAX = 128.0 - - -def print_stats(x, name: str) -> None: - if isinstance(x, torch.Tensor): - x_np = x.detach().cpu().to(torch.float32).numpy() - elif isinstance(x, np.ndarray): - x_np = x.astype(np.float32) - else: - raise TypeError("Input must be a torch.Tensor or numpy.ndarray") - print(f"[STATS] {name}") - print(f" Shape : {x_np.shape}") - print(f" Min : {x_np.min():.6f}") - print(f" Max : {x_np.max():.6f}") - print(f" Mean : {x_np.mean():.6f}") - print(f" Median: {np.median(x_np):.6f}") - print(f" Std : {x_np.std():.6f}") - - -def load_dflash_checkpoint(dflash_model_path: str) -> tuple[dict, dict]: - """Download and load the DFlash safetensors checkpoint and config. - - Returns - ------- - state_dict : dict[str, Tensor] — all tensors in fp32 - cfg : dict — parsed config.json - """ - bin_path = hf_hub_download(repo_id=dflash_model_path, filename="model.safetensors") - config_path = hf_hub_download(repo_id=dflash_model_path, filename="config.json") - - with open(config_path, "r") as f: - cfg = json.load(f) - - state_dict = {} - with safe_open(bin_path, framework="pt", device="cpu") as f: - for key in f.keys(): - state_dict[key] = f.get_tensor(key).to(torch.float32) - - return state_dict, cfg - - -def extract_lm_head(model: AutoModelForCausalLM) -> tuple[torch.Tensor, torch.Tensor | None]: - """Return (lm_head_weight, lm_head_bias) from a HuggingFace causal LM (fp32).""" - sd = model.state_dict() - weight = sd["lm_head.weight"].to(torch.float32) - bias = sd.get("lm_head.bias") - if bias is not None: - bias = bias.to(torch.float32) - return weight, bias - - -def build_tlm_model( - base_model: AutoModelForCausalLM, - dflash_state_dict: dict, - target_layer_ids: list[int], - target_absmax: float = _TARGET_ABSMAX, -) -> AutoModelForCausalLM: - """Attach fc + hidden_norm to *base_model*, inject DFlash weights, and scale fc. - - Modifies *base_model* in-place and returns it. - """ - inner = base_model.model - hidden_size = base_model.config.hidden_size - model_type = getattr(base_model.config, "model_type", "") - n = len(target_layer_ids) - - # Add fc and hidden_norm - inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False) - - if "qwen3" in model_type: - from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm - - inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) - elif "llama" in model_type: - from transformers.models.llama.modeling_llama import LlamaRMSNorm - - inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=base_model.config.rms_norm_eps) - else: - warnings.warn(f"Unknown model_type '{model_type}'; using nn.RMSNorm for hidden_norm.") - inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(base_model.config, "rms_norm_eps", 1e-6)) - - # Inject weights from DFlash checkpoint - fc_tensor = dflash_state_dict["fc.weight"].to(torch.float32) - inner.fc.weight.data.copy_(fc_tensor) - - hn_tensor = dflash_state_dict["hidden_norm.weight"].to(torch.float32) - inner.hidden_norm.weight.data.copy_(hn_tensor) - - # Scale fc weights so activations stay within fp16 range - # RMSNorm(x/s) == RMSNorm(x), so this is zero-accuracy-cost - with torch.no_grad(): - in_feat = inner.fc.in_features - max_row_norm = inner.fc.weight.data.norm(dim=1).max().item() - fc_out_bound = (in_feat**0.5) * max_row_norm - s = max(fc_out_bound / target_absmax, 1.0) - inner.fc.weight.data.div_(s) - print( - f"[TLM] fc scale: in_features={in_feat}, max_row_norm={max_row_norm:.4f}, " - f"fc_out_bound={fc_out_bound:.2f}, s={s:.6f}" - ) - - print(f"[TLM] fc ({n * hidden_size} -> {hidden_size}) and hidden_norm attached and scaled") - return base_model - - -def build_dlm_model( - dflash_model_path: str, - lm_head_weight: torch.Tensor, - lm_head_bias: torch.Tensor | None = None, -) -> AutoModelForCausalLM: - """Load the DFlash model and inject lm_head weights from the base TLM model. - - Also removes fc / hidden_norm if the DFlash checkpoint has them. - """ - dlm_model = AutoModelForCausalLM.from_pretrained(dflash_model_path, torch_dtype=torch.float32) - - with torch.no_grad(): - dlm_model.lm_head.weight.copy_(lm_head_weight) - - if lm_head_bias is not None: - if dlm_model.lm_head.bias is None: - dlm_model.lm_head.bias = nn.Parameter(lm_head_bias) - else: - with torch.no_grad(): - dlm_model.lm_head.bias.copy_(lm_head_bias) - - # DFlash checkpoints occasionally carry fc / hidden_norm — strip them - for attr in ("fc", "hidden_norm"): - if hasattr(dlm_model, attr): - delattr(dlm_model, attr) - print(f"[DLM] Removed dlm_model.{attr}") - - print(f"[DLM] lm_head injected (shape: {lm_head_weight.shape})") - return dlm_model From 7cb95a397887beb29640f977c2ec4a425316d933 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 27 May 2026 14:02:05 -0700 Subject: [PATCH 3/5] DFlash: accept comma-separated device IDs in benchmark scripts Signed-off-by: Vahid Janfaza Signed-off-by: Vahid Janfaza --- examples/performance/dflash/README.md | 4 ++-- .../performance/dflash/basic_inference.py | 18 +++++++++++++++-- examples/performance/dflash/benchmark.py | 20 ++++++++++++++++--- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/examples/performance/dflash/README.md b/examples/performance/dflash/README.md index 8b0181d307..8b6377a34c 100644 --- a/examples/performance/dflash/README.md +++ b/examples/performance/dflash/README.md @@ -51,8 +51,8 @@ python basic_inference.py --model_name Qwen3-4B \ | Flag | Default | Notes | |---|---|---| -| `--tlm_devices` | `0 1 2 3` | TLM device IDs | -| `--dlm_devices` | `0 1 2 3` | DLM device IDs | +| `--tlm_devices` | `0,1,2,3` | TLM device IDs | +| `--dlm_devices` | `0,1,2,3` | DLM device IDs | | `--tlm_cores` / `--dlm_cores` | `8` | per-side core count | | `--ctx_len` | `4096` | | | `--prefill_seq_len` | `128` | | diff --git a/examples/performance/dflash/basic_inference.py b/examples/performance/dflash/basic_inference.py index 5fbb91d983..79005c77a2 100644 --- a/examples/performance/dflash/basic_inference.py +++ b/examples/performance/dflash/basic_inference.py @@ -45,6 +45,10 @@ from QEfficient.utils.logging_utils import logger # noqa: E402 +def parse_device_list(s): + return [int(x) for x in s.split(",") if x.strip() != ""] + + def parse_args(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument( @@ -70,8 +74,18 @@ def parse_args(): p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") # Devices / cores - p.add_argument("--tlm_devices", nargs="+", type=int, default=[60, 61, 62, 63]) - p.add_argument("--dlm_devices", nargs="+", type=int, default=[60, 61, 62, 63]) + p.add_argument( + "--tlm_devices", + type=parse_device_list, + default=[60, 61, 62, 63], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) + p.add_argument( + "--dlm_devices", + type=parse_device_list, + default=[60, 61, 62, 63], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) p.add_argument("--tlm_cores", type=int, default=8) p.add_argument("--dlm_cores", type=int, default=8) diff --git a/examples/performance/dflash/benchmark.py b/examples/performance/dflash/benchmark.py index 1a7601d83b..49cf98d2ca 100644 --- a/examples/performance/dflash/benchmark.py +++ b/examples/performance/dflash/benchmark.py @@ -24,7 +24,7 @@ # Custom devices / cores / dataset python benchmark.py --model_name Llama-3.1-8B-Instruct \\ - --tlm_devices 0 1 2 3 --dlm_devices 4 5 6 7 \\ + --tlm_devices 0,1,2,3 --dlm_devices 4,5,6,7 \\ --tlm_cores 8 --dlm_cores 8 --dataset gsm8k """ @@ -41,6 +41,10 @@ from utils import MODEL_MAP, compile_dlm_qpc, compile_tlm_qpc, resolve_model_name # noqa: E402 +def parse_device_list(s): + return [int(x) for x in s.split(",") if x.strip() != ""] + + def parse_args(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument( @@ -58,8 +62,18 @@ def parse_args(): p.add_argument("--dlm_qpc", default=None, help="Pre-compiled DLM qpc dir (skip DLM compile).") # Devices / cores - p.add_argument("--tlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) - p.add_argument("--dlm_devices", nargs="+", type=int, default=[0, 1, 2, 3]) + p.add_argument( + "--tlm_devices", + type=parse_device_list, + default=[0, 1, 2, 3], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) + p.add_argument( + "--dlm_devices", + type=parse_device_list, + default=[0, 1, 2, 3], + help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", + ) p.add_argument("--tlm_cores", type=int, default=8) p.add_argument("--dlm_cores", type=int, default=8) From e468a15319df7e25808b664d473b17c812390f29 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 27 May 2026 14:04:10 -0700 Subject: [PATCH 4/5] DFlash: accept comma-separated device IDs in benchmark scripts Signed-off-by: Vahid Janfaza --- examples/performance/dflash/basic_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/performance/dflash/basic_inference.py b/examples/performance/dflash/basic_inference.py index 79005c77a2..712788af7b 100644 --- a/examples/performance/dflash/basic_inference.py +++ b/examples/performance/dflash/basic_inference.py @@ -77,13 +77,13 @@ def parse_args(): p.add_argument( "--tlm_devices", type=parse_device_list, - default=[60, 61, 62, 63], + default=[0, 1, 2, 3], help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", ) p.add_argument( "--dlm_devices", type=parse_device_list, - default=[60, 61, 62, 63], + default=[0, 1, 2, 3], help="Comma-separated device IDs, e.g. '0,1,2,3' or '0'.", ) p.add_argument("--tlm_cores", type=int, default=8) From 4ec6a39012318335151482c89dbaff015c966cf1 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 28 May 2026 21:58:55 -0700 Subject: [PATCH 5/5] Adding embed layer to DLM to remove passing noise embed requirements Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 2 -- examples/performance/dflash/README.md | 1 - 2 files changed, 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f6837d2e42..c49ce7125b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3153,14 +3153,12 @@ def export( if self.dflash_dlm: example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), - # "noise_embeds": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), "target_hidden": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float), "position_ids": torch.arange(seq_len, 2 * seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), "position_ids_target": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), "past_key_values": [[] for _ in range(self.num_layers)], } dynamic_axes = { - # "noise_embeds": {0: "batch_size", 1: "seq_len"}, "input_ids": {0: "batch_size", 1: "seq_len"}, "target_hidden": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, diff --git a/examples/performance/dflash/README.md b/examples/performance/dflash/README.md index 8b6377a34c..0a17946ca0 100644 --- a/examples/performance/dflash/README.md +++ b/examples/performance/dflash/README.md @@ -57,7 +57,6 @@ python basic_inference.py --model_name Qwen3-4B \ | `--ctx_len` | `4096` | | | `--prefill_seq_len` | `128` | | | `--generation_len` | `1024` (benchmark) / `256` (single) | | -| `--noise_embed_path` | `noise_embedding/_noise_embeds.npy` | override if needed | | `--hf_token` | `$HF_TOKEN` | required for gated repos | | `--tlm_hf_path` | from `MODEL_MAP` | required when the map entry has `None` |