-
Notifications
You must be signed in to change notification settings - Fork 870
Add recurrent gated delta rule custom op for Qwen3.5 attention #18088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,8 @@ def forward( | |
|
|
||
|
|
||
| ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} | ||
| _RECURRENT_GATED_DELTA_RULE_OP = None | ||
| _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False | ||
|
|
||
|
|
||
| def register_attention(name: str): | ||
|
|
@@ -60,6 +62,37 @@ def decorator(cls: Type[Attention]): | |
| return decorator | ||
|
|
||
|
|
||
| def _get_recurrent_gated_delta_rule_op(): | ||
| global _RECURRENT_GATED_DELTA_RULE_OP | ||
| global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP | ||
|
|
||
| if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP: | ||
| return _RECURRENT_GATED_DELTA_RULE_OP | ||
|
|
||
| _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True | ||
| try: | ||
| _RECURRENT_GATED_DELTA_RULE_OP = ( | ||
| torch.ops.llama.recurrent_gated_delta_rule.default | ||
| ) | ||
| return _RECURRENT_GATED_DELTA_RULE_OP | ||
| except (AttributeError, RuntimeError): | ||
| pass | ||
|
|
||
| try: | ||
| from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 | ||
| except Exception: | ||
| return None | ||
|
|
||
| try: | ||
| _RECURRENT_GATED_DELTA_RULE_OP = ( | ||
| torch.ops.llama.recurrent_gated_delta_rule.default | ||
| ) | ||
| except (AttributeError, RuntimeError): | ||
| _RECURRENT_GATED_DELTA_RULE_OP = None | ||
|
|
||
| return _RECURRENT_GATED_DELTA_RULE_OP | ||
|
|
||
|
|
||
| class KVCache(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -668,6 +701,18 @@ def _recurrent_gated_delta_rule( | |
| scale = 1.0 / (query.shape[-1] ** 0.5) | ||
| query = query * scale | ||
|
|
||
| recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op() | ||
| if recurrent_gated_delta_rule_op is not None: | ||
| core_attn_out = recurrent_gated_delta_rule_op( | ||
| query, | ||
| key, | ||
| value, | ||
| g, | ||
| beta, | ||
| self.recurrent_state[:batch_size], | ||
| ) | ||
| return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | ||
|
|
||
| core_attn_out = torch.zeros( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit. |
||
| batch_size, | ||
| num_heads, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ endif() | |
|
|
||
| set(_common_compile_options | ||
| $<$<CXX_COMPILER_ID:MSVC>:/wd4996> | ||
| $<$<CXX_COMPILER_ID:MSVC>:/Zc:__cplusplus> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What codepath are you doing down that isnt triggering properly without this? Typically the c10 pattern is to just have explicit msvc conditions and not rely on the c++ version on windows iirc. I could be wrong on that though. |
||
| $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wno-deprecated-declarations -fPIC> | ||
| ) | ||
| if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64|aarch64") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_recurrent_gated_delta_rule_op()swallows all exceptions when importingexecutorch.extension.llm.custom_ops.custom_ops. Catching broadExceptioncan hide real load/link errors and make debugging difficult; consider narrowing toImportError/OSError(or logging the exception at debug level) so unexpected failures surface.