|
5 | 5 | import pyrecest.backend |
6 | 6 |
|
7 | 7 | # pylint: disable=no-name-in-module,no-member |
8 | | -from pyrecest.backend import empty, int32, int64, log, random, squeeze |
| 8 | +from pyrecest.backend import empty, int32, int64, log, random, squeeze, where |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class AbstractManifoldSpecificDistribution(ABC): |
@@ -81,9 +81,15 @@ def sample_metropolis_hastings( |
81 | 81 | ): |
82 | 82 | # jscpd:ignore-end |
83 | 83 | """Metropolis Hastings sampling algorithm.""" |
84 | | - assert ( |
85 | | - pyrecest.backend.__backend_name__ != "jax" |
86 | | - ), "Not supported on this backend" |
| 84 | + if pyrecest.backend.__backend_name__ == "jax": |
| 85 | + return sample_metropolis_hastings_jax( |
| 86 | + log_pdf=self.ln_pdf, |
| 87 | + proposal=proposal, |
| 88 | + start_point=start_point, |
| 89 | + n=n, |
| 90 | + burn_in=burn_in, |
| 91 | + skipping=skipping, |
| 92 | + ) |
87 | 93 | if proposal is None or start_point is None: |
88 | 94 | raise NotImplementedError( |
89 | 95 | "Default proposals and starting points should be set in inheriting classes." |
@@ -115,3 +121,54 @@ def sample_metropolis_hastings( |
115 | 121 |
|
116 | 122 | relevant_samples = s[burn_in::skipping, :] |
117 | 123 | return squeeze(relevant_samples) |
| 124 | + |
| 125 | + |
| 126 | +def sample_metropolis_hastings_jax( |
| 127 | + log_pdf, # function: R^d -> scalar log density |
| 128 | + proposal, # function: (key, x) -> x_proposed |
| 129 | + start_point, |
| 130 | + n: int, |
| 131 | + burn_in: int = 10, |
| 132 | + skipping: int = 5, |
| 133 | +): |
| 134 | + """ |
| 135 | + Metropolis-Hastings sampler in JAX, jittable and vectorizable. |
| 136 | +
|
| 137 | + Args: |
| 138 | + key: jax.random.PRNGKey |
| 139 | + log_pdf: callable taking an array x and returning log p(x) |
| 140 | + proposal: callable taking (key, x) and returning a proposed x' |
| 141 | + start_point: initial state x_0 (array) |
| 142 | + n: number of samples to return (after burn-in and thinning) |
| 143 | + burn_in: number of initial states to discard |
| 144 | + skipping: thinning factor (keep every `skipping`-th state) |
| 145 | + """ |
| 146 | + from jax import lax |
| 147 | + |
| 148 | + total_steps = burn_in + n * skipping |
| 149 | + |
| 150 | + def one_step(carry, _): |
| 151 | + x, log_px = carry |
| 152 | + |
| 153 | + # Propose new state |
| 154 | + x_prop = proposal(x) |
| 155 | + log_px_prop = log_pdf(x_prop) |
| 156 | + |
| 157 | + # Acceptance probability in log-space |
| 158 | + log_alpha = log_px_prop - log_px # assumes symmetric proposal |
| 159 | + log_alpha = min(0.0, log_alpha) # cap at log(1) = 0 |
| 160 | + u = random.uniform() # in (0, 1) |
| 161 | + accept = log(u) < log_alpha |
| 162 | + |
| 163 | + # If accepted, move to proposal; else stay |
| 164 | + x_new = where(accept, x_prop, x) |
| 165 | + log_px_new = where(accept, log_px_prop, log_px) |
| 166 | + |
| 167 | + return (x_new, log_px_new), x_new |
| 168 | + |
| 169 | + init_carry = (start_point, log_pdf(start_point)) |
| 170 | + (_, _, _), chain = lax.scan(one_step, init_carry, None, length=total_steps) |
| 171 | + |
| 172 | + # Burn-in + thinning |
| 173 | + samples = chain[burn_in::skipping] |
| 174 | + return samples # shape: (n, *start_point.shape) |
0 commit comments