diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 7f2b2d4e337..7a4cb5645af 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -48,6 +48,8 @@ def forward( ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {} +_RECURRENT_GATED_DELTA_RULE_OP = None +_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False def register_attention(name: str): @@ -60,6 +62,37 @@ def decorator(cls: Type[Attention]): return decorator +def _get_recurrent_gated_delta_rule_op(): + global _RECURRENT_GATED_DELTA_RULE_OP + global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP + + if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP: + return _RECURRENT_GATED_DELTA_RULE_OP + + _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + try: + _RECURRENT_GATED_DELTA_RULE_OP = ( + torch.ops.llama.recurrent_gated_delta_rule.default + ) + return _RECURRENT_GATED_DELTA_RULE_OP + except (AttributeError, RuntimeError): + pass + + try: + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + except Exception: + return None + + try: + _RECURRENT_GATED_DELTA_RULE_OP = ( + torch.ops.llama.recurrent_gated_delta_rule.default + ) + except (AttributeError, RuntimeError): + _RECURRENT_GATED_DELTA_RULE_OP = None + + return _RECURRENT_GATED_DELTA_RULE_OP + + class KVCache(nn.Module): def __init__( self, @@ -668,6 +701,18 @@ def _recurrent_gated_delta_rule( scale = 1.0 / (query.shape[-1] ** 0.5) query = query * scale + recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op() + if recurrent_gated_delta_rule_op is not None: + core_attn_out = recurrent_gated_delta_rule_op( + query, + key, + value, + g, + beta, + self.recurrent_state[:batch_size], + ) + return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + core_attn_out = torch.zeros( batch_size, num_heads, diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 130a55f658c..03714823a16 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -6,6 +6,10 @@ # LICENSE file in the root directory of this source tree. import unittest +import json +import tempfile + +from pathlib import Path from executorch.devtools.backend_debug import get_delegation_info @@ -25,6 +29,7 @@ from executorch.examples.models.llama.export_llama_lib import ( _export_llama, + _prepare_for_llama_export, build_args_parser, get_quantizer_and_quant_params, ) @@ -37,6 +42,39 @@ class ExportLlamaLibTest(unittest.TestCase): + def _make_tiny_qwen35_params(self) -> dict: + return { + "dim": 64, + "hidden_dim": 128, + "n_heads": 4, + "head_dim": 16, + "n_kv_heads": 2, + "n_layers": 4, + "norm_eps": 1e-6, + "rope_theta": 10000000.0, + "use_scaled_rope": False, + "vocab_size": 256, + "use_hf_rope": True, + "partial_rotary_factor": 0.25, + "attention_qkv_bias": False, + "use_qk_norm": True, + "qk_norm_before_rope": True, + "attention_type": "mha", + "use_q_gate": True, + "rms_norm_add_unit_offset": True, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 8, + "linear_value_head_dim": 8, + "linear_num_key_heads": 4, + "linear_num_value_heads": 4, + "layer_types": [ + "linear_attention", + "full_attention", + "linear_attention", + "full_attention", + ], + } + def test_has_expected_ops_and_op_counts(self): """ Checks the presence of unwanted expensive ops. @@ -66,6 +104,41 @@ def test_has_expected_ops_and_op_counts(self): for op, _op_info in delegation_info.delegation_by_operator.items(): self.assertTrue(op not in UNWANTED_OPS) + def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self): + with tempfile.TemporaryDirectory() as temp_dir: + params_path = Path(temp_dir) / "tiny_qwen35.json" + params_path.write_text(json.dumps(self._make_tiny_qwen35_params())) + + parser = build_args_parser() + args = parser.parse_args( + [ + "--model", + "qwen3_5_0_8b", + "--params", + str(params_path), + "--use_kv_cache", + "--disable_dynamic_shape", + "--max_seq_length", + "8", + "--max_context_length", + "8", + ] + ) + + llm_config = LlmConfig.from_args(args) + builder = _prepare_for_llama_export(llm_config).export() + assert builder.pre_autograd_graph_module is not None + + recurrent_nodes = [ + node + for node in builder.pre_autograd_graph_module.graph.nodes + if "auto_functionalized_v2" in str(node.target) + and node.args + and "llama.recurrent_gated_delta_rule" in str(node.args[0]) + ] + + self.assertEqual(len(recurrent_nodes), 2) + @unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available") def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self): llm_config = LlmConfig() diff --git a/examples/models/llama/tests/test_qwen3_5_attention.py b/examples/models/llama/tests/test_qwen3_5_attention.py index 5a9f67d57cf..b255f08105c 100644 --- a/examples/models/llama/tests/test_qwen3_5_attention.py +++ b/examples/models/llama/tests/test_qwen3_5_attention.py @@ -7,6 +7,7 @@ import unittest import torch +import executorch.examples.models.llama.attention as attention_module from executorch.examples.models.llama.attention import ATTENTION_REGISTRY from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm @@ -123,6 +124,107 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self): torch.allclose(state_after_first, state_after_second, atol=1e-5) ) + def test_gated_deltanet_chunked_prefill_matches_full_sequence(self): + torch.manual_seed(0) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_chunked.load_state_dict(attn_full.state_dict()) + + x = torch.randn(1, 5, args.dim) + dummy_freq = torch.zeros(1, 1) + + full_output, _ = attn_full( + x, + dummy_freq, + dummy_freq, + input_pos=torch.tensor([0], dtype=torch.long), + ) + + chunk_outputs = [] + for start, end in ((0, 3), (3, 4), (4, 5)): + output, _ = attn_chunked( + x[:, start:end], + dummy_freq, + dummy_freq, + input_pos=torch.tensor([start], dtype=torch.long), + ) + chunk_outputs.append(output) + + chunked_output = torch.cat(chunk_outputs, dim=1) + + self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) + self.assertTrue( + torch.allclose( + attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5 + ) + ) + self.assertTrue( + torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5) + ) + + def test_gated_deltanet_custom_op_matches_fallback(self): + recurrent_op = attention_module._get_recurrent_gated_delta_rule_op() + if recurrent_op is None: + self.skipTest("llama::recurrent_gated_delta_rule is not available") + + torch.manual_seed(0) + args = self._make_args( + use_kv_cache=True, + use_q_gate=True, + linear_conv_kernel_dim=4, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_num_key_heads=2, + linear_num_value_heads=4, + ) + rope = Rope(args) + attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope) + attn_fallback.load_state_dict(attn_custom.state_dict()) + + query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) + key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim) + value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim) + g = torch.randn(1, 3, attn_custom.num_v_heads) + beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads)) + + original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP + original_tried_loading = attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP + try: + attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + custom_output = attn_custom._recurrent_gated_delta_rule( + query, key, value, g, beta + ) + + attention_module._RECURRENT_GATED_DELTA_RULE_OP = None + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True + fallback_output = attn_fallback._recurrent_gated_delta_rule( + query, key, value, g, beta + ) + finally: + attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op + attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = ( + original_tried_loading + ) + + self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5)) + self.assertTrue( + torch.allclose( + attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5 + ) + ) + if __name__ == "__main__": unittest.main() diff --git a/extension/llm/custom_ops/CMakeLists.txt b/extension/llm/custom_ops/CMakeLists.txt index 2cdfe547430..07f3eeb6a07 100644 --- a/extension/llm/custom_ops/CMakeLists.txt +++ b/extension/llm/custom_ops/CMakeLists.txt @@ -18,6 +18,7 @@ endif() set(_common_compile_options $<$:/wd4996> + $<$:/Zc:__cplusplus> $<$>:-Wno-deprecated-declarations -fPIC> ) if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64|aarch64") diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 9aacded4b4c..84f5667882e 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -11,7 +11,9 @@ # pyre-unsafe import logging +import os +from pathlib import Path from typing import Tuple import torch @@ -22,32 +24,76 @@ aten = torch.ops.aten -try: - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None - op2 = torch.ops.llama.fast_hadamard_transform.default - assert op2 is not None -except: - # This is needed to ensure that custom ops are registered - from executorch.extension.pybindings import portable_lib # noqa # usort: skip - # Ideally package is installed in only one location but usage of - # PYATHONPATH can result in multiple locations. - # ATM this is mainly used in CI for qnn runner. Will need to revisit this - from pathlib import Path +def _is_custom_ops_registered() -> bool: + try: + torch.ops.llama.sdpa_with_kv_cache.default + torch.ops.llama.fast_hadamard_transform.default + return True + except (AttributeError, RuntimeError): + return False + + +def _get_custom_ops_library_override() -> Path | None: + override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB") + if override is None: + return None + + lib_path = Path(override).expanduser().resolve() + assert lib_path.is_file(), ( + "EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing " + f"custom_ops_aot_lib, but got {lib_path}" + ) + return lib_path + + +def _find_custom_ops_library() -> Path: + override = _get_custom_ops_library_override() + if override is not None: + return override package_path = Path(__file__).parent.resolve() - logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}") + candidates = [] + patterns = ( + "**/custom_ops_aot_lib.dll", + "**/custom_ops_aot_lib.so", + "**/custom_ops_aot_lib.dylib", + ) + + for pattern in patterns: + candidates.extend(package_path.glob(pattern)) + + libs = sorted({path.resolve() for path in candidates if path.is_file()}) + assert libs, f"Could not find custom_ops_aot_lib under {package_path}" + return max(libs, key=lambda path: path.stat().st_mtime) + + +def _load_custom_ops_library() -> None: + try: + # This is needed to ensure that custom ops are registered when + # portable_lib is available in the current environment. + from executorch.extension.pybindings import portable_lib # noqa # usort: skip + except ImportError: + portable_lib = None + + lib_path = _find_custom_ops_library() + logging.info(f"Loading custom ops library: {lib_path}") - libs = list(package_path.glob("**/*custom_ops_aot_lib.*")) + if os.name == "nt": + os.add_dll_directory(str(lib_path.parent)) + torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" + if torch_lib_dir.is_dir(): + os.add_dll_directory(str(torch_lib_dir)) - assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" - logging.info(f"Loading custom ops library: {libs[0]}") - torch.ops.load_library(libs[0]) - op = torch.ops.llama.sdpa_with_kv_cache.default - assert op is not None - op2 = torch.ops.llama.fast_hadamard_transform.default - assert op2 is not None + torch.ops.load_library(lib_path) + + # Keep the import alive to avoid lint complaints in environments where + # portable_lib is needed for symbol resolution. + _ = portable_lib + +if not _is_custom_ops_registered(): + _load_custom_ops_library() + assert _is_custom_ops_registered() custom_ops_lib = torch.library.Library("llama", "IMPL") @@ -271,6 +317,87 @@ def update_cache_with_indices_meta( return torch.empty((1,), dtype=value.dtype, device="meta") +def _validate_recurrent_gated_delta_rule_params( + query, + key, + value, + g, + beta, + recurrent_state, +): + assert ( + query.dim() == 4 + ), f"Expected query to be 4 dimensional but got {query.dim()} dimensions." + assert ( + key.dim() == 4 + ), f"Expected key to be 4 dimensional but got {key.dim()} dimensions." + assert ( + value.dim() == 4 + ), f"Expected value to be 4 dimensional but got {value.dim()} dimensions." + assert g.dim() == 3, f"Expected g to be 3 dimensional but got {g.dim()} dimensions." + assert ( + beta.dim() == 3 + ), f"Expected beta to be 3 dimensional but got {beta.dim()} dimensions." + assert ( + recurrent_state.dim() == 4 + ), f"Expected recurrent_state to be 4 dimensional but got {recurrent_state.dim()} dimensions." + + for name, tensor in { + "query": query, + "key": key, + "value": value, + "g": g, + "beta": beta, + "recurrent_state": recurrent_state, + }.items(): + assert ( + tensor.dtype == torch.float32 + ), f"Expected {name} to be float32 but got {tensor.dtype}" + + assert ( + query.shape == key.shape + ), f"Expected query and key to have matching shapes but got {query.shape} and {key.shape}" + assert ( + query.shape[:3] == value.shape[:3] + ), f"Expected query and value to match in batch/head/sequence dims but got {query.shape} and {value.shape}" + assert ( + g.shape == query.shape[:3] + ), f"Expected g to match query batch/head/sequence dims but got {g.shape} and {query.shape}" + assert ( + beta.shape == query.shape[:3] + ), f"Expected beta to match query batch/head/sequence dims but got {beta.shape} and {query.shape}" + assert recurrent_state.shape == ( + query.size(0), + query.size(1), + query.size(3), + value.size(3), + ), ( + "Expected recurrent_state to have shape " + f"{(query.size(0), query.size(1), query.size(3), value.size(3))} " + f"but got {recurrent_state.shape}" + ) + + +@impl(custom_ops_lib, "recurrent_gated_delta_rule", "Meta") +def recurrent_gated_delta_rule_meta( + query, + key, + value, + g, + beta, + recurrent_state, +): + _validate_recurrent_gated_delta_rule_params( + query, + key, + value, + g, + beta, + recurrent_state, + ) + return torch.empty_like(value) + + def _validate_quantized_sdpa_params( query, key, diff --git a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp index 146ac3cc298..d48f593868c 100644 --- a/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp +++ b/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp @@ -13,14 +13,40 @@ namespace torch::executor::native { namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} + Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) { executorch::aten::RuntimeContext context; return fast_hadamard_transform_out(context, vec, out); } + +at::Tensor& fast_hadamard_transform_out_aten( + const at::Tensor& vec, + at::Tensor& out) { + auto vec_et = to_et_arg(vec); + auto out_et = to_et_arg(out); + auto& et_result = + fast_hadamard_transform_out_no_context(vec_et.call(), out_et.call()); + return copy_et_result_to_out(et_result, out); +} + at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) { auto out = at::empty_like(vec); - WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1) - (vec, out); + fast_hadamard_transform_out_aten(vec, out); return out; } } // namespace @@ -38,6 +64,5 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { torch::executor::native::fast_hadamard_transform_aten); m.impl( "fast_hadamard_transform.out", - WRAP_TO_ATEN( - torch::executor::native::fast_hadamard_transform_out_no_context, 1)); + torch::executor::native::fast_hadamard_transform_out_aten); } diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 72bddce7b5b..839d4e7d3bb 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -15,6 +15,10 @@ #include // @lint-ignore CLANGTIDY facebook-unused-include-check #include +#include +#include +#include +#include #ifdef ET_USE_THREADPOOL #include @@ -178,6 +182,67 @@ bool validate_cache_params( return true; } +bool validate_recurrent_gated_delta_rule_args( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + const Tensor& recurrent_state) { + ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor"); + ET_CHECK_OR_RETURN_FALSE(g.dim() == 3, "g must be a 3D tensor"); + ET_CHECK_OR_RETURN_FALSE(beta.dim() == 3, "beta must be a 3D tensor"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.dim() == 4, "recurrent_state must be a 4D tensor"); + + ET_CHECK_OR_RETURN_FALSE( + query.scalar_type() == ScalarType::Float, "query must be float32"); + ET_CHECK_OR_RETURN_FALSE( + key.scalar_type() == ScalarType::Float, "key must be float32"); + ET_CHECK_OR_RETURN_FALSE( + value.scalar_type() == ScalarType::Float, "value must be float32"); + ET_CHECK_OR_RETURN_FALSE(g.scalar_type() == ScalarType::Float, "g must be float32"); + ET_CHECK_OR_RETURN_FALSE( + beta.scalar_type() == ScalarType::Float, "beta must be float32"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.scalar_type() == ScalarType::Float, + "recurrent_state must be float32"); + + ET_CHECK_OR_RETURN_FALSE( + query.size(0) == key.size(0) && query.size(1) == key.size(1) && + query.size(2) == key.size(2) && query.size(3) == key.size(3), + "query and key must have matching shapes"); + ET_CHECK_OR_RETURN_FALSE( + query.size(0) == value.size(0) && query.size(1) == value.size(1) && + query.size(2) == value.size(2), + "query and value must match in batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + g.size(0) == query.size(0) && g.size(1) == query.size(1) && + g.size(2) == query.size(2), + "g must match query batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + beta.size(0) == query.size(0) && beta.size(1) == query.size(1) && + beta.size(2) == query.size(2), + "beta must match query batch/head/sequence dims"); + ET_CHECK_OR_RETURN_FALSE( + recurrent_state.size(0) == query.size(0) && + recurrent_state.size(1) == query.size(1) && + recurrent_state.size(2) == query.size(3) && + recurrent_state.size(3) == value.size(3), + "recurrent_state shape must match [B, H, K, V]"); + + for (const Tensor* tensor : + {&query, &key, &value, &g, &beta, &recurrent_state}) { + ET_CHECK_OR_RETURN_FALSE( + is_contiguous_dim_order((*tensor).dim_order().data(), (*tensor).dim()), + "recurrent gated delta rule expects contiguous inputs"); + } + + return true; +} + // TODO: seq_length is not yet used for copy void update_cache( const Tensor& projected_value, @@ -610,6 +675,137 @@ Tensor& sdpa_with_kv_cache_out( return output; } + +Tensor& recurrent_gated_delta_rule_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output) { + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(output, value.sizes()) == Error::Ok, + InvalidArgument, + output, + "Failed to resize recurrent_gated_delta_rule output tensor."); + ET_KERNEL_CHECK( + ctx, + validate_recurrent_gated_delta_rule_args( + query, key, value, g, beta, recurrent_state), + InvalidArgument, + output); + ET_KERNEL_CHECK( + ctx, + output.scalar_type() == ScalarType::Float, + InvalidArgument, + output); + ET_KERNEL_CHECK( + ctx, + is_contiguous_dim_order(output.dim_order().data(), output.dim()), + InvalidArgument, + output); + + const auto batch_size = query.size(0); + const auto num_heads = query.size(1); + const auto sequence_length = query.size(2); + const auto k_head_dim = query.size(3); + const auto v_head_dim = value.size(3); + + const auto q_batch_stride = num_heads * sequence_length * k_head_dim; + const auto q_head_stride = sequence_length * k_head_dim; + const auto q_seq_stride = k_head_dim; + + const auto value_batch_stride = num_heads * sequence_length * v_head_dim; + const auto value_head_stride = sequence_length * v_head_dim; + const auto value_seq_stride = v_head_dim; + + const auto gv_batch_stride = num_heads * sequence_length; + const auto gv_head_stride = sequence_length; + + const auto state_batch_stride = num_heads * k_head_dim * v_head_dim; + const auto state_head_stride = k_head_dim * v_head_dim; + + const auto* query_data = query.const_data_ptr(); + const auto* key_data = key.const_data_ptr(); + const auto* value_data = value.const_data_ptr(); + const auto* g_data = g.const_data_ptr(); + const auto* beta_data = beta.const_data_ptr(); + auto* recurrent_state_data = recurrent_state.mutable_data_ptr(); + auto* output_data = output.mutable_data_ptr(); + + for (int64_t batch = 0; batch < batch_size; ++batch) { + for (int64_t head = 0; head < num_heads; ++head) { + const auto q_offset = batch * q_batch_stride + head * q_head_stride; + const auto value_offset = + batch * value_batch_stride + head * value_head_stride; + const auto gv_offset = batch * gv_batch_stride + head * gv_head_stride; + const auto state_offset = + batch * state_batch_stride + head * state_head_stride; + + const auto* q_head = query_data + q_offset; + const auto* k_head = key_data + q_offset; + const auto* value_head = value_data + value_offset; + const auto* g_head = g_data + gv_offset; + const auto* beta_head = beta_data + gv_offset; + auto* state_head = recurrent_state_data + state_offset; + auto* output_head = output_data + value_offset; + + std::vector kv_mem(v_head_dim); + std::vector delta(v_head_dim); + + for (int64_t token = 0; token < sequence_length; ++token) { + const auto* q_t = q_head + token * q_seq_stride; + const auto* k_t = k_head + token * q_seq_stride; + const auto* v_t = value_head + token * value_seq_stride; + auto* output_t = output_head + token * value_seq_stride; + + const float g_t = std::exp(g_head[token]); + const float beta_t = beta_head[token]; + + if (g_t != 1.0f) { + for (int64_t idx = 0; idx < state_head_stride; ++idx) { + state_head[idx] *= g_t; + } + } + + std::fill(kv_mem.begin(), kv_mem.end(), 0.0f); + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float key_value = k_t[k_idx]; + const auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + kv_mem[v_idx] += state_row[v_idx] * key_value; + } + } + + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + delta[v_idx] = (v_t[v_idx] - kv_mem[v_idx]) * beta_t; + } + + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float key_value = k_t[k_idx]; + auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + state_row[v_idx] += key_value * delta[v_idx]; + } + } + + std::fill(output_t, output_t + v_head_dim, 0.0f); + for (int64_t k_idx = 0; k_idx < k_head_dim; ++k_idx) { + const float query_value = q_t[k_idx]; + const auto* state_row = state_head + k_idx * v_head_dim; + for (int64_t v_idx = 0; v_idx < v_head_dim; ++v_idx) { + output_t[v_idx] += state_row[v_idx] * query_value; + } + } + } + } + } + + return output; +} } // namespace native } // namespace executor } // namespace torch @@ -628,3 +824,36 @@ EXECUTORCH_LIBRARY( llama, "custom_quantized_sdpa.out", torch::executor::native::custom_quantized_sdpa_out); + +namespace { + +void recurrent_gated_delta_rule_out_boxed( + executorch::runtime::KernelRuntimeContext& ctx, + executorch::runtime::Span stack) { + ET_KERNEL_CHECK_MSG( + ctx, + stack.size() == 7, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + static_cast(7), + stack.size()); + + auto& query = stack[0]->toTensor(); + auto& key = stack[1]->toTensor(); + auto& value = stack[2]->toTensor(); + auto& g = stack[3]->toTensor(); + auto& beta = stack[4]->toTensor(); + auto& recurrent_state = stack[5]->toTensor(); + auto& output = stack[6]->toTensor(); + + (void)torch::executor::native::recurrent_gated_delta_rule_out( + ctx, query, key, value, g, beta, recurrent_state, output); +} + +const auto recurrent_gated_delta_rule_out_registration = + executorch::runtime::register_kernel(executorch::runtime::Kernel( + "llama::recurrent_gated_delta_rule.out", + recurrent_gated_delta_rule_out_boxed)); + +} // namespace diff --git a/extension/llm/custom_ops/op_sdpa.h b/extension/llm/custom_ops/op_sdpa.h index 9d357eb6ea1..9f029f52f31 100644 --- a/extension/llm/custom_ops/op_sdpa.h +++ b/extension/llm/custom_ops/op_sdpa.h @@ -75,6 +75,16 @@ Tensor& custom_quantized_sdpa_out( const optional& v_scales, const bool is_seq_at_dim_1, Tensor& output); + +Tensor& recurrent_gated_delta_rule_out( + RuntimeContext& ctx, + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output); } // namespace native } // namespace executor } // namespace torch diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index 5bbf22d336e..096c698e5d1 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -17,6 +17,24 @@ namespace torch { namespace executor { namespace native { +namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} +} // namespace + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -50,6 +68,20 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +at::Tensor& sdpa_with_kv_cache_out_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output); + Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -77,6 +109,17 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale); +at::Tensor& custom_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output); + Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -118,6 +161,24 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2); +at::Tensor& custom_quantized_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales, + const bool is_seq_at_dim_2, + at::Tensor& output); + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -129,6 +190,12 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos); +at::Tensor& update_cache_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + at::Tensor& output); + // New functions for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -143,6 +210,39 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices); +at::Tensor& update_cache_with_indices_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices, + at::Tensor& output); + +Tensor& recurrent_gated_delta_rule_out_no_context( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output); + +at::Tensor recurrent_gated_delta_rule_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state); + +at::Tensor& recurrent_gated_delta_rule_out_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state, + at::Tensor& output); + Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, @@ -192,22 +292,59 @@ at::Tensor sdpa_with_kv_cache_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty_like(q_projected); - WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) - (q_projected, - k_projected, - v_projected, - key_cache, - value_cache, - start_pos, - seq_len, - attn_mask, - dropout_p, - is_causal, - scale, - output); + sdpa_with_kv_cache_out_aten( + q_projected, + k_projected, + v_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, + output); return output; } +at::Tensor& sdpa_with_kv_cache_out_aten( + const at::Tensor& q_projected, + const at::Tensor& k_projected, + const at::Tensor& v_projected, + at::Tensor& key_cache, + at::Tensor& value_cache, + const int64_t start_pos, + const int64_t seq_len, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output) { + auto q_et = to_et_arg(q_projected); + auto k_et = to_et_arg(k_projected); + auto v_et = to_et_arg(v_projected); + auto key_cache_et = to_et_arg(key_cache); + auto value_cache_et = to_et_arg(value_cache); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto output_et = to_et_arg(output); + auto& et_result = sdpa_with_kv_cache_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + key_cache_et.call(), + value_cache_et.call(), + start_pos, + seq_len, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -248,11 +385,40 @@ at::Tensor custom_sdpa_aten( // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) - (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); + custom_sdpa_out_aten( + q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); return output; } +at::Tensor& custom_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + at::Tensor& output) { + auto q_et = to_et_arg(q); + auto k_et = to_et_arg(k); + auto v_et = to_et_arg(v); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto output_et = to_et_arg(output); + auto& et_result = custom_sdpa_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + start_pos, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& custom_quantized_sdpa_out_no_context( const Tensor& q, const Tensor& k, @@ -314,26 +480,75 @@ at::Tensor custom_quantized_sdpa_aten( const std::optional& v_scales, const bool is_seq_at_dim_2) { auto output = at::empty(q.sizes()); - WRAP_TO_ATEN(custom_quantized_sdpa_out_no_context, 15) - (q, - k, - v, - start_pos, - attn_mask, - dropout_p, - is_causal, - scale, - q_zero_points, - q_scales, - k_zero_points, - k_scales, - v_zero_points, - v_scales, - is_seq_at_dim_2, - output); + custom_quantized_sdpa_out_aten( + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + is_seq_at_dim_2, + output); return output; } +at::Tensor& custom_quantized_sdpa_out_aten( + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const int64_t start_pos, + const std::optional attn_mask, + const double dropout_p, + const bool is_causal, + const std::optional scale, + const std::optional& q_zero_points, + const std::optional& q_scales, + const std::optional& k_zero_points, + const std::optional& k_scales, + const std::optional& v_zero_points, + const std::optional& v_scales, + const bool is_seq_at_dim_2, + at::Tensor& output) { + auto q_et = to_et_arg(q); + auto k_et = to_et_arg(k); + auto v_et = to_et_arg(v); + auto attn_mask_et = to_et_arg>(attn_mask); + auto scale_et = to_et_arg>(scale); + auto q_zero_points_et = to_et_arg>(q_zero_points); + auto q_scales_et = to_et_arg>(q_scales); + auto k_zero_points_et = to_et_arg>(k_zero_points); + auto k_scales_et = to_et_arg>(k_scales); + auto v_zero_points_et = to_et_arg>(v_zero_points); + auto v_scales_et = to_et_arg>(v_scales); + auto output_et = to_et_arg(output); + auto& et_result = custom_quantized_sdpa_out_no_context( + q_et.call(), + k_et.call(), + v_et.call(), + start_pos, + attn_mask_et.call(), + dropout_p, + is_causal, + scale_et.call(), + q_zero_points_et.call(), + q_scales_et.call(), + k_zero_points_et.call(), + k_scales_et.call(), + v_zero_points_et.call(), + v_scales_et.call(), + is_seq_at_dim_2, + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, @@ -349,11 +564,23 @@ at::Tensor update_cache_aten( at::Tensor& cache, const int64_t start_pos) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_out_no_context, 3) - (value, cache, start_pos, output); + update_cache_out_aten(value, cache, start_pos, output); return output; } +at::Tensor& update_cache_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + at::Tensor& output) { + auto value_et = to_et_arg(value); + auto cache_et = to_et_arg(cache); + auto output_et = to_et_arg(output); + auto& et_result = update_cache_out_no_context( + value_et.call(), cache_et.call(), start_pos, output_et.call()); + return copy_et_result_to_out(et_result, output); +} + // Implementations for update_cache_with_indices Tensor& update_cache_with_indices_out_no_context( const Tensor& value, @@ -372,11 +599,81 @@ at::Tensor update_cache_with_indices_aten( const int64_t start_pos, const at::Tensor& indices) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_with_indices_out_no_context, 4) - (value, cache, start_pos, indices, output); + update_cache_with_indices_out_aten(value, cache, start_pos, indices, output); + return output; +} + +at::Tensor& update_cache_with_indices_out_aten( + const at::Tensor& value, + at::Tensor& cache, + const int64_t start_pos, + const at::Tensor& indices, + at::Tensor& output) { + auto value_et = to_et_arg(value); + auto cache_et = to_et_arg(cache); + auto indices_et = to_et_arg(indices); + auto output_et = to_et_arg(output); + auto& et_result = update_cache_with_indices_out_no_context( + value_et.call(), + cache_et.call(), + start_pos, + indices_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + +Tensor& recurrent_gated_delta_rule_out_no_context( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& g, + const Tensor& beta, + Tensor& recurrent_state, + Tensor& output) { + executorch::aten::RuntimeContext context{}; + return torch::executor::native::recurrent_gated_delta_rule_out( + context, query, key, value, g, beta, recurrent_state, output); +} + +at::Tensor recurrent_gated_delta_rule_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state) { + auto output = at::empty_like(value); + recurrent_gated_delta_rule_out_aten( + query, key, value, g, beta, recurrent_state, output); return output; } +at::Tensor& recurrent_gated_delta_rule_out_aten( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& g, + const at::Tensor& beta, + at::Tensor& recurrent_state, + at::Tensor& output) { + auto query_et = to_et_arg(query); + auto key_et = to_et_arg(key); + auto value_et = to_et_arg(value); + auto g_et = to_et_arg(g); + auto beta_et = to_et_arg(beta); + auto recurrent_state_et = to_et_arg(recurrent_state); + auto output_et = to_et_arg(output); + auto& et_result = recurrent_gated_delta_rule_out_no_context( + query_et.call(), + key_et.call(), + value_et.call(), + g_et.call(), + beta_et.call(), + recurrent_state_et.call(), + output_et.call()); + return copy_et_result_to_out(et_result, output); +} + } // namespace native } // namespace executor } // namespace torch @@ -410,6 +707,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "update_cache_with_indices.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)"); + m.def( + "recurrent_gated_delta_rule(Tensor query, Tensor key, Tensor value, Tensor g, " + "Tensor beta, Tensor(a!) recurrent_state) -> Tensor"); + m.def( + "recurrent_gated_delta_rule.out(Tensor query, Tensor key, Tensor value, Tensor g, " + "Tensor beta, Tensor(a!) recurrent_state, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -430,29 +733,31 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); m.impl( "sdpa_with_kv_cache.out", - WRAP_TO_ATEN( - torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); + torch::executor::native::sdpa_with_kv_cache_out_aten); m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); m.impl( "custom_sdpa.out", - WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); + torch::executor::native::custom_sdpa_out_aten); m.impl("update_cache", torch::executor::native::update_cache_aten); m.impl( "update_cache.out", - WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); + torch::executor::native::update_cache_out_aten); m.impl( "update_cache_with_indices", torch::executor::native::update_cache_with_indices_aten); m.impl( "update_cache_with_indices.out", - WRAP_TO_ATEN( - torch::executor::native::update_cache_with_indices_out_no_context, - 4)); + torch::executor::native::update_cache_with_indices_out_aten); + m.impl( + "recurrent_gated_delta_rule", + torch::executor::native::recurrent_gated_delta_rule_aten); + m.impl( + "recurrent_gated_delta_rule.out", + torch::executor::native::recurrent_gated_delta_rule_out_aten); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); m.impl( "custom_quantized_sdpa.out", - WRAP_TO_ATEN( - torch::executor::native::custom_quantized_sdpa_out_no_context, 15)); + torch::executor::native::custom_quantized_sdpa_out_aten); } diff --git a/extension/llm/custom_ops/op_tile_crop_aot.cpp b/extension/llm/custom_ops/op_tile_crop_aot.cpp index 5aa98ee8d4a..dcbfcdd2e2d 100644 --- a/extension/llm/custom_ops/op_tile_crop_aot.cpp +++ b/extension/llm/custom_ops/op_tile_crop_aot.cpp @@ -16,10 +16,30 @@ namespace torch { namespace executor { namespace native { +namespace { +template +auto to_et_arg(AType&& value) { + return executorch::extension::internal::type_convert( + std::forward(value)); +} + +at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { + auto converted_result = + executorch::extension::internal::type_convert( + et_result) + .call(); + at::native::resize_output(out, converted_result.sizes()); + out.copy_(converted_result); + return out; +} +} // namespace Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out); +at::Tensor& +tile_crop_out_aten(const at::Tensor& input, int64_t tile_size, at::Tensor& out); + Tensor& tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { executorch::aten::RuntimeContext context{}; @@ -28,12 +48,19 @@ tile_crop_out_no_context(const Tensor& input, int64_t tile_size, Tensor& out) { at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size); +at::Tensor& +tile_crop_out_aten(const at::Tensor& input, int64_t tile_size, at::Tensor& out) { + auto input_et = to_et_arg(input); + auto out_et = to_et_arg(out); + auto& et_result = + tile_crop_out_no_context(input_et.call(), tile_size, out_et.call()); + return copy_et_result_to_out(et_result, out); +} + at::Tensor tile_crop_aten(const at::Tensor& input, int64_t tile_size) { // max_num_tiles = 4, num_channels = 3. auto output = at::empty({4, 3, tile_size, tile_size}); - - WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2) - (input, tile_size, output); + tile_crop_out_aten(input, tile_size, output); return output; } @@ -51,5 +78,5 @@ TORCH_LIBRARY_IMPL(preprocess, CompositeExplicitAutograd, m) { m.impl("tile_crop", torch::executor::native::tile_crop_aten); m.impl( "tile_crop.out", - WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)); + torch::executor::native::tile_crop_out_aten); } diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 84a349c97f0..9124f56c946 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -431,3 +431,157 @@ def test_batched_update_kv_cache_more_updates(self): self._update_and_validate( k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) + + +class RecurrentGatedDeltaRuleTest(unittest.TestCase): + def _make_inputs( + self, + batch_size: int = 2, + num_heads: int = 3, + seq_len: int = 4, + k_head_dim: int = 5, + v_head_dim: int = 6, + ): + query = torch.randn(batch_size, num_heads, seq_len, k_head_dim) + key = torch.randn(batch_size, num_heads, seq_len, k_head_dim) + value = torch.randn(batch_size, num_heads, seq_len, v_head_dim) + g = torch.randn(batch_size, num_heads, seq_len) + beta = torch.sigmoid(torch.randn(batch_size, num_heads, seq_len)) + recurrent_state = torch.randn( + batch_size, num_heads, k_head_dim, v_head_dim + ) + return query, key, value, g, beta, recurrent_state + + def _reference_recurrent_gated_delta_rule( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_state: torch.Tensor, + ): + state = recurrent_state.clone() + output = torch.zeros_like(value) + + for token in range(query.size(2)): + g_t = g[:, :, token].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, token].unsqueeze(-1) + k_t = key[:, :, token] + v_t = value[:, :, token] + q_t = query[:, :, token] + + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output[:, :, token] = (state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output, state + + def test_recurrent_gated_delta_rule_matches_reference(self): + torch.manual_seed(0) + + test_cases = ( + (2, 3, 4, 5, 6), + (1, 4, 7, 8, 3), + ) + + for case in test_cases: + with self.subTest(case=case): + ( + query, + key, + value, + g, + beta, + recurrent_state, + ) = self._make_inputs(*case) + + expected_output, expected_state = ( + self._reference_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + ) + ) + + actual_state = recurrent_state.clone() + actual_output = torch.ops.llama.recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + actual_state, + ) + + self.assertTrue( + torch.allclose(actual_output, expected_output, atol=1e-5) + ) + self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) + + def test_recurrent_gated_delta_rule_out_matches_reference(self): + torch.manual_seed(0) + + query, key, value, g, beta, recurrent_state = self._make_inputs() + expected_output, expected_state = self._reference_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + recurrent_state, + ) + + actual_state = recurrent_state.clone() + actual_output = torch.empty_like(value) + returned_output = torch.ops.llama.recurrent_gated_delta_rule.out( + query, + key, + value, + g, + beta, + actual_state, + out=actual_output, + ) + + self.assertEqual(returned_output.data_ptr(), actual_output.data_ptr()) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-5)) + self.assertTrue(torch.allclose(actual_state, expected_state, atol=1e-5)) + + def test_recurrent_gated_delta_rule_chunked_matches_full_sequence(self): + torch.manual_seed(0) + + query, key, value, g, beta, recurrent_state = self._make_inputs(seq_len=6) + + full_state = recurrent_state.clone() + full_output = torch.ops.llama.recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + full_state, + ) + + chunk_state = recurrent_state.clone() + chunk_outputs = [] + for start, end in ((0, 2), (2, 5), (5, 6)): + chunk_outputs.append( + torch.ops.llama.recurrent_gated_delta_rule( + query[:, :, start:end, :], + key[:, :, start:end, :], + value[:, :, start:end, :], + g[:, :, start:end], + beta[:, :, start:end], + chunk_state, + ) + ) + + chunked_output = torch.cat(chunk_outputs, dim=2) + self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5)) + self.assertTrue(torch.allclose(chunk_state, full_state, atol=1e-5))