From 9c2cc66cede5875489a44fa21385a72d51eedd4d Mon Sep 17 00:00:00 2001 From: Anxhelo Xhebraj Date: Thu, 12 Mar 2026 10:09:18 -0700 Subject: [PATCH] Fix scale to 1.0 for cudnn_flash_jax --- src/maxtext/layers/attention_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index c0033b4bae..622bcf1e8e 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1672,7 +1672,7 @@ def cudnn_jax_flash_attention( key, value, mask_type=MaskType.CAUSAL, - scale=1.0 / math.sqrt(head_dim), + scale=1.0, dropout_rate=self.dropout_rate, qkv_layout="BTNH", return_residual=True,