Skip to content

Add Gemma4 layer-wise unit tests#3905

Open
hengtaoguo wants to merge 1 commit into
mainfrom
hengtaoguo-gemma4
Open

Add Gemma4 layer-wise unit tests#3905
hengtaoguo wants to merge 1 commit into
mainfrom
hengtaoguo-gemma4

Conversation

@hengtaoguo
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo commented May 14, 2026

Description

Add layer-wise unit tests comparing MaxText and PyTorch implementations for Gemma 4 vision components, including VisionEntry, Gemma4VisionRotaryEmbedding, Gemma4Attention, Gemma4EncoderBlock, VisionExit, Gemma4VisionEncoderLayer, and Gemma4VisionProjector (still offline).

TODO: A follow-up PR to transform such layer-wise tests to PyTorch-free and runnable on CI.

Tests

python -m pytest tests/unit/gemma4_layers_test.py -vv -s
collected 7 items                                               

tests/unit/gemma4_layers_test.py::TestGemma4VisionEntry::test_vision_entry_matches_torch W0514 05:57:15.485060 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED
tests/unit/gemma4_layers_test.py::TestGemma4VisionRotaryEmbedding::test_rotary_embedding_matches_torch W0514 05:57:20.120236 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED
tests/unit/gemma4_layers_test.py::TestGemma4VisionAttention::test_attention_matches_torch W0514 05:57:23.553000 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED
tests/unit/gemma4_layers_test.py::TestGemma4VisionEncoderBlock::test_encoder_block_matches_torch W0514 05:57:29.738181 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED
tests/unit/gemma4_layers_test.py::TestGemma4VisionExit::test_vision_exit_matches_torch W0514 05:57:31.237544 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED
tests/unit/gemma4_layers_test.py::TestGemma4VisionEncoderEndToEnd::test_vision_encoder_matches_torch W0514 05:57:44.092117 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED
tests/unit/gemma4_layers_test.py::TestGemma4VisionProjector::test_vision_projector_matches_torch W0514 05:57:45.259719 2196772 pjrt_executable.cc:642] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
PASSED

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 14, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The Pull Request introduces comprehensive layer-wise unit tests for the Gemma 4 vision components, comparing the MaxText (JAX) implementation against the PyTorch reference from the transformers library. The tests cover all key layers, including VisionEntry, RotaryEmbedding, Attention, EncoderBlock, and the full VisionEncoderLayer, ensuring numerical parity across frameworks.

🔍 General Feedback

  • Completeness: The coverage is excellent, including both individual components and end-to-end encoder tests.
  • Numerical Parity: Most individual layer tests use a tight tolerance (1e-3), which is a strong indicator of implementation correctness.
  • Standardized Helpers: The use of shared multimodal test utilities and clear weight-copying functions makes the tests easy to follow and maintain.
  • CI Status: Note that the new test file is added to pytest.ini's ignore list, which is consistent with the stated TODO to transform these into PyTorch-free tests for CI compatibility.

float32_logits=True,
float32_qk_product=True,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 This variable is defined but not used in the pyconfig.initialize call below. Consider removing it if it's not needed.

Suggested change
base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml")
jax_config = pyconfig.initialize(

jax_inputs, torch_inputs = create_random_jax_torch(batch_size, seq_len, self.config.hidden_size_for_vit)

torch_output = torch_model(torch_inputs)
jax_output = jax_model(jax_inputs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The relative tolerance rtol=5e-2 (5%) is quite high for a unit test comparing float32 implementations. While the cumulative error in a 27-layer model and the scaling by sqrt(d_model) in VisionExit might justify some variance, it would be ideal to see if this can be tightened (e.g., to 1e-3 or 1e-2) to ensure higher precision matches.

If this tolerance is the tightest possible due to framework differences, consider adding a brief comment explaining the reason.

Suggested change
jax_output = jax_model(jax_inputs)
assert_all_close_jax_torch(
jax_output_squeezed,
torch_lhs,
rtol=5e-2,
atol=5e-2,
error_msg="Gemma4VisionEncoderLayer end-to-end outputs differ",
)

Copy link
Copy Markdown
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

# =============================================================================


def copy_rmsnorm_weights(torch_norm, jax_norm):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems not very specific to gemma4? Should we move to tests.utils.multimodal_test_utils for re-use?

seq_len = 42 * 60
dummy_shape = (batch_size, seq_len, self.config.hidden_size_for_vit)

jax_model = JaxGemma4Attention(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about JaxGemma4VisionAttention?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants