Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,14 @@ def forward(
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
self.target_layer_ids = getattr(self, "target_layer_ids", None)
target_hidden_list = []

for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.target_layer_ids and idx in self.target_layer_ids:
target_hidden_list.append(hidden_states)
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
Expand All @@ -316,6 +320,16 @@ def forward(
if return_legacy_cache:
past_key_values = past_key_values.to_legacy_cache()

if self.target_layer_ids:
target_hidden = torch.cat(target_hidden_list, dim=-1)
target_hidden_fc = self.fc(target_hidden)
target_hidden_final = self.hidden_norm(target_hidden_fc)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=target_hidden_final,
)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
Expand Down Expand Up @@ -354,6 +368,10 @@ def forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

if getattr(self.model, "target_layer_ids", None):
output_hidden_states = False

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand All @@ -371,6 +389,20 @@ def forward(

# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)

if getattr(self.model, "target_layer_ids", None):
target_hidden = outputs.hidden_states
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states).float()
predicted_token_ids = logits.argmax(dim=-1).to(torch.int32)
return CausalLMOutputWithPast(
loss=None,
logits=predicted_token_ids,
past_key_values=outputs.past_key_values,
hidden_states=target_hidden,
attentions=outputs.attentions,
)

hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states).float()

Expand Down
44 changes: 43 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
)
from QEfficient.transformers.models.pytorch_transforms import (
CustomOpsTransform,
DFlashTLMTransform,
DFlashTransform,
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
PoolingTransform,
Expand Down Expand Up @@ -2878,6 +2880,20 @@ def __init__(
if self.is_tlm:
self.model.qaic_config["return_pdfs"] = True

self.dflash_dlm = False
self.hidden_size = self.model.config.hidden_size
self.vocab_size = self.model.config.vocab_size
if qaic_config is not None:
self.dflash_dlm = qaic_config.get("dflash_dlm", False)
if self.dflash_dlm:
self.model, _ = DFlashTransform.apply(self.model, qaic_config)

self.dflash_tlm = False
if qaic_config is not None:
self.dflash_tlm = bool(qaic_config.get("target_layer_ids", None))
if self.dflash_tlm:
self.model, _ = DFlashTLMTransform.apply(self.model, qaic_config)

def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()

Expand Down Expand Up @@ -3079,7 +3095,7 @@ def export(
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS

kv_cache_shape = get_padding_shape_from_config(
self.model.config, fbs if self.continuous_batching else bs, seq_len
self.model.config, fbs if self.continuous_batching else bs, seq_len * 2
)
enable_chunking = kwargs.get("enable_chunking", False)
if (
Expand Down Expand Up @@ -3133,6 +3149,22 @@ def export(
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
}

if self.dflash_dlm:
example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"target_hidden": torch.ones((bs, seq_len, self.hidden_size), dtype=torch.float),
"position_ids": torch.arange(seq_len, 2 * seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
"position_ids_target": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
"past_key_values": [[] for _ in range(self.num_layers)],
}
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"target_hidden": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
"position_ids_target": {0: "batch_size", 1: "seq_len"},
}

if self.ccl_enabled:
example_inputs["comp_ctx_lengths"] = torch.randint(0, 127, (512,), dtype=torch.int8)
dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
Expand Down Expand Up @@ -3250,6 +3282,9 @@ def export(
qaic_config=self.model.qaic_config,
)

if self.dflash_tlm:
output_names.append("hidden_states")

return self._export(
example_inputs,
output_names=output_names,
Expand Down Expand Up @@ -3334,6 +3369,7 @@ def build_decode_specialization(
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
dflash_block_size: Optional[int] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -3377,6 +3413,9 @@ def build_decode_specialization(

spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None

if self.dflash_tlm or self.dflash_dlm:
spec["seq_len"] = dflash_block_size

if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
else:
Expand All @@ -3397,6 +3436,7 @@ def compile(
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
dflash_block_size: Optional[int] = None,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
mxfp6_matmul: bool = False,
Expand Down Expand Up @@ -3612,6 +3652,7 @@ def compile(
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
dflash_block_size=dflash_block_size,
)
if decode_spec:
specializations.append(decode_spec)
Expand All @@ -3624,6 +3665,7 @@ def compile(
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
dflash_block_size=dflash_block_size,
prefill_only=prefill_only,
)
if decode_spec:
Expand Down
144 changes: 144 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,18 @@
QEffQwen3ForCausalLM,
QEffQwen3Model,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import (
QEffQwen3Attention as QEffQwen3DFlashAttention,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import (
QEffQwen3DecoderLayer as QEffQwen3DFlashDecoderLayer,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import (
QEffQwen3ForCausalLM as QEffQwen3DFlashForCausalLM,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3_dflash_draft import (
QEffQwen3Model as QEffQwen3DFlashModel,
)
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
QEffPrefillChunkedQwen3MoeSparseMoeBlock,
QEffQwen3MoeAttention,
Expand Down Expand Up @@ -954,6 +966,138 @@ def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -
return model, transformed


class DFlashTransform(ModuleMappingTransform):
"""
Replaces standard QEff Qwen3 modules with DFlash-specialized versions when dflash_dlm=True.
Applied after KVCacheTransform so it operates on already-transformed QEff modules.
"""

_module_mapping = {
QEffQwen3Attention: QEffQwen3DFlashAttention,
QEffQwen3DecoderLayer: QEffQwen3DFlashDecoderLayer,
QEffQwen3Model: QEffQwen3DFlashModel,
QEffQwen3ForCausalLM: QEffQwen3DFlashForCausalLM,
}

@classmethod
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
if not (qaic_config and qaic_config.get("dflash_dlm", False)):
return model, False
return super().apply(model)


class DFlashTLMTransform:
"""
Adds fc and hidden_norm layers to the inner model, loads their weights from the TLM
checkpoint, and sets target_layer_ids — activating the TLM hidden-state collection
path in QEffQwen3Model / QEffLlamaModel forward.

Triggered when target_layer_ids is present in qaic_config.
Works without modifying the transformers library: weights that were silently dropped
as unexpected keys during from_pretrained are re-read here from the checkpoint.
"""

@classmethod
def _load_tlm_weights(cls, checkpoint_path: str) -> dict:
"""Read only fc and hidden_norm weights from a local checkpoint directory."""
from pathlib import Path

import torch

path = Path(checkpoint_path)
target_keys = {"model.fc.weight", "model.hidden_norm.weight"}
dlm_model_weights: dict = {}

# safetensors (preferred modern format)
sf_files = sorted(path.glob("*.safetensors"))
if sf_files:
try:
from safetensors import safe_open

for sf in sf_files:
with safe_open(str(sf), framework="pt", device="cpu") as f:
for key in f.keys():
if key in target_keys and key not in dlm_model_weights:
dlm_model_weights[key] = f.get_tensor(key)
if len(dlm_model_weights) == len(target_keys):
break
return dlm_model_weights
except ImportError:
warnings.warn("safetensors not installed; falling back to .bin loading.")

# pytorch .bin fallback
bin_files = sorted(path.glob("pytorch_model*.bin"))
for bf in bin_files:
sd = torch.load(str(bf), map_location="cpu", weights_only=True)
for key in target_keys:
if key in sd and key not in dlm_model_weights:
dlm_model_weights[key] = sd[key]
if len(dlm_model_weights) == len(target_keys):
break

return dlm_model_weights

@classmethod
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
target_layer_ids = qaic_config.get("target_layer_ids", None) if qaic_config else None
if not target_layer_ids:
return model, False

n = len(target_layer_ids)
inner = model.model # QEffQwen3Model or QEffLlamaModel
hidden_size = model.config.hidden_size
model_type = getattr(model.config, "model_type", "")

# --- add fc and hidden_norm only if not already present ---
# tlm_maker.py (and similar scripts) inject weights before calling the
# constructor directly; in that case we must not overwrite them.
layers_already_present = hasattr(inner, "fc") and hasattr(inner, "hidden_norm")

if not layers_already_present:
# add fc
inner.fc = nn.Linear(n * hidden_size, hidden_size, bias=False)

# add hidden_norm using the same RMSNorm class as the model
if "qwen3" in model_type:
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm

inner.hidden_norm = Qwen3RMSNorm(hidden_size, eps=model.config.rms_norm_eps)
elif "llama" in model_type:
from transformers.models.llama.modeling_llama import LlamaRMSNorm

inner.hidden_norm = LlamaRMSNorm(hidden_size, eps=model.config.rms_norm_eps)
else:
warnings.warn(
f"DFlashTLMTransform: unknown model_type '{model_type}'. Using nn.RMSNorm as hidden_norm fallback."
)
inner.hidden_norm = nn.RMSNorm(hidden_size, eps=getattr(model.config, "rms_norm_eps", 1e-6))

# load fc / hidden_norm weights from checkpoint
# pretrained_model_name_or_path is stored in qaic_config by from_pretrained
ckpt_path = (qaic_config or {}).get("pretrained_model_name_or_path", None)
if ckpt_path:
weights = cls._load_tlm_weights(ckpt_path)
if "model.fc.weight" in weights:
inner.fc.weight.data.copy_(weights["model.fc.weight"])
if "model.hidden_norm.weight" in weights:
inner.hidden_norm.weight.data.copy_(weights["model.hidden_norm.weight"])
if not weights:
warnings.warn(
"DFlashTLMTransform: fc/hidden_norm weights not found in checkpoint "
f"at '{ckpt_path}'. Layers are randomly initialized."
)
else:
warnings.warn(
"DFlashTLMTransform: no checkpoint path available — "
"use QEFFAutoModelForCausalLM.from_pretrained() so the path is "
"stored automatically, or set qaic_config['pretrained_model_name_or_path']."
)

# --- activate TLM collection path in forward ---
inner.target_layer_ids = target_layer_ids
return model, True


class VlmKVOffloadTransform(ModuleMappingTransform):
# supported architectures
_module_mapping = {
Expand Down
Loading