Skip to content

Commit 9b3a06c

Browse files
committed
Added metropolis hastings sampling for jax
1 parent d277a66 commit 9b3a06c

1 file changed

Lines changed: 61 additions & 4 deletions

File tree

pyrecest/distributions/abstract_manifold_specific_distribution.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyrecest.backend
66

77
# pylint: disable=no-name-in-module,no-member
8-
from pyrecest.backend import empty, int32, int64, log, random, squeeze
8+
from pyrecest.backend import empty, int32, int64, log, random, squeeze, where
99

1010

1111
class AbstractManifoldSpecificDistribution(ABC):
@@ -81,9 +81,15 @@ def sample_metropolis_hastings(
8181
):
8282
# jscpd:ignore-end
8383
"""Metropolis Hastings sampling algorithm."""
84-
assert (
85-
pyrecest.backend.__backend_name__ != "jax"
86-
), "Not supported on this backend"
84+
if pyrecest.backend.__backend_name__ == "jax":
85+
return sample_metropolis_hastings_jax(
86+
log_pdf=self.ln_pdf,
87+
proposal=proposal,
88+
start_point=start_point,
89+
n=n,
90+
burn_in=burn_in,
91+
skipping=skipping,
92+
)
8793
if proposal is None or start_point is None:
8894
raise NotImplementedError(
8995
"Default proposals and starting points should be set in inheriting classes."
@@ -115,3 +121,54 @@ def sample_metropolis_hastings(
115121

116122
relevant_samples = s[burn_in::skipping, :]
117123
return squeeze(relevant_samples)
124+
125+
126+
def sample_metropolis_hastings_jax(
127+
log_pdf, # function: R^d -> scalar log density
128+
proposal, # function: (key, x) -> x_proposed
129+
start_point,
130+
n: int,
131+
burn_in: int = 10,
132+
skipping: int = 5,
133+
):
134+
"""
135+
Metropolis-Hastings sampler in JAX, jittable and vectorizable.
136+
137+
Args:
138+
key: jax.random.PRNGKey
139+
log_pdf: callable taking an array x and returning log p(x)
140+
proposal: callable taking (key, x) and returning a proposed x'
141+
start_point: initial state x_0 (array)
142+
n: number of samples to return (after burn-in and thinning)
143+
burn_in: number of initial states to discard
144+
skipping: thinning factor (keep every `skipping`-th state)
145+
"""
146+
from jax import lax
147+
148+
total_steps = burn_in + n * skipping
149+
150+
def one_step(carry, _):
151+
x, log_px = carry
152+
153+
# Propose new state
154+
x_prop = proposal(x)
155+
log_px_prop = log_pdf(x_prop)
156+
157+
# Acceptance probability in log-space
158+
log_alpha = log_px_prop - log_px # assumes symmetric proposal
159+
log_alpha = min(0.0, log_alpha) # cap at log(1) = 0
160+
u = random.uniform() # in (0, 1)
161+
accept = log(u) < log_alpha
162+
163+
# If accepted, move to proposal; else stay
164+
x_new = where(accept, x_prop, x)
165+
log_px_new = where(accept, log_px_prop, log_px)
166+
167+
return (x_new, log_px_new), x_new
168+
169+
init_carry = (start_point, log_pdf(start_point))
170+
(_, _, _), chain = lax.scan(one_step, init_carry, None, length=total_steps)
171+
172+
# Burn-in + thinning
173+
samples = chain[burn_in::skipping]
174+
return samples # shape: (n, *start_point.shape)

0 commit comments

Comments
 (0)