Add vmap reset() regression test for nnx.metrics (#5483)#5491
Open
chenkuanliao wants to merge 1 commit into
Open
Add vmap reset() regression test for nnx.metrics (#5483)#5491chenkuanliao wants to merge 1 commit into
chenkuanliao wants to merge 1 commit into
Conversation
db1487c to
a71a8f3
Compare
Construct a MultiMetric(loss=Average) under nnx.vmap so its state carries a leading batch axis, call reset(), then run a vmapped update and assert compute() returns the expected per-batch means. Locks in the shape-preserving reset() behavior tracked in google#5483.
a71a8f3 to
ca14973
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This adds a regression test for the
nnx.metricsbehavior reported in #5483.nnx.metrics.Average.reset()(andWelford.reset()) re-zero their state by assigning a scalar (jnp.array(0, ...)). When a metric is built undernnx.vmap, its state carries a leading batch axis of shape(N,). On the reporter's versions (flax 0.12.0 / jax 0.7.2) the scalar assignment replaced the whole array, collapsing(N,)to(), so a later vmappedupdatefailed with:While investigating I found the crash no longer reproduces on
main: theVariablestate layer has since changed so a full-index assignment broadcasts a scalar into the existing array in place instead of replacing it, soreset()preserves the batch axis today. I confirmed it still reproduces on flax 0.12.0 / jax 0.7.2 and is gone onmainindependent of any single commit.Since there is no longer a live crash to fix on
main, this PR is scoped as a regression test only, to lock in the current vmapreset()behavior so it can't silently break again.What's in this PR
tests/nnx/metrics_test.py—test_vmap_reset_preserves_shape: construct aMultiMetric(loss=Average('loss'))undernnx.vmap(state shape(N,)), callreset(), assert the state stays shape(N,), then run a vmappedupdateand assertcompute()['loss']returns the expected per-batch values.Checklist
metricsdoes not work well withvmap#5483