Fix enable_validation#2201
Conversation
|
@fehiepsi : shall we keep |
|
We decided to enable validation by default. I think setting the attribute to the global one might fix the issue. |
juanitorduz
left a comment
There was a problem hiding this comment.
Minor point, right @Qazalbash ?
| from . import constraints | ||
|
|
||
| _VALIDATION_ENABLED = True | ||
| _VALIDATION_ENABLED = False |
There was a problem hiding this comment.
Based on the comment above, we should keep _VALIDATION_ENABLED = True
There was a problem hiding this comment.
This is fixed. Note that this changes the behavior by default and might be a breaking change for some codebases. However, I find that it is much better to have validation enabled by default, as it was quite unintuitive in the first place.
Qazalbash
left a comment
There was a problem hiding this comment.
@juanitorduz Yes, the global variable value should be reverted.
Qazalbash
left a comment
There was a problem hiding this comment.
Thanks @renecotyfanboy
|
I think CI will trigger after @juanitorduz approves. Let's merge this PR only when the latest CI passes. |
|
@renecotyfanboy thanks! Can you push an empty commit (e.g., git commit -m"empty" --allow-empty ) to try to trigger the CI? 🙏 |
|
Enabling validation by default broke some tests in the CI. Should I simply wrap them around |
|
Let's leave things as they were. What do you say @juanitorduz? |
This reverts commit c66e2fb.
|
Could we fix the tests instead? This indicates that we might not use correct params for those tests. |
|
Indeed, that is a better approach! I could help fix these tests if you feel stuck @renecotyfanboy :) |
|
I started fixing a bunch of tests here.
However, I see that it is taking a super long time for me to look at each case, each distribution, and find what is ill-defined in the parameter used in space or in the batch / event shape mismatch. If I am to continue fixing the test, I will probably switch to LLMs to iterate faster. |
|
Thank you @renecotyfanboy ! I think Claude (or similar) can be super useful here! If you need some help please let me know :) |
|
I just tried to run copilot+claude for this and it spent half my monthly credit just to figure out that it should disable validation at runtime while explicitly prompted not to do so. LMAO. This task is too hard for me, and I won't have time to work on this in the next months. Maybe we should disable validation by default (as it is currently the case), and set as a long term goal to fix the tests using proper and valid test inputs. WDYT? |
Thanks for the effort @renecotyfanboy ! Not an easy one ;) Let me try to see if I can make it work (please do not close this PR) |
With validation defaulting to True, fix the remaining failures using valid parameters where possible, and skip validation only where a test or library path intentionally uses invalid / off-manifold values: - test_distribution_constraints: build out-of-bounds params under validation_enabled(False) (the kwarg does not reach compound dists' internals). - test_log_prob_gradient / test_sample_gradient: scope validation off around the finite-difference reference, which perturbs params off their constraint manifold. - AutoLaplaceApproximation.get_posterior: build its MVN with validate_args=False (singular-Hessian fallback intentionally zeroes scale_tril). - EulerMaruyama / GaussianRandomWalk log_prob: internal Normals built from the input value use validate_args=False; the public log_prob still validates value. - HSGP centered approximation: Normal(0, spd) with validate_args=False since spd underflows to 0 for high-frequency basis functions. - compat.param: project a concrete constraint-violating init onto its constraint (matches Pyro semantics), fixing the unnormalized-simplex pyroapi case. - enum_elbo: fix unnormalized guide_probs_a; drop validation_enabled(False) wraps. - HSGP tests: drop validation_enabled(False) wraps. - test_subposterior_structure: use non-degenerate subposteriors. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
pmap-NUTS landed just over rtol=0.06 on the mean (0.937 vs 1.0) because pmap and vmap accumulate floats differently on a borderline estimate. Bump num_samples to 15000 (keeping num_warmup=2000) to shrink the Monte Carlo error; all map_fn/algo combinations now pass with a comfortable margin (worst relative error ~0.023). Note: higher warmup destabilizes the fixed-trajectory HMC chains, so only the sample count is increased. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
ok @renecotyfanboy Its seems its working now :) FYI: @Qazalbash @fehiepsi |
| and not np.all(constraint(val)) | ||
| ): | ||
| transform = biject_to(constraint) | ||
| val = transform(transform.inv(val)) |
There was a problem hiding this comment.
Why this is needed? I don't quite understand the comment above.
There was a problem hiding this comment.
This is something I found while talking with Claude and doing some tests:
Pyro treats a constrained param's init value as a constrained value and, if it does not satisfy the constraint, projects it through the constraint's bijector. NumPyro's param instead returns the raw init at trace time (the constraint is only used later by SVI to build the unconstrained optimizer). So during SVI.init's first guide trace the raw init reaches the model/guide; e.g. pyroapi's test_constraints inits a simplex param with an unnormalized exp(randn(3)), which now trips validation-by-default in dist.Categorical.
Project it here (only when concrete and not already valid) to match Pyro.
It could be I am missing something 🙈 , but it seems this helps make the test pass while having the validation.
There was a problem hiding this comment.
It is unclear why pyro is related here. Does this imply the test is not robust? I think transform(transform.inv(val)) is the same as val for bijector. Do we want to address non bijective transforms?
There was a problem hiding this comment.
It is unclear why pyro is related here.
I just compared it with what Pyro was doing XD to see if there was an issue here.
Does this imply the test is not robust? I think transform(transform.inv(val)) is the same as val for bijector. Do we want to address non-bijective transforms?
Let me investigate . . . 🤔
There was a problem hiding this comment.
Ok! This is how far I got:
-
On
transform(transform.inv(val)) == val: You're right, for a true bijector it's the identity. I confirmed it's a no-op forpositive,real,unit_interval, andgreater_than. It only changesvalfor non-bijective, dimension-changing transforms likesimplex(StickBreaking mapsR^(n-1) ↔the simplex). -
Remark: NumPyro represents a constrained
paramby its unconstrained value, so the constrained value SVI actually optimizes from istransform(transform.inv(init)). I verified this directly: for a simplexparaminitialized withexp(randn(3)), the value SVI uses is[9.99e-1, 1.19e-7, 1.42e-14]: exactly the round-trip (and notably notval/sum). NumPyro'sparamprimitive returns the raw init at trace time, so duringSVI.init's first guide trace,Categoricalreceives the unnormalized vector and trips validation-by-default. The projection just makes the traced value consistent with what SVI uses one step later. It's the same thing Pyro's param store does,
Does it imply the test isn't robust? Somewhat, test_constraints is a smoke test that constrained params don't crash SVI; it never checks the actual values, so the near-one-hot projection "passes" trivially. It's the only test that needs this (enum_elbo's probs are all valid, so its pyro.param calls never enter the projection branch — the not np.all(constraint(val)) guard skips them, which also protects valid boundary values like [[0,1],[1,0]] from being perturbed).
We then have two options:
- Keep the change (it's consistent with SVI's own init, just needs a clearer comment)
xfailthe single vendoredtest_constraintsif you'd prefer no compat-layer behavior change.
Both keep CI green.
WDYT?
There was a problem hiding this comment.
I see. Seems like a test issue where the value is not valid. Could we fix the test instead? Just ensure we provide valid params there. If it is too difficult, just apply this double transform there. I still don't understand why we can't fix this in the test
There was a problem hiding this comment.
ok! Let me try this out (I first needed to understand this edge case)
There was a problem hiding this comment.
We can change simplex param to something other than exp(randn(3))
Replace the validate_args=False fallbacks with root-cause fixes: - HSGP _centered_approximation: clip spd (sqrt spectral density) to a tiny positive floor so the scale stays valid when it underflows to 0. - EulerMaruyama / GaussianRandomWalk log_prob: use Normal's location invariance, evaluating the residual / increments under a zero-mean Normal so the loc stays valid for out-of-support values. The public log_prob's @validate_sample still warns about out-of-support values. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
| # zero-mean Normal. This keeps the loc valid even for out-of-support | ||
| # ``value`` (where ``mu`` would be NaN); the public log_prob's | ||
| # @validate_sample still warns about out-of-support values. | ||
| sde_log_prob = Normal(0.0, sigma).to_event(self.event_dim).log_prob(xt - mu) |
There was a problem hiding this comment.
Nice solution to address NaN mu.
Revert the numpyro/compat/pyro.py constraint projection and instead override the vendored pyroapi test_constraints with a valid simplex init for the guide param `q` (was an unnormalized exp(randn(3))), so it passes with validation enabled by default. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
|
||
|
|
||
| # pyroapi's test_constraints inits the simplex param `q` with an unnormalized | ||
| # exp(randn(3)); use a valid simplex so it passes with validation enabled. |
There was a problem hiding this comment.
exp(randn(3)) for q lives in the vendored upstream pyroapi package (pyroapi/tests/test_svi.py), see https://github.com/pyro-ppl/pyro-api/blob/master/pyroapi/tests/test_svi.py#L168
To provide a valid simplex (as suggested) we therefore redefine test_constraints in our own test/pyroapi/test_pyroapi.py, identical to upstream except q is initialized to a valid simplex [0.4, 0.3, 0.3] instead of the unnormalized exp(randn(3)). Since the test name is imported via from pyroapi.tests import *, redefining it shadows the upstream one, it does not add a test.
|
@fehiepsi is this one good to fo :) ? |
Hi there,
Here is a super small fix for validation. Currently, the validation context is handled using
_VALIDATION_ENABLEDglobal variable. By default_VALIDATION_ENABLED=Trueand_validate_args=Falseby default, which trigger weird behaviors when enabling / disabling the validation through contexts. But since everything is handled through mutating_VALIDATION_ENABLED, it sometimes produces weird behaviors.The simplest repro is :
The current fix produces the expected behavior.