From 5cb453423e2cf55246d00faef95660aabab9f7bb Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 18 Mar 2026 10:02:44 -0700 Subject: [PATCH] benchmark gdn Differential Revision: D97132275 --- examples/models/llama/attention.py | 169 ++++++- examples/models/llama/model_args.py | 1 + examples/models/llama/tests/BUCK | 18 + .../test_gated_deltanet_coreml_export.py | 453 ++++++++++++++++++ .../llama/tests/test_qwen3_5_attention.py | 90 ++++ 5 files changed, 730 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama/tests/test_gated_deltanet_coreml_export.py diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 7f2b2d4e337..dde76f0cfd0 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -564,6 +564,7 @@ def __init__( self.head_v_dim = args.linear_value_head_dim self.key_dim = self.head_k_dim * self.num_k_heads self.value_dim = self.head_v_dim * self.num_v_heads + self.chunk_size = args.deltanet_chunk_size self.conv_kernel_size = args.linear_conv_kernel_dim assert ( @@ -702,6 +703,165 @@ def _recurrent_gated_delta_rule( return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + def _chunked_recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + ) -> torch.Tensor: + initial_dtype = query.dtype + query = _l2norm(query, dim=-1, eps=1e-6) + key = _l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + chunk_size = self.chunk_size + assert chunk_size is not None + + # Pad T to next multiple of chunk_size. + remainder = sequence_length % chunk_size + if remainder != 0: + pad_len = chunk_size - remainder + query = F.pad(query, (0, 0, 0, pad_len)) + key = F.pad(key, (0, 0, 0, pad_len)) + value = F.pad(value, (0, 0, 0, pad_len)) + g = F.pad(g, (0, 0, 0, pad_len)) + beta = F.pad(beta, (0, 0, 0, pad_len)) + padded_len = sequence_length + pad_len + else: + padded_len = sequence_length + + num_chunks = padded_len // chunk_size + + # Pre-scale by beta. + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + # Reshape into chunks: (B, H, N, C, dim). + q_c = query.reshape( + batch_size, num_heads, num_chunks, chunk_size, k_head_dim + ) + k_c = key.reshape( + batch_size, num_heads, num_chunks, chunk_size, k_head_dim + ) + v_c = v_beta.reshape( + batch_size, num_heads, num_chunks, chunk_size, v_head_dim + ) + k_beta_c = k_beta.reshape( + batch_size, num_heads, num_chunks, chunk_size, k_head_dim + ) + g_c = g.reshape(batch_size, num_heads, num_chunks, chunk_size) + + # Cumulative log-gate within each chunk. + decay = g_c.cumsum(dim=-1) + decay_exp = decay.exp().unsqueeze(-1) + + # Relative gate-decay matrix (lower triangular, C×C per chunk). + L_mask = ( + (decay.unsqueeze(-1) - decay.unsqueeze(-2)).tril().exp().tril() + ) + + # --- WY representation (parallel across all chunks) --- + # M = (β·k @ k^T · L_mask), strictly lower triangular C×C. + kkt = k_beta_c @ k_c.transpose(-1, -2) + M = (kkt * L_mask).tril(diagonal=-1) + + # Solve (I + M) X = [v_c | k_beta_c·exp(decay)] for corrected values/keys. + I_plus_M = M + torch.eye( + chunk_size, device=key.device, dtype=key.dtype + ) + rhs = torch.cat([v_c, k_beta_c * decay_exp], dim=-1) + solved = torch.linalg.solve_triangular( + I_plus_M, rhs, upper=False, unitriangular=True + ) + v_wy = solved[..., :v_head_dim] + k_cumdecay = solved[..., v_head_dim:] + + # Intra-chunk causal attention: Q @ K^T · L_mask, upper-triangular masked. + causal_mask = torch.triu( + torch.ones( + chunk_size, chunk_size, dtype=torch.bool, device=key.device + ), + diagonal=1, + ) + intra_attn = ( + q_c @ k_c.transpose(-1, -2) * L_mask + ).masked_fill(causal_mask, 0) + + # --- Inter-chunk state propagation via torch.scan --- + # Permute scan dimension first: (N, B, H, ...). + intra_attn_s = intra_attn.permute(2, 0, 1, 3, 4) + v_wy_s = v_wy.permute(2, 0, 1, 3, 4) + k_cumdecay_s = k_cumdecay.permute(2, 0, 1, 3, 4) + q_c_s = q_c.permute(2, 0, 1, 3, 4) + k_c_s = k_c.permute(2, 0, 1, 3, 4) + decay_s = decay.permute(2, 0, 1, 3) + + init_state = self.recurrent_state[:batch_size].to(value.dtype) + + def inter_chunk_fn(carry, x): + attn_i = x[0] + v_wy_i = x[1] + k_cumdecay_i = x[2] + q_ci = x[3] + k_ci = x[4] + decay_i = x[5] + + # Subtract state contribution from WY-corrected values. + v_prime = k_cumdecay_i @ carry + v_new = v_wy_i - v_prime + + # Output = inter (query × state) + intra (causal attn × corrected v). + o_inter = (q_ci * decay_i.unsqueeze(-1).exp()) @ carry + o_i = o_inter + attn_i @ v_new + + # State update: gate-decay + rank-C outer product. + decay_last = decay_i[:, :, chunk_size - 1] + k_decayed = k_ci * ( + decay_last.unsqueeze(-1).unsqueeze(-1) + - decay_i.unsqueeze(-1) + ).exp() + next_carry = ( + carry * decay_last.unsqueeze(-1).unsqueeze(-1).exp() + + k_decayed.transpose(-1, -2) @ v_new + ) + return next_carry, o_i + + xs = ( + intra_attn_s, + v_wy_s, + k_cumdecay_s, + q_c_s, + k_c_s, + decay_s, + ) + final_state, all_outputs = torch.scan( + inter_chunk_fn, init=init_state, xs=xs + ) + + # (N, B, H, C, V) -> (B, H, T_padded, V) + all_outputs = all_outputs.permute(1, 2, 0, 3, 4).reshape( + batch_size, num_heads, padded_len, v_head_dim + ) + + core_attn_out = all_outputs[:, :, :sequence_length, :] + + with torch.no_grad(): + self.recurrent_state[:batch_size].copy_( + final_state.to(self.recurrent_state.dtype) + ) + + return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + def forward( self, x: torch.Tensor, @@ -740,7 +900,14 @@ def forward( beta = b.sigmoid() g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta) + if self.chunk_size is not None: + core_attn_out = self._chunked_recurrent_gated_delta_rule( + query, key, value, g, beta + ) + else: + core_attn_out = self._recurrent_gated_delta_rule( + query, key, value, g, beta + ) core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) z = z.reshape(-1, self.head_v_dim) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index a3380417316..d0ae5cdb60f 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -50,6 +50,7 @@ class ModelArgs: linear_value_head_dim: Optional[int] = None linear_num_key_heads: Optional[int] = None linear_num_value_heads: Optional[int] = None + deltanet_chunk_size: Optional[int] = None # Qwen3.5 RMSNorm uses (1 + weight) scaling. rms_norm_add_unit_offset: bool = False multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index 431c3c92814..cd59db5386f 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -99,6 +99,24 @@ fbcode_target(_kind = python_unittest, ], ) +fbcode_target(_kind = python_unittest, + name = "test_gated_deltanet_coreml_export", + srcs = [ + "test_gated_deltanet_coreml_export.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:export_library", + "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama:sdpa", + "//executorch/examples/models/llama:custom_kv_cache", + "//executorch/extension/pybindings:portable_lib", + ], +) + fbcode_target(_kind = python_unittest, name = "test_export_llama_lib", srcs = [ diff --git a/examples/models/llama/tests/test_gated_deltanet_coreml_export.py b/examples/models/llama/tests/test_gated_deltanet_coreml_export.py new file mode 100644 index 00000000000..ee718e119e1 --- /dev/null +++ b/examples/models/llama/tests/test_gated_deltanet_coreml_export.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time +import unittest + +import torch +from torch import nn + +from executorch.examples.models.llama.attention import ATTENTION_REGISTRY +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope + +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, +) +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.extension.llm.export.partitioner_lib import get_coreml_partitioner +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, +) + + +SEQUENCE_LENGTHS = [1, 2, 4, 8, 16, 32, 64, 256, 512, 604] + +WARMUP_ITERS = 5 +BENCH_ITERS = 20 + + +class _AttentionExportWrapper(nn.Module): + """Wraps AttentionGatedDeltaNet for torch.export with explicit arguments. + + Supplies dummy frequency tensors (unused by DeltaNet) and routes input_pos. + """ + + def __init__(self, attn: nn.Module) -> None: + super().__init__() + self.attn = attn + self.register_buffer("_dummy_freq", torch.zeros(1, 1)) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + out, _ = self.attn( + x, self._dummy_freq, self._dummy_freq, input_pos=input_pos + ) + return out + + +class TestGatedDeltaNetCoreMLExport(unittest.TestCase): + """Verify AttentionGatedDeltaNet exports for CoreML at various sequence lengths. + + Hyperparameters match the Qwen3.5 0.8B linear-attention layer config: + dim=1024, linear_key_head_dim=128, linear_value_head_dim=128, + linear_num_key_heads=16, linear_num_value_heads=16, conv_kernel=4. + """ + + DIM = 1024 + + @staticmethod + def _make_args(seq_len: int) -> ModelArgs: + return ModelArgs( + dim=1024, + n_layers=1, + n_heads=8, + n_kv_heads=2, + head_dim=256, + hidden_dim=3584, + norm_eps=1e-6, + vocab_size=248320, + max_seq_len=max(seq_len, 64), + max_context_len=max(seq_len, 64), + max_batch_size=1, + use_kv_cache=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + def _build_and_export(self, seq_len: int): + torch.manual_seed(0) + args = self._make_args(seq_len) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn.eval() + + wrapper = _AttentionExportWrapper(attn) + example_inputs = ( + torch.randn(1, seq_len, args.dim), + torch.tensor([0], dtype=torch.long), + ) + + exported = torch.export.export(wrapper, example_inputs, strict=False) + out = exported.module()(*example_inputs) + self.assertEqual(out.shape, (1, seq_len, args.dim)) + return exported + + def _lower_to_coreml(self, exported): + partitioner = get_coreml_partitioner( + coreml_ios=18, + embedding_quantize=None, + pt2e_quantize=None, + coreml_quantize=None, + coreml_compute_units="cpu_and_ne", + ) + edge = to_edge_transform_and_lower( + exported, + partitioner=[partitioner], + ) + return edge + + def _export_to_buffer(self, seq_len: int) -> bytes: + exported = self._build_and_export(seq_len) + edge = self._lower_to_coreml(exported) + et_program = edge.to_executorch() + return et_program.buffer + + def test_torch_export(self): + """Verify torch.export succeeds at all target sequence lengths.""" + for seq_len in SEQUENCE_LENGTHS: + with self.subTest(seq_len=seq_len): + self._build_and_export(seq_len) + + def test_coreml_lower(self): + """Verify CoreML lowering succeeds at all target sequence lengths.""" + for seq_len in SEQUENCE_LENGTHS: + with self.subTest(seq_len=seq_len): + exported = self._build_and_export(seq_len) + edge = self._lower_to_coreml(exported) + self.assertIsNotNone(edge) + + def test_coreml_benchmark(self): + """Export, load via ET pybindings, and benchmark forward at each seq length.""" + results = [] + + for seq_len in SEQUENCE_LENGTHS: + with self.subTest(seq_len=seq_len): + pte_buffer = self._export_to_buffer(seq_len) + et_module = _load_for_executorch_from_buffer(pte_buffer) + + inputs = [ + torch.randn(1, seq_len, self.DIM), + torch.tensor([0], dtype=torch.long), + ] + + # Warmup + for _ in range(WARMUP_ITERS): + et_module.forward(inputs) + + # Benchmark + start = time.perf_counter() + for _ in range(BENCH_ITERS): + et_module.forward(inputs) + elapsed = time.perf_counter() - start + + avg_ms = (elapsed / BENCH_ITERS) * 1000.0 + results.append((seq_len, avg_ms)) + + # Display results table + print("\n") + print("=" * 50) + print("GatedDeltaNet CoreML Forward Benchmark") + print(f" Qwen3.5 0.8B hyperparams, batch_size=1") + print(f" warmup={WARMUP_ITERS}, iters={BENCH_ITERS}") + print("=" * 50) + print(f"{'seq_len':>10} {'avg (ms)':>10} {'throughput (tok/s)':>18}") + print("-" * 50) + for seq_len, avg_ms in results: + toks_per_sec = seq_len / (avg_ms / 1000.0) + print(f"{seq_len:>10} {avg_ms:>10.3f} {toks_per_sec:>18.1f}") + print("=" * 50) + + +XNNPACK_SEQUENCE_LENGTHS = [1, 2, 4, 8, 16, 32, 64] +CONTEXT_LEN = 1024 +CONTEXT_FILL_REPS = 3 + + +class _MHAExportWrapper(nn.Module): + """Wraps AttentionMHA for torch.export with precomputed RoPE freq tables.""" + + def __init__(self, attn: nn.Module, rope: Rope) -> None: + super().__init__() + self.attn = attn + self.register_buffer("_freqs_cos", rope.freqs_cos.clone()) + self.register_buffer("_freqs_sin", rope.freqs_sin.clone()) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + freqs_cos = self._freqs_cos[input_pos] + freqs_sin = self._freqs_sin[input_pos] + out, _ = self.attn(x, freqs_cos, freqs_sin, input_pos=input_pos) + return out + + +def _export_and_lower_xnnpack(wrapper, example_inputs): + exported = torch.export.export(wrapper, example_inputs, strict=False) + edge = to_edge_transform_and_lower( + exported, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=True, + _skip_dim_order=True, + ), + ) + return edge + + +def _benchmark_single(edge, inputs): + pte_buffer = edge.to_executorch().buffer + et_module = _load_for_executorch_from_buffer(pte_buffer) + for _ in range(WARMUP_ITERS): + et_module.forward(inputs) + start = time.perf_counter() + for _ in range(BENCH_ITERS): + et_module.forward(inputs) + elapsed = time.perf_counter() - start + return (elapsed / BENCH_ITERS) * 1000.0 + + +def _benchmark_context_fill(edge, seq_len, dim, context_len, pos_fn): + """Run actual multi-step context fill, advancing input_pos each step. + + pos_fn(step, seq_len) returns the input_pos tensor for that step. + """ + pte_buffer = edge.to_executorch().buffer + et_module = _load_for_executorch_from_buffer(pte_buffer) + num_steps = context_len // seq_len + + # Warmup: one full context fill + for step in range(num_steps): + inputs = [torch.randn(1, seq_len, dim), pos_fn(step, seq_len)] + et_module.forward(inputs) + + # Benchmark + start = time.perf_counter() + for _ in range(CONTEXT_FILL_REPS): + for step in range(num_steps): + inputs = [torch.randn(1, seq_len, dim), pos_fn(step, seq_len)] + et_module.forward(inputs) + elapsed = time.perf_counter() - start + return (elapsed / CONTEXT_FILL_REPS) * 1000.0 + + +def _gdn_pos(step, seq_len): + return torch.tensor([step * seq_len], dtype=torch.long) + + +def _mha_pos(step, seq_len): + start = step * seq_len + return torch.arange(start, start + seq_len, dtype=torch.long) + + +def _make_gdn_args(seq_len: int, max_ctx: int = 0, chunk_size=None) -> ModelArgs: + max_ctx = max(max_ctx, seq_len, 64) + return ModelArgs( + dim=1024, + n_layers=1, + n_heads=8, + n_kv_heads=2, + head_dim=256, + hidden_dim=3584, + norm_eps=1e-6, + vocab_size=248320, + max_seq_len=max_ctx, + max_context_len=max_ctx, + max_batch_size=1, + use_kv_cache=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + deltanet_chunk_size=chunk_size, + ) + + +def _make_mha_args(seq_len: int, max_ctx: int = 0) -> ModelArgs: + max_ctx = max(max_ctx, seq_len, 64) + return ModelArgs( + dim=1024, + n_layers=1, + n_heads=16, + n_kv_heads=4, + head_dim=64, + hidden_dim=3584, + norm_eps=1e-6, + vocab_size=248320, + max_seq_len=max_ctx, + max_context_len=max_ctx, + max_batch_size=1, + use_kv_cache=True, + ) + + +def _build_gdn(seq_len, max_ctx=0, chunk_size=None): + torch.manual_seed(0) + args = _make_gdn_args(seq_len, max_ctx, chunk_size) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn.eval() + wrapper = _AttentionExportWrapper(attn) + example_inputs = ( + torch.randn(1, seq_len, args.dim), + torch.tensor([0], dtype=torch.long), + ) + return wrapper, example_inputs + + +def _build_mha(seq_len, max_ctx=0): + from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, + ) + from executorch.examples.models.llama.source_transformation.sdpa import ( + replace_sdpa_with_custom_op, + ) + + torch.manual_seed(0) + args = _make_mha_args(seq_len, max_ctx) + rope = Rope(args) + attn = ATTENTION_REGISTRY["mha"](args, 0, rope) + attn.eval() + for p in attn.parameters(): + p.requires_grad_(False) + for b in attn.buffers(): + b.requires_grad_(False) + wrapper = _MHAExportWrapper(attn, rope) + example_inputs = ( + torch.randn(1, seq_len, args.dim), + torch.arange(0, seq_len, dtype=torch.long), + ) + return wrapper, example_inputs + + +def _print_table(title, subtitle, results, col2="avg (ms)"): + print("\n") + print("=" * 55) + print(title) + print(f" {subtitle}") + print("=" * 55) + print(f"{'seq_len':>10} {col2:>12}") + print("-" * 30) + for row in results: + print(f"{row[0]:>10} {row[1]:>12.3f}") + print("=" * 55) + + +class TestGatedDeltaNetXnnpackBenchmark(unittest.TestCase): + """Export AttentionGatedDeltaNet to XNNPACK and benchmark at various seq lengths.""" + + DIM = 1024 + + def test_xnnpack_sequential_benchmark(self): + results = [] + for seq_len in XNNPACK_SEQUENCE_LENGTHS: + wrapper, inputs = _build_gdn(seq_len) + edge = _export_and_lower_xnnpack(wrapper, inputs) + avg_ms = _benchmark_single(edge, list(inputs)) + results.append((seq_len, avg_ms)) + _print_table( + "GatedDeltaNet XNNPACK — fp32", + "Qwen3.5 0.8B hyperparams, batch_size=1", + results, + ) + + def test_xnnpack_context_fill(self): + results = [] + for seq_len in XNNPACK_SEQUENCE_LENGTHS: + wrapper, inputs = _build_gdn(seq_len, max_ctx=CONTEXT_LEN) + edge = _export_and_lower_xnnpack(wrapper, inputs) + total_ms = _benchmark_context_fill( + edge, seq_len, self.DIM, CONTEXT_LEN, _gdn_pos + ) + results.append((seq_len, total_ms)) + _print_table( + f"GatedDeltaNet XNNPACK — fill {CONTEXT_LEN} tokens (fp32)", + f"Qwen3.5 0.8B hyperparams, batch_size=1, reps={CONTEXT_FILL_REPS}", + results, + col2="total (ms)", + ) + +class TestMHAXnnpackBenchmark(unittest.TestCase): + """Export Qwen3.5 0.8B MHA block to XNNPACK and benchmark.""" + + DIM = 1024 + + def test_xnnpack_mha_benchmark(self): + results = [] + for seq_len in XNNPACK_SEQUENCE_LENGTHS: + wrapper, inputs = _build_mha(seq_len) + edge = _export_and_lower_xnnpack(wrapper, inputs) + avg_ms = _benchmark_single(edge, list(inputs)) + results.append((seq_len, avg_ms)) + _print_table( + "MHA XNNPACK — fp32 (custom SDPA + KV cache)", + "Qwen3.5 0.8B hyperparams, batch_size=1", + results, + ) + + def test_xnnpack_mha_context_fill(self): + results = [] + for seq_len in XNNPACK_SEQUENCE_LENGTHS: + wrapper, inputs = _build_mha(seq_len, max_ctx=CONTEXT_LEN) + edge = _export_and_lower_xnnpack(wrapper, inputs) + total_ms = _benchmark_context_fill( + edge, seq_len, self.DIM, CONTEXT_LEN, _mha_pos + ) + results.append((seq_len, total_ms)) + _print_table( + f"MHA XNNPACK — fill {CONTEXT_LEN} tokens (fp32, custom ops)", + f"Qwen3.5 0.8B hyperparams, batch_size=1, reps={CONTEXT_FILL_REPS}", + results, + col2="total (ms)", + ) + +class TestXnnpackGraphs(unittest.TestCase): + """Print before/after XNNPACK lowering graphs for GDN and MHA.""" + + def _print_before_after(self, label, wrapper, example_inputs): + exported = torch.export.export(wrapper, example_inputs, strict=False) + print("\n") + print("=" * 70) + print(f"{label} — BEFORE XNNPACK lowering") + print("=" * 70) + print(exported.graph_module.graph) + + edge = to_edge_transform_and_lower( + exported, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _check_ir_validity=True, + _skip_dim_order=True, + ), + ) + print("\n") + print("=" * 70) + print(f"{label} — AFTER XNNPACK lowering") + print("=" * 70) + print(edge.exported_program().graph_module.graph) + + def test_gdn_graph(self): + wrapper, inputs = _build_gdn(1) + self._print_before_after("GDN (seq_len=1)", wrapper, inputs) + + def test_mha_graph(self): + wrapper, inputs = _build_mha(1) + self._print_before_after("MHA (seq_len=1)", wrapper, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 5a9f67d57cf..aa7995d6de2 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -96,6 +96,96 @@ def test_gated_deltanet_resets_state_on_new_sequence(self): state_after_reset = attn.recurrent_state.clone() self.assertTrue(torch.allclose(state_after_first, state_after_reset, atol=1e-5)) + def _make_deltanet_pair(self, chunk_size, seq_len): + """Build two identical gated_deltanet layers: one chunked, one sequential.""" + torch.manual_seed(42) + base_args = { + "use_kv_cache": True, + "use_q_gate": True, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 4, + "linear_value_head_dim": 4, + "linear_num_key_heads": 2, + "linear_num_value_heads": 4, + } + args_seq = self._make_args(**base_args, deltanet_chunk_size=None) + args_chunk = self._make_args(**base_args, deltanet_chunk_size=chunk_size) + rope = Rope(args_seq) + + attn_seq = ATTENTION_REGISTRY["gated_deltanet"](args_seq, 0, rope) + attn_chunk = ATTENTION_REGISTRY["gated_deltanet"](args_chunk, 0, rope) + attn_chunk.load_state_dict(attn_seq.state_dict()) + attn_seq.eval() + attn_chunk.eval() + + x = torch.randn(1, seq_len, args_seq.dim) + dummy_freq = torch.zeros(1, 1) + return attn_seq, attn_chunk, x, dummy_freq + + def test_chunked_deltanet_matches_sequential(self): + attn_seq, attn_chunk, x, dummy_freq = self._make_deltanet_pair( + chunk_size=4, seq_len=12 + ) + out_seq, _ = attn_seq(x, dummy_freq, dummy_freq) + out_chunk, _ = attn_chunk(x, dummy_freq, dummy_freq) + + self.assertTrue( + torch.allclose(out_seq, out_chunk, atol=1e-5), + f"Max diff: {(out_seq - out_chunk).abs().max().item()}", + ) + self.assertTrue( + torch.allclose( + attn_seq.recurrent_state, attn_chunk.recurrent_state, atol=1e-5 + ), + "Recurrent states diverged.", + ) + + def test_chunked_deltanet_non_divisible_seq_len(self): + attn_seq, attn_chunk, x, dummy_freq = self._make_deltanet_pair( + chunk_size=4, seq_len=7 + ) + out_seq, _ = attn_seq(x, dummy_freq, dummy_freq) + out_chunk, _ = attn_chunk(x, dummy_freq, dummy_freq) + + self.assertTrue( + torch.allclose(out_seq, out_chunk, atol=1e-5), + f"Max diff: {(out_seq - out_chunk).abs().max().item()}", + ) + self.assertTrue( + torch.allclose( + attn_seq.recurrent_state, attn_chunk.recurrent_state, atol=1e-5 + ), + "Recurrent states diverged for non-divisible seq_len.", + ) + + def test_chunked_deltanet_export(self): + torch.manual_seed(42) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + deltanet_chunk_size=4, + ) + rope = Rope(args) + attn = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn.eval() + + x = torch.randn(1, 8, args.dim) + dummy_freq = torch.zeros(1, 1) + try: + exported = torch.export.export( + attn, + (x, dummy_freq, dummy_freq), + strict=False, + ) + self.assertIsNotNone(exported) + except Exception as e: + self.fail(f"torch.export.export failed with chunked deltanet: {e}") + def test_gated_deltanet_no_input_pos_does_not_leak_state(self): torch.manual_seed(0) args = self._make_args(