Add recurrent gated delta rule custom op for Qwen3.5 attention#18088
Add recurrent gated delta rule custom op for Qwen3.5 attention#18088Phineas1500 wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18088
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Awaiting Approval, 29 New Failures, 1 Cancelled Job, 3 PendingAs of commit e5540ad with merge base 8c0a60b ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds a fused llama::recurrent_gated_delta_rule custom op and integrates it into Qwen3.5 GatedDeltaNet attention to avoid the Python per-token recurrence loop when the op is available, along with tighter custom-op library discovery/loading and new test coverage.
Changes:
- Implemented and registered
llama::recurrent_gated_delta_rule(runtime kernel + ATen/AOT registrations) and updated attention to use it with a fallback path. - Refined
custom_ops_aot_libdiscovery/loading (package-local by default, optionalEXECUTORCH_CUSTOM_OPS_AOT_LIBoverride). - Added tests for recurrent-state correctness/parity, chunked prefill behavior, and export graph op selection.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/custom_ops/test_update_cache.py | Adds unit tests for recurrent gated delta rule correctness, .out behavior, and chunking parity. |
| extension/llm/custom_ops/op_tile_crop_aot.cpp | Replaces WRAP_TO_ATEN usage with explicit ET↔ATen conversion helpers for .out. |
| extension/llm/custom_ops/op_sdpa_aot.cpp | Adds ATen bindings for recurrent op; refactors multiple .out wrappers to explicit conversions. |
| extension/llm/custom_ops/op_sdpa.h | Declares the new recurrent_gated_delta_rule_out kernel signature. |
| extension/llm/custom_ops/op_sdpa.cpp | Implements recurrent kernel logic and registers the ExecuTorch kernel. |
| extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp | Refactors .out binding to explicit ET↔ATen conversion helpers. |
| extension/llm/custom_ops/custom_ops.py | Tightens custom op library discovery/loading; adds meta impl for recurrent op. |
| extension/llm/custom_ops/CMakeLists.txt | Adds MSVC /Zc:__cplusplus compile option. |
| examples/models/llama/tests/test_qwen3_5_attention.py | Adds chunked prefill parity + fused-op vs fallback parity tests. |
| examples/models/llama/tests/test_export_llama_lib.py | Adds tiny Qwen3.5 export test asserting recurrent op selection in graph. |
| examples/models/llama/attention.py | Adds lazy lookup/loading for fused recurrent op and uses it when available. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _get_custom_ops_library_override() -> Path | None: | ||
| override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB") | ||
| if override is None: | ||
| return None | ||
|
|
||
| lib_path = Path(override).expanduser().resolve() | ||
| assert lib_path.is_file(), ( | ||
| "EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing " | ||
| f"custom_ops_aot_lib, but got {lib_path}" | ||
| ) | ||
| return lib_path | ||
|
|
||
|
|
||
| def _find_custom_ops_library() -> Path: | ||
| override = _get_custom_ops_library_override() | ||
| if override is not None: | ||
| return override | ||
|
|
||
| package_path = Path(__file__).parent.resolve() | ||
| logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}") | ||
| candidates = [] | ||
| patterns = ( | ||
| "**/custom_ops_aot_lib.dll", | ||
| "**/custom_ops_aot_lib.so", | ||
| "**/custom_ops_aot_lib.dylib", | ||
| ) | ||
|
|
||
| for pattern in patterns: | ||
| candidates.extend(package_path.glob(pattern)) | ||
|
|
||
| libs = sorted({path.resolve() for path in candidates if path.is_file()}) | ||
| assert libs, f"Could not find custom_ops_aot_lib under {package_path}" | ||
| return max(libs, key=lambda path: path.stat().st_mtime) |
There was a problem hiding this comment.
Avoid using assert for runtime validation of EXECUTORCH_CUSTOM_OPS_AOT_LIB / library discovery. Assertions can be stripped with python -O, turning these into silent misconfigurations; raise a ValueError/FileNotFoundError with the same message instead.
| try: | ||
| from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 | ||
| except Exception: | ||
| return None |
There was a problem hiding this comment.
_get_recurrent_gated_delta_rule_op() swallows all exceptions when importing executorch.extension.llm.custom_ops.custom_ops. Catching broad Exception can hide real load/link errors and make debugging difficult; consider narrowing to ImportError/OSError (or logging the exception at debug level) so unexpected failures surface.
| std::vector<float> kv_mem(v_head_dim); | ||
| std::vector<float> delta(v_head_dim); | ||
|
|
There was a problem hiding this comment.
recurrent_gated_delta_rule_out allocates std::vector buffers (kv_mem, delta) inside the per-(batch, head) loop. For long sequences / many heads this adds repeated heap allocations and can dominate runtime; allocate these buffers once per call (or reuse a scratch buffer) and resize as needed, or use stack/arena allocation when sizes are small.
| namespace { | ||
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } | ||
|
|
||
| at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { | ||
| auto converted_result = | ||
| executorch::extension::internal::type_convert<Tensor&, at::Tensor>( | ||
| et_result) | ||
| .call(); | ||
| at::native::resize_output(out, converted_result.sizes()); | ||
| out.copy_(converted_result); | ||
| return out; | ||
| } |
There was a problem hiding this comment.
The to_et_arg / copy_et_result_to_out helpers are duplicated here and in other custom-op AOT wrappers (tile_crop / sdpa / fast_hadamard_transform). Consider factoring them into a shared utility header to reduce copy-paste and keep conversion semantics consistent across ops.
| if os.name == "nt": | ||
| os.add_dll_directory(str(lib_path.parent)) | ||
| torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" | ||
| if torch_lib_dir.is_dir(): | ||
| os.add_dll_directory(str(torch_lib_dir)) |
There was a problem hiding this comment.
On Windows, os.add_dll_directory() returns a handle that must be kept alive; otherwise the directory is removed immediately (CPython refcounting), which can cause torch.ops.load_library() to fail to resolve dependent DLLs. Store the returned handles (e.g., in a module-level list) at least through the load (and ideally for process lifetime).
|
@lucylq this is a PR for optimizing the recurrence in Qwen 3.5, which we discussed here: #17801 (comment) I'm next going to make a PR for quantization. I'll let you know once that's up. |
| m.impl("tile_crop", torch::executor::native::tile_crop_aten); | ||
| m.impl( | ||
| "tile_crop.out", | ||
| WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)); |
There was a problem hiding this comment.
what problem did you run into with WRAP_TO_ATEN?
|
Can you tell me a bit more about the serialization issue you ran into as well as the MSVC one? |
| ) | ||
| return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | ||
|
|
||
| core_attn_out = torch.zeros( |
There was a problem hiding this comment.
can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.
|
|
||
| set(_common_compile_options | ||
| $<$<CXX_COMPILER_ID:MSVC>:/wd4996> | ||
| $<$<CXX_COMPILER_ID:MSVC>:/Zc:__cplusplus> |
There was a problem hiding this comment.
What codepath are you doing down that isnt triggering properly without this? Typically the c10 pattern is to just have explicit msvc conditions and not rely on the c++ version on windows iirc. I could be wrong on that though.
Summary
This PR adds a fused
llama::recurrent_gated_delta_rulecustom op and wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python per-token recurrence loop when the op is available.It also tightens local custom-op loading so we no longer implicitly scan repo-local
cmake-out*directories, and adds coverage for recurrent-state correctness, chunked prefill behavior, and export graph selection.What changed
llama::recurrent_gated_delta_ruleruntime and AOT registrationscustom_ops_aot_libdiscovery:EXECUTORCH_CUSTOM_OPS_AOT_LIBcmake-out*scanning.outvariant behaviorllama.recurrent_gated_delta_ruleValidation
Linux CPU-only (aarch64)
Built
custom_ops_aot_libsuccessfully and loaded it viaEXECUTORCH_CUSTOM_OPS_AOT_LIB.Passed:
pytest extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest -q3 passedpytest examples/models/llama/tests/test_qwen3_5_attention.py -q7 passedpytest examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule -q1 passedReal-model CPU validation
On a real
Qwen3.5-0.8BCPU run, fused recurrence matched the fallback path on next-token selection with very small logit drift, and improved eager prefill latency on the tested prompt.Observed on local CPU validation:
1e-51.6xon the tested promptWindows note
A local Windows-only FFHT/MSVC workaround was used during development to keep the local build usable, but that workaround is intentionally not included in this PR.
Non-goals / separate issues
I did not treat the local
program.fbsserialization issue as part of this change.This branch does not modify
exir/_serialize/*orschema/program.fbs, and serialization-focused checks passed on both this branch and cleanmainonce the local environment was set up correctly.A separate end-to-end tiny Qwen3.5
.pteexport probe hit:RuntimeError: Missing out variants: {'aten::alias'}That appears to be a separate pre-existing export issue outside this change set.