diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index c0033b4bae..622bcf1e8e 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1672,7 +1672,7 @@ def cudnn_jax_flash_attention( key, value, mask_type=MaskType.CAUSAL, - scale=1.0 / math.sqrt(head_dim), + scale=1.0, dropout_rate=self.dropout_rate, qkv_layout="BTNH", return_residual=True,