Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def forward(


ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
_RECURRENT_GATED_DELTA_RULE_OP = None
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False


def register_attention(name: str):
Expand All @@ -60,6 +62,37 @@ def decorator(cls: Type[Attention]):
return decorator


def _get_recurrent_gated_delta_rule_op():
global _RECURRENT_GATED_DELTA_RULE_OP
global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP

if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP:
return _RECURRENT_GATED_DELTA_RULE_OP

_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
try:
_RECURRENT_GATED_DELTA_RULE_OP = (
torch.ops.llama.recurrent_gated_delta_rule.default
)
return _RECURRENT_GATED_DELTA_RULE_OP
except (AttributeError, RuntimeError):
pass

try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except Exception:
return None
Comment on lines +81 to +84
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_recurrent_gated_delta_rule_op() swallows all exceptions when importing executorch.extension.llm.custom_ops.custom_ops. Catching broad Exception can hide real load/link errors and make debugging difficult; consider narrowing to ImportError/OSError (or logging the exception at debug level) so unexpected failures surface.

Copilot uses AI. Check for mistakes.

try:
_RECURRENT_GATED_DELTA_RULE_OP = (
torch.ops.llama.recurrent_gated_delta_rule.default
)
except (AttributeError, RuntimeError):
_RECURRENT_GATED_DELTA_RULE_OP = None

return _RECURRENT_GATED_DELTA_RULE_OP


class KVCache(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -668,6 +701,18 @@ def _recurrent_gated_delta_rule(
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op()
if recurrent_gated_delta_rule_op is not None:
core_attn_out = recurrent_gated_delta_rule_op(
query,
key,
value,
g,
beta,
self.recurrent_state[:batch_size],
)
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

core_attn_out = torch.zeros(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.

batch_size,
num_heads,
Expand Down
73 changes: 73 additions & 0 deletions examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
Expand All @@ -6,6 +6,10 @@
# LICENSE file in the root directory of this source tree.

import unittest
import json
import tempfile

from pathlib import Path

from executorch.devtools.backend_debug import get_delegation_info

Expand All @@ -25,6 +29,7 @@

from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
_prepare_for_llama_export,
build_args_parser,
get_quantizer_and_quant_params,
)
Expand All @@ -37,6 +42,39 @@


class ExportLlamaLibTest(unittest.TestCase):
def _make_tiny_qwen35_params(self) -> dict:
return {
"dim": 64,
"hidden_dim": 128,
"n_heads": 4,
"head_dim": 16,
"n_kv_heads": 2,
"n_layers": 4,
"norm_eps": 1e-6,
"rope_theta": 10000000.0,
"use_scaled_rope": False,
"vocab_size": 256,
"use_hf_rope": True,
"partial_rotary_factor": 0.25,
"attention_qkv_bias": False,
"use_qk_norm": True,
"qk_norm_before_rope": True,
"attention_type": "mha",
"use_q_gate": True,
"rms_norm_add_unit_offset": True,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 8,
"linear_value_head_dim": 8,
"linear_num_key_heads": 4,
"linear_num_value_heads": 4,
"layer_types": [
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
],
}

def test_has_expected_ops_and_op_counts(self):
"""
Checks the presence of unwanted expensive ops.
Expand Down Expand Up @@ -66,6 +104,41 @@
for op, _op_info in delegation_info.delegation_by_operator.items():
self.assertTrue(op not in UNWANTED_OPS)

def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self):
with tempfile.TemporaryDirectory() as temp_dir:
params_path = Path(temp_dir) / "tiny_qwen35.json"
params_path.write_text(json.dumps(self._make_tiny_qwen35_params()))

parser = build_args_parser()
args = parser.parse_args(
[
"--model",
"qwen3_5_0_8b",
"--params",
str(params_path),
"--use_kv_cache",
"--disable_dynamic_shape",
"--max_seq_length",
"8",
"--max_context_length",
"8",
]
)

llm_config = LlmConfig.from_args(args)
builder = _prepare_for_llama_export(llm_config).export()
assert builder.pre_autograd_graph_module is not None

recurrent_nodes = [
node
for node in builder.pre_autograd_graph_module.graph.nodes
if "auto_functionalized_v2" in str(node.target)
and node.args
and "llama.recurrent_gated_delta_rule" in str(node.args[0])
]

self.assertEqual(len(recurrent_nodes), 2)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
llm_config = LlmConfig()
Expand Down
102 changes: 102 additions & 0 deletions examples/models/llama/tests/test_qwen3_5_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -7,6 +7,7 @@
import unittest

import torch
import executorch.examples.models.llama.attention as attention_module
from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
Expand Down Expand Up @@ -123,6 +124,107 @@
torch.allclose(state_after_first, state_after_second, atol=1e-5)
)

def test_gated_deltanet_chunked_prefill_matches_full_sequence(self):
torch.manual_seed(0)
args = self._make_args(
use_kv_cache=True,
use_q_gate=True,
linear_conv_kernel_dim=4,
linear_key_head_dim=4,
linear_value_head_dim=4,
linear_num_key_heads=2,
linear_num_value_heads=4,
)
rope = Rope(args)
attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_chunked.load_state_dict(attn_full.state_dict())

x = torch.randn(1, 5, args.dim)
dummy_freq = torch.zeros(1, 1)

full_output, _ = attn_full(
x,
dummy_freq,
dummy_freq,
input_pos=torch.tensor([0], dtype=torch.long),
)

chunk_outputs = []
for start, end in ((0, 3), (3, 4), (4, 5)):
output, _ = attn_chunked(
x[:, start:end],
dummy_freq,
dummy_freq,
input_pos=torch.tensor([start], dtype=torch.long),
)
chunk_outputs.append(output)

chunked_output = torch.cat(chunk_outputs, dim=1)

self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5))
self.assertTrue(
torch.allclose(
attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5
)
)
self.assertTrue(
torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5)
)

def test_gated_deltanet_custom_op_matches_fallback(self):
recurrent_op = attention_module._get_recurrent_gated_delta_rule_op()
if recurrent_op is None:
self.skipTest("llama::recurrent_gated_delta_rule is not available")

torch.manual_seed(0)
args = self._make_args(
use_kv_cache=True,
use_q_gate=True,
linear_conv_kernel_dim=4,
linear_key_head_dim=4,
linear_value_head_dim=4,
linear_num_key_heads=2,
linear_num_value_heads=4,
)
rope = Rope(args)
attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_fallback.load_state_dict(attn_custom.state_dict())

query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim)
g = torch.randn(1, 3, attn_custom.num_v_heads)
beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads))

original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP
original_tried_loading = attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
try:
attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
custom_output = attn_custom._recurrent_gated_delta_rule(
query, key, value, g, beta
)

attention_module._RECURRENT_GATED_DELTA_RULE_OP = None
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
fallback_output = attn_fallback._recurrent_gated_delta_rule(
query, key, value, g, beta
)
finally:
attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = (
original_tried_loading
)

self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5))
self.assertTrue(
torch.allclose(
attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5
)
)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions extension/llm/custom_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ endif()

set(_common_compile_options
$<$<CXX_COMPILER_ID:MSVC>:/wd4996>
$<$<CXX_COMPILER_ID:MSVC>:/Zc:__cplusplus>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What codepath are you doing down that isnt triggering properly without this? Typically the c10 pattern is to just have explicit msvc conditions and not rely on the c++ version on windows iirc. I could be wrong on that though.

$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wno-deprecated-declarations -fPIC>
)
if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64|aarch64")
Expand Down
Loading
Loading