You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Force layout on Q for MLA when use_jax_splash=True
This helps in non-pallas splash attention and removes copies when num_heads is 128.
major to minor layout
original query: 1, 2, 192, 1024, 128
attention expectation: 1, 2, 128, 192 , 1024
This change speedups up non-pallas forward splash attention upto 14%.
PiperOrigin-RevId: 861206515
0 commit comments