Skip to content

Add recurrent gated delta rule custom op for Qwen3.5 attention#18088

Open
Phineas1500 wants to merge 3 commits intopytorch:mainfrom
Phineas1500:feature/recurrent-gated-delta-rule-windows
Open

Add recurrent gated delta rule custom op for Qwen3.5 attention#18088
Phineas1500 wants to merge 3 commits intopytorch:mainfrom
Phineas1500:feature/recurrent-gated-delta-rule-windows

Conversation

@Phineas1500
Copy link
Contributor

Summary

This PR adds a fused llama::recurrent_gated_delta_rule custom op and wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python per-token recurrence loop when the op is available.

It also tightens local custom-op loading so we no longer implicitly scan repo-local cmake-out* directories, and adds coverage for recurrent-state correctness, chunked prefill behavior, and export graph selection.

What changed

  • added llama::recurrent_gated_delta_rule runtime and AOT registrations
  • updated Qwen3.5 GatedDeltaNet attention to use the fused op with Python fallback preserved
  • tightened custom_ops_aot_lib discovery:
    • default to package-local discovery
    • allow explicit override via EXECUTORCH_CUSTOM_OPS_AOT_LIB
    • removed implicit repo-local cmake-out* scanning
  • added tests for:
    • recurrent op parity vs reference
    • .out variant behavior
    • chunked-state parity vs full-sequence execution
    • custom-op vs fallback attention parity
    • tiny Qwen3.5 export selecting llama.recurrent_gated_delta_rule

Validation

Linux CPU-only (aarch64)

Built custom_ops_aot_lib successfully and loaded it via EXECUTORCH_CUSTOM_OPS_AOT_LIB.

Passed:

  • pytest extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest -q
    • 3 passed
  • pytest examples/models/llama/tests/test_qwen3_5_attention.py -q
    • 7 passed
  • pytest examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule -q
    • 1 passed

Real-model CPU validation

On a real Qwen3.5-0.8B CPU run, fused recurrence matched the fallback path on next-token selection with very small logit drift, and improved eager prefill latency on the tested prompt.

Observed on local CPU validation:

  • same next token from fused path vs fallback
  • max logit diff on the order of 1e-5
  • eager prefill speedup about 1.6x on the tested prompt

Windows note

A local Windows-only FFHT/MSVC workaround was used during development to keep the local build usable, but that workaround is intentionally not included in this PR.

Non-goals / separate issues

I did not treat the local program.fbs serialization issue as part of this change.

This branch does not modify exir/_serialize/* or schema/program.fbs, and serialization-focused checks passed on both this branch and clean main once the local environment was set up correctly.

A separate end-to-end tiny Qwen3.5 .pte export probe hit:

  • RuntimeError: Missing out variants: {'aten::alias'}

That appears to be a separate pre-existing export issue outside this change set.

Copilot AI review requested due to automatic review settings March 11, 2026 04:36
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18088

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Awaiting Approval, 29 New Failures, 1 Cancelled Job, 3 Pending

As of commit e5540ad with merge base 8c0a60b (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a fused llama::recurrent_gated_delta_rule custom op and integrates it into Qwen3.5 GatedDeltaNet attention to avoid the Python per-token recurrence loop when the op is available, along with tighter custom-op library discovery/loading and new test coverage.

Changes:

  • Implemented and registered llama::recurrent_gated_delta_rule (runtime kernel + ATen/AOT registrations) and updated attention to use it with a fallback path.
  • Refined custom_ops_aot_lib discovery/loading (package-local by default, optional EXECUTORCH_CUSTOM_OPS_AOT_LIB override).
  • Added tests for recurrent-state correctness/parity, chunked prefill behavior, and export graph op selection.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
extension/llm/custom_ops/test_update_cache.py Adds unit tests for recurrent gated delta rule correctness, .out behavior, and chunking parity.
extension/llm/custom_ops/op_tile_crop_aot.cpp Replaces WRAP_TO_ATEN usage with explicit ET↔ATen conversion helpers for .out.
extension/llm/custom_ops/op_sdpa_aot.cpp Adds ATen bindings for recurrent op; refactors multiple .out wrappers to explicit conversions.
extension/llm/custom_ops/op_sdpa.h Declares the new recurrent_gated_delta_rule_out kernel signature.
extension/llm/custom_ops/op_sdpa.cpp Implements recurrent kernel logic and registers the ExecuTorch kernel.
extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp Refactors .out binding to explicit ET↔ATen conversion helpers.
extension/llm/custom_ops/custom_ops.py Tightens custom op library discovery/loading; adds meta impl for recurrent op.
extension/llm/custom_ops/CMakeLists.txt Adds MSVC /Zc:__cplusplus compile option.
examples/models/llama/tests/test_qwen3_5_attention.py Adds chunked prefill parity + fused-op vs fallback parity tests.
examples/models/llama/tests/test_export_llama_lib.py Adds tiny Qwen3.5 export test asserting recurrent op selection in graph.
examples/models/llama/attention.py Adds lazy lookup/loading for fused recurrent op and uses it when available.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +37 to +68
def _get_custom_ops_library_override() -> Path | None:
override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB")
if override is None:
return None

lib_path = Path(override).expanduser().resolve()
assert lib_path.is_file(), (
"EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing "
f"custom_ops_aot_lib, but got {lib_path}"
)
return lib_path


def _find_custom_ops_library() -> Path:
override = _get_custom_ops_library_override()
if override is not None:
return override

package_path = Path(__file__).parent.resolve()
logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}")
candidates = []
patterns = (
"**/custom_ops_aot_lib.dll",
"**/custom_ops_aot_lib.so",
"**/custom_ops_aot_lib.dylib",
)

for pattern in patterns:
candidates.extend(package_path.glob(pattern))

libs = sorted({path.resolve() for path in candidates if path.is_file()})
assert libs, f"Could not find custom_ops_aot_lib under {package_path}"
return max(libs, key=lambda path: path.stat().st_mtime)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using assert for runtime validation of EXECUTORCH_CUSTOM_OPS_AOT_LIB / library discovery. Assertions can be stripped with python -O, turning these into silent misconfigurations; raise a ValueError/FileNotFoundError with the same message instead.

Copilot uses AI. Check for mistakes.
Comment on lines +81 to +84
try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except Exception:
return None
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_recurrent_gated_delta_rule_op() swallows all exceptions when importing executorch.extension.llm.custom_ops.custom_ops. Catching broad Exception can hide real load/link errors and make debugging difficult; consider narrowing to ImportError/OSError (or logging the exception at debug level) so unexpected failures surface.

Copilot uses AI. Check for mistakes.
Comment on lines +756 to +758
std::vector<float> kv_mem(v_head_dim);
std::vector<float> delta(v_head_dim);

Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recurrent_gated_delta_rule_out allocates std::vector buffers (kv_mem, delta) inside the per-(batch, head) loop. For long sequences / many heads this adds repeated heap allocations and can dominate runtime; allocate these buffers once per call (or reuse a scratch buffer) and resize as needed, or use stack/arena allocation when sizes are small.

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +35
namespace {
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}

at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) {
auto converted_result =
executorch::extension::internal::type_convert<Tensor&, at::Tensor>(
et_result)
.call();
at::native::resize_output(out, converted_result.sizes());
out.copy_(converted_result);
return out;
}
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The to_et_arg / copy_et_result_to_out helpers are duplicated here and in other custom-op AOT wrappers (tile_crop / sdpa / fast_hadamard_transform). Consider factoring them into a shared utility header to reduce copy-paste and keep conversion semantics consistent across ops.

Copilot uses AI. Check for mistakes.
Comment on lines +82 to +86
if os.name == "nt":
os.add_dll_directory(str(lib_path.parent))
torch_lib_dir = Path(torch.__file__).resolve().parent / "lib"
if torch_lib_dir.is_dir():
os.add_dll_directory(str(torch_lib_dir))
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Windows, os.add_dll_directory() returns a handle that must be kept alive; otherwise the directory is removed immediately (CPython refcounting), which can cause torch.ops.load_library() to fail to resolve dependent DLLs. Store the returned handles (e.g., in a module-level list) at least through the load (and ideally for process lifetime).

Copilot uses AI. Check for mistakes.
@Phineas1500
Copy link
Contributor Author

@lucylq this is a PR for optimizing the recurrence in Qwen 3.5, which we discussed here: #17801 (comment)

I'm next going to make a PR for quantization. I'll let you know once that's up.

m.impl("tile_crop", torch::executor::native::tile_crop_aten);
m.impl(
"tile_crop.out",
WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what problem did you run into with WRAP_TO_ATEN?

@JacobSzwejbka
Copy link
Contributor

Can you tell me a bit more about the serialization issue you ran into as well as the MSVC one?

)
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

core_attn_out = torch.zeros(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.


set(_common_compile_options
$<$<CXX_COMPILER_ID:MSVC>:/wd4996>
$<$<CXX_COMPILER_ID:MSVC>:/Zc:__cplusplus>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What codepath are you doing down that isnt triggering properly without this? Typically the c10 pattern is to just have explicit msvc conditions and not rely on the c++ version on windows iirc. I could be wrong on that though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants