Skip to content

Fix enable_validation#2201

Merged
fehiepsi merged 15 commits into
pyro-ppl:masterfrom
renecotyfanboy:master
Jun 8, 2026
Merged

Fix enable_validation#2201
fehiepsi merged 15 commits into
pyro-ppl:masterfrom
renecotyfanboy:master

Conversation

@renecotyfanboy

Copy link
Copy Markdown
Contributor

Hi there,
Here is a super small fix for validation. Currently, the validation context is handled using _VALIDATION_ENABLED global variable. By default _VALIDATION_ENABLED=True and _validate_args=False by default, which trigger weird behaviors when enabling / disabling the validation through contexts. But since everything is handled through mutating _VALIDATION_ENABLED, it sometimes produces weird behaviors.

The simplest repro is :

print(dist.Uniform(0, 1).log_prob(-0.5)) # -0.0 (expected)

with numpyro.validation_enabled(True):
    print(dist.Uniform(0, 1).log_prob(-0.5)) # -inf (expected)

print(dist.Uniform(0, 1).log_prob(-0.5)) # -inf (expect -0.0)

The current fix produces the expected behavior.

@juanitorduz

juanitorduz commented May 26, 2026

Copy link
Copy Markdown
Collaborator

@fehiepsi : shall we keep _VALIDATION_ENABLED and _validate_args equal to True?

@fehiepsi

Copy link
Copy Markdown
Member

We decided to enable validation by default. I think setting the attribute to the global one might fix the issue.

@juanitorduz juanitorduz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point, right @Qazalbash ?

Comment thread numpyro/distributions/distribution.py Outdated
from . import constraints

_VALIDATION_ENABLED = True
_VALIDATION_ENABLED = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the comment above, we should keep _VALIDATION_ENABLED = True

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed. Note that this changes the behavior by default and might be a breaking change for some codebases. However, I find that it is much better to have validation enabled by default, as it was quite unintuitive in the first place.

@Qazalbash Qazalbash left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juanitorduz Yes, the global variable value should be reverted.

@Qazalbash Qazalbash left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Qazalbash

Copy link
Copy Markdown
Collaborator

I think CI will trigger after @juanitorduz approves. Let's merge this PR only when the latest CI passes.

@juanitorduz

Copy link
Copy Markdown
Collaborator

@renecotyfanboy thanks! Can you push an empty commit (e.g., git commit -m"empty" --allow-empty ) to try to trigger the CI? 🙏

@renecotyfanboy

Copy link
Copy Markdown
Contributor Author

Enabling validation by default broke some tests in the CI. Should I simply wrap them around with numpyro.validation_enabled(False) ? I am not sure if I am willing to fix them properly 😆

@Qazalbash

Copy link
Copy Markdown
Collaborator

Let's leave things as they were. What do you say @juanitorduz?

@fehiepsi

Copy link
Copy Markdown
Member

Could we fix the tests instead? This indicates that we might not use correct params for those tests.

@juanitorduz

juanitorduz commented May 29, 2026

Copy link
Copy Markdown
Collaborator

Indeed, that is a better approach! I could help fix these tests if you feel stuck @renecotyfanboy :)

@renecotyfanboy

Copy link
Copy Markdown
Contributor Author

I started fixing a bunch of tests here.

  • Disabled validation for the categorical variables since most of them are not normalized in the tests
  • Disabled validation for the Gaussian process contributed tests
  • Used better MCMC params to ensure convergence for the functional map tests
  • Used better random params for the negative binomial log prob tests

However, I see that it is taking a super long time for me to look at each case, each distribution, and find what is ill-defined in the parameter used in space or in the batch / event shape mismatch. If I am to continue fixing the test, I will probably switch to LLMs to iterate faster.

@juanitorduz

Copy link
Copy Markdown
Collaborator

Thank you @renecotyfanboy ! I think Claude (or similar) can be super useful here! If you need some help please let me know :)

@renecotyfanboy

renecotyfanboy commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

I just tried to run copilot+claude for this and it spent half my monthly credit just to figure out that it should disable validation at runtime while explicitly prompted not to do so. LMAO. This task is too hard for me, and I won't have time to work on this in the next months. Maybe we should disable validation by default (as it is currently the case), and set as a long term goal to fix the tests using proper and valid test inputs. WDYT?

@juanitorduz

Copy link
Copy Markdown
Collaborator

I just tried to run copilot+claude for this and it spent half my monthly credit just to figure out that it should disable validation at runtime while explicitly prompted not to do so. LMAO. This task is too hard for me, and I won't have time to work on this in the next months. Maybe we should disable validation by default (as it is currently the case), and set as a long term goal to fix the tests using proper and valid test inputs. WDYT?

Thanks for the effort @renecotyfanboy ! Not an easy one ;)

Let me try to see if I can make it work (please do not close this PR)

With validation defaulting to True, fix the remaining failures using valid
parameters where possible, and skip validation only where a test or library
path intentionally uses invalid / off-manifold values:

- test_distribution_constraints: build out-of-bounds params under
  validation_enabled(False) (the kwarg does not reach compound dists' internals).
- test_log_prob_gradient / test_sample_gradient: scope validation off around the
  finite-difference reference, which perturbs params off their constraint manifold.
- AutoLaplaceApproximation.get_posterior: build its MVN with validate_args=False
  (singular-Hessian fallback intentionally zeroes scale_tril).
- EulerMaruyama / GaussianRandomWalk log_prob: internal Normals built from the
  input value use validate_args=False; the public log_prob still validates value.
- HSGP centered approximation: Normal(0, spd) with validate_args=False since spd
  underflows to 0 for high-frequency basis functions.
- compat.param: project a concrete constraint-violating init onto its constraint
  (matches Pyro semantics), fixing the unnormalized-simplex pyroapi case.
- enum_elbo: fix unnormalized guide_probs_a; drop validation_enabled(False) wraps.
- HSGP tests: drop validation_enabled(False) wraps.
- test_subposterior_structure: use non-degenerate subposteriors.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@juanitorduz juanitorduz added the enhancement New feature or request label Jun 5, 2026
@juanitorduz juanitorduz self-assigned this Jun 5, 2026
@juanitorduz juanitorduz self-requested a review June 5, 2026 09:38
pmap-NUTS landed just over rtol=0.06 on the mean (0.937 vs 1.0) because pmap
and vmap accumulate floats differently on a borderline estimate. Bump
num_samples to 15000 (keeping num_warmup=2000) to shrink the Monte Carlo error;
all map_fn/algo combinations now pass with a comfortable margin (worst relative
error ~0.023). Note: higher warmup destabilizes the fixed-trajectory HMC chains,
so only the sample count is increased.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@juanitorduz

Copy link
Copy Markdown
Collaborator

ok @renecotyfanboy Its seems its working now :)

FYI: @Qazalbash @fehiepsi

@juanitorduz juanitorduz requested review from Qazalbash and fehiepsi June 5, 2026 11:37
Comment thread numpyro/compat/pyro.py Outdated
and not np.all(constraint(val))
):
transform = biject_to(constraint)
val = transform(transform.inv(val))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is needed? I don't quite understand the comment above.

@juanitorduz juanitorduz Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I found while talking with Claude and doing some tests:

Pyro treats a constrained param's init value as a constrained value and, if it does not satisfy the constraint, projects it through the constraint's bijector. NumPyro's param instead returns the raw init at trace time (the constraint is only used later by SVI to build the unconstrained optimizer). So during SVI.init's first guide trace the raw init reaches the model/guide; e.g. pyroapi's test_constraints inits a simplex param with an unnormalized exp(randn(3)), which now trips validation-by-default in dist.Categorical.
Project it here (only when concrete and not already valid) to match Pyro.

It could be I am missing something 🙈 , but it seems this helps make the test pass while having the validation.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unclear why pyro is related here. Does this imply the test is not robust? I think transform(transform.inv(val)) is the same as val for bijector. Do we want to address non bijective transforms?

@juanitorduz juanitorduz Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unclear why pyro is related here.

I just compared it with what Pyro was doing XD to see if there was an issue here.

Does this imply the test is not robust? I think transform(transform.inv(val)) is the same as val for bijector. Do we want to address non-bijective transforms?

Let me investigate . . . 🤔

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! This is how far I got:

  • On transform(transform.inv(val)) == val: You're right, for a true bijector it's the identity. I confirmed it's a no-op for positive, real, unit_interval, and greater_than. It only changes val for non-bijective, dimension-changing transforms like simplex (StickBreaking maps R^(n-1) ↔ the simplex).

  • Remark: NumPyro represents a constrained param by its unconstrained value, so the constrained value SVI actually optimizes from is transform(transform.inv(init)). I verified this directly: for a simplex param initialized with exp(randn(3)), the value SVI uses is [9.99e-1, 1.19e-7, 1.42e-14]: exactly the round-trip (and notably not val/sum). NumPyro's param primitive returns the raw init at trace time, so during SVI.init's first guide trace, Categorical receives the unnormalized vector and trips validation-by-default. The projection just makes the traced value consistent with what SVI uses one step later. It's the same thing Pyro's param store does,

Does it imply the test isn't robust? Somewhat, test_constraints is a smoke test that constrained params don't crash SVI; it never checks the actual values, so the near-one-hot projection "passes" trivially. It's the only test that needs this (enum_elbo's probs are all valid, so its pyro.param calls never enter the projection branch — the not np.all(constraint(val)) guard skips them, which also protects valid boundary values like [[0,1],[1,0]] from being perturbed).

We then have two options:

  • Keep the change (it's consistent with SVI's own init, just needs a clearer comment)
  • xfail the single vendored test_constraints if you'd prefer no compat-layer behavior change.

Both keep CI green.

WDYT?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Seems like a test issue where the value is not valid. Could we fix the test instead? Just ensure we provide valid params there. If it is too difficult, just apply this double transform there. I still don't understand why we can't fix this in the test

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! Let me try this out (I first needed to understand this edge case)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change simplex param to something other than exp(randn(3))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in e03f20e

Comment thread numpyro/contrib/hsgp/approximation.py Outdated
Comment thread numpyro/distributions/continuous.py Outdated
Comment thread numpyro/distributions/continuous.py Outdated
Replace the validate_args=False fallbacks with root-cause fixes:

- HSGP _centered_approximation: clip spd (sqrt spectral density) to a tiny
  positive floor so the scale stays valid when it underflows to 0.
- EulerMaruyama / GaussianRandomWalk log_prob: use Normal's location
  invariance, evaluating the residual / increments under a zero-mean Normal so
  the loc stays valid for out-of-support values. The public log_prob's
  @validate_sample still warns about out-of-support values.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@fehiepsi fehiepsi left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Thanks both!

Comment thread numpyro/contrib/hsgp/approximation.py Outdated
# zero-mean Normal. This keeps the loc valid even for out-of-support
# ``value`` (where ``mu`` would be NaN); the public log_prob's
# @validate_sample still warns about out-of-support values.
sde_log_prob = Normal(0.0, sigma).to_event(self.event_dim).log_prob(xt - mu)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice solution to address NaN mu.

@juanitorduz juanitorduz requested a review from fehiepsi June 5, 2026 14:51
juanitorduz and others added 2 commits June 5, 2026 16:54
Revert the numpyro/compat/pyro.py constraint projection and instead override
the vendored pyroapi test_constraints with a valid simplex init for the guide
param `q` (was an unnormalized exp(randn(3))), so it passes with validation
enabled by default.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>


# pyroapi's test_constraints inits the simplex param `q` with an unnormalized
# exp(randn(3)); use a valid simplex so it passes with validation enabled.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp(randn(3)) for q lives in the vendored upstream pyroapi package (pyroapi/tests/test_svi.py), see https://github.com/pyro-ppl/pyro-api/blob/master/pyroapi/tests/test_svi.py#L168

To provide a valid simplex (as suggested) we therefore redefine test_constraints in our own test/pyroapi/test_pyroapi.py, identical to upstream except q is initialized to a valid simplex [0.4, 0.3, 0.3] instead of the unnormalized exp(randn(3)). Since the test name is imported via from pyroapi.tests import *, redefining it shadows the upstream one, it does not add a test.

@juanitorduz juanitorduz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment

@juanitorduz

Copy link
Copy Markdown
Collaborator

@fehiepsi is this one good to fo :) ?

@fehiepsi fehiepsi merged commit 6b6caf6 into pyro-ppl:master Jun 8, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants