Skip to content

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Jan 21, 2026

Force layout on Q for MLA when use_jax_splash=True

This helps in non-pallas splash attention and removes copies when num_heads is 128.
major to minor layout
original query: 1, 2, 192, 1024, 128
attention expectation: 1, 2, 128, 192 , 1024

This change speedups up non-pallas forward splash attention upto 14%.

@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 50.00000% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/attention_mla.py 55.55% 2 Missing and 2 partials ⚠️
src/MaxText/kernels/jax_flash_attention.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@copybara-service copybara-service bot force-pushed the test_855382451 branch 2 times, most recently from b0ed511 to 4386083 Compare January 23, 2026 19:46
@copybara-service copybara-service bot changed the title Force layout on Q for MLA. Force layout on Q for MLA when use_jax_splash=True Jan 23, 2026
@copybara-service copybara-service bot force-pushed the test_855382451 branch 2 times, most recently from f7d8cf3 to 762ac72 Compare January 26, 2026 15:56
@copybara-service copybara-service bot closed this Jan 26, 2026
@copybara-service copybara-service bot deleted the test_855382451 branch January 26, 2026 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants