Skip to content

Commit e446b03

Browse files
authored
Fix scale to 1.0 for cudnn_flash_jax
1 parent 441bc95 commit e446b03

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxtext/layers/attention_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,7 @@ def cudnn_jax_flash_attention(
16781678
key,
16791679
value,
16801680
mask_type=MaskType.CAUSAL,
1681-
scale=1.0 / math.sqrt(head_dim),
1681+
scale=1.0,
16821682
dropout_rate=self.dropout_rate,
16831683
qkv_layout="BTNH",
16841684
return_residual=True,

0 commit comments

Comments
 (0)