diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index fe596398..5dfcb34a 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -72,7 +72,13 @@ def body(i, xcur): return jax.lax.fori_loop(0, 100, body, re, unroll=True) -@implements(_galsim.Moffat) +@implements( + _galsim.Moffat, + lax_description="""\ +The JAX version of the Moffat profile uses `None` as the default value for the +`trunc` parameter to indicate no truncation of the profile. +""", +) @register_pytree_node_class class Moffat(GSObject): _is_axisymmetric = True @@ -85,7 +91,7 @@ def __init__( scale_radius=None, half_light_radius=None, fwhm=None, - trunc=0.0, + trunc=None, flux=1.0, gsparams=None, ): @@ -103,16 +109,16 @@ def __init__( fwhm=fwhm, ) else: + if trunc is not None: + sr_val = _MoffatCalculateSRFromHLR(half_light_radius, trunc, beta) + else: + sr_val = half_light_radius / jnp.sqrt( + jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0 + ) + super().__init__( beta=beta, - scale_radius=( - jax.lax.select( - trunc > 0, - _MoffatCalculateSRFromHLR(half_light_radius, trunc, beta), - half_light_radius - / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), - ) - ), + scale_radius=sr_val, trunc=trunc, flux=flux, gsparams=gsparams, @@ -183,13 +189,12 @@ def _inv_r0_sq(self): @property def _maxRrD(self): """maxR/rd ; fluxFactor Integral of total flux in terms of 'rD' units.""" - return jax.lax.select( - self.trunc > 0.0, - self.trunc * self._inv_r0, - jnp.sqrt( + if self.trunc is not None: + return self.trunc * self._inv_r0 + else: + return jnp.sqrt( jnp.power(self.gsparams.xvalue_accuracy, 1.0 / (1.0 - self.beta)) - 1.0 - ), - ) + ) @property def _maxR(self): @@ -202,11 +207,10 @@ def _maxRrD_sq(self): @property def _fluxFactor(self): - return jax.lax.select( - self.trunc > 0.0, - 1.0 - jnp.power(1 + self._maxRrD * self._maxRrD, (1.0 - self.beta)), - 1.0, - ) + if self.trunc is not None: + return 1.0 - jnp.power(1 + self._maxRrD * self._maxRrD, (1.0 - self.beta)) + else: + return 1.0 @property @implements(_galsim.moffat.Moffat.half_light_radius) @@ -360,12 +364,11 @@ def _kValue(self, kpos): k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - res = jax.lax.cond( - self.trunc > 0, - lambda x: self._kValue_trunc(x), - lambda x: self._kValue_untrunc(x), - k, - ) + if self.trunc is not None: + res = self._kValue_trunc(k) + else: + res = self._kValue_untrunc(k) + return res.reshape(out_shape) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/tests/GalSim b/tests/GalSim index 3251a393..f2daf0aa 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 3251a393bf7ea94fe9ccda3508bc7db722eca1cf +Subproject commit f2daf0aacfd6751a2fe68ac7fb0400255f6dbaa2 diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 04376cd3..02972caa 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -35,13 +35,13 @@ def test_moffat_comp_galsim_maxk(): beta=psf.beta, scale_radius=psf.scale_radius, flux=psf.flux, - trunc=psf.trunc, + trunc=psf.trunc or 0.0, ) gpsf = gpsf.withGSParams(maxk_threshold=thresh) fk = psf.kValue(psf.maxk, 0).real / psf.flux print( - f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" + f"{psf.beta} \t {int(psf.trunc or 0.0)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}" ) np.testing.assert_allclose( psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5 diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index 28db011e..a41d81e3 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -7,7 +7,7 @@ def duplicate(params): - return {x: y * e for x, y in params.items()} + return {x: y * e if y is not None else [None] * 2 for x, y in params.items()} def test_gaussian_vmapping():