Add Gemma4 layer-wise unit tests#3905
Conversation
a98f5ee to
4908150
Compare
bc9ff8d to
adebb01
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
2071949 to
1282054
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
The Pull Request introduces comprehensive layer-wise unit tests for the Gemma 4 vision components, comparing the MaxText (JAX) implementation against the PyTorch reference from the transformers library. The tests cover all key layers, including VisionEntry, RotaryEmbedding, Attention, EncoderBlock, and the full VisionEncoderLayer, ensuring numerical parity across frameworks.
🔍 General Feedback
- Completeness: The coverage is excellent, including both individual components and end-to-end encoder tests.
- Numerical Parity: Most individual layer tests use a tight tolerance (
1e-3), which is a strong indicator of implementation correctness. - Standardized Helpers: The use of shared multimodal test utilities and clear weight-copying functions makes the tests easy to follow and maintain.
- CI Status: Note that the new test file is added to
pytest.ini's ignore list, which is consistent with the stated TODO to transform these into PyTorch-free tests for CI compatibility.
| float32_logits=True, | ||
| float32_qk_product=True, | ||
| ) | ||
|
|
There was a problem hiding this comment.
🟢 This variable is defined but not used in the pyconfig.initialize call below. Consider removing it if it's not needed.
| base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml") | |
| jax_config = pyconfig.initialize( |
| jax_inputs, torch_inputs = create_random_jax_torch(batch_size, seq_len, self.config.hidden_size_for_vit) | ||
|
|
||
| torch_output = torch_model(torch_inputs) | ||
| jax_output = jax_model(jax_inputs) |
There was a problem hiding this comment.
🟡 The relative tolerance rtol=5e-2 (5%) is quite high for a unit test comparing float32 implementations. While the cumulative error in a 27-layer model and the scaling by sqrt(d_model) in VisionExit might justify some variance, it would be ideal to see if this can be tightened (e.g., to 1e-3 or 1e-2) to ensure higher precision matches.
If this tolerance is the tightest possible due to framework differences, consider adding a brief comment explaining the reason.
| jax_output = jax_model(jax_inputs) | |
| assert_all_close_jax_torch( | |
| jax_output_squeezed, | |
| torch_lhs, | |
| rtol=5e-2, | |
| atol=5e-2, | |
| error_msg="Gemma4VisionEncoderLayer end-to-end outputs differ", | |
| ) |
aireenmei
left a comment
There was a problem hiding this comment.
Thanks for adding tests!
| # ============================================================================= | ||
|
|
||
|
|
||
| def copy_rmsnorm_weights(torch_norm, jax_norm): |
There was a problem hiding this comment.
Seems not very specific to gemma4? Should we move to tests.utils.multimodal_test_utils for re-use?
| seq_len = 42 * 60 | ||
| dummy_shape = (batch_size, seq_len, self.config.hidden_size_for_vit) | ||
|
|
||
| jax_model = JaxGemma4Attention( |
There was a problem hiding this comment.
How about JaxGemma4VisionAttention?
Description
Add layer-wise unit tests comparing MaxText and PyTorch implementations for Gemma 4 vision components, including
VisionEntry,Gemma4VisionRotaryEmbedding,Gemma4Attention,Gemma4EncoderBlock,VisionExit,Gemma4VisionEncoderLayer, andGemma4VisionProjector(still offline).TODO: A follow-up PR to transform such layer-wise tests to PyTorch-free and runnable on CI.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.