Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,9 @@ def _temme_series_kve(v, z):
z_sq = z * z
logzo2 = jnp.log(z / 2.0)
mu = -v * logzo2
sinc_v = jnp.where(v == 0.0, 1.0, jnp.sin(jnp.pi * v) / (jnp.pi * v))
sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu) / mu)
sinc_v = jnp.sinc(v)
mu_safe = jnp.where(mu != 0, mu, 1.0)
sinhc_mu = jnp.where(mu == 0.0, 1.0, jnp.sinh(mu_safe) / mu_safe)

initial_f = (coeff1 * jnp.cosh(mu) + coeff2 * (-logzo2) * sinhc_mu) / sinc_v
initial_p = 0.5 * jnp.exp(mu) / gamma1pv_inv
Expand Down
6 changes: 3 additions & 3 deletions jax_galsim/core/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False):
The values of the Akima cubic spline at the points x.
"""
xp = jnp.asarray(xp)
# yp = jnp.array(yp) # unused
yp = jnp.asarray(yp)
if fixed_spacing:
dxp = xp[1] - xp[0]
i = jnp.floor((x - xp[0]) / dxp).astype(jnp.int32)
Expand All @@ -160,6 +160,6 @@ def akima_interp(x, xp, yp, coeffs, fixed_spacing=False):
dx3 = dx2 * dx
xval = a[i] + b[i] * dx + c[i] * dx2 + d[i] * dx3

xval = jnp.where(x < xp[0], 0, xval)
xval = jnp.where(x > xp[-1], 0, xval)
xval = jnp.where(x < xp[0], yp[0], xval)
xval = jnp.where(x > xp[-1], yp[-1], xval)
return xval
94 changes: 82 additions & 12 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from jax_galsim.bessel import j0, kv
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.core.interpolate import akima_interp, akima_interp_coeffs
from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.position import PositionD
from jax_galsim.random import UniformDeviate


Expand Down Expand Up @@ -103,12 +103,13 @@ def __init__(
fwhm=fwhm,
)
else:
trunc_ = jnp.where(trunc > 0, trunc, 100.0)
Comment thread
beckermr marked this conversation as resolved.
Outdated
super().__init__(
beta=beta,
scale_radius=(
jax.lax.select(
trunc > 0,
_MoffatCalculateSRFromHLR(half_light_radius, trunc, beta),
_MoffatCalculateSRFromHLR(half_light_radius, trunc_, beta),
half_light_radius
/ jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0),
)
Expand Down Expand Up @@ -281,7 +282,19 @@ def _prefactor(self):
@jax.jit
def _maxk_func(self, k):
return (
jnp.abs(self._kValue(PositionD(x=k, y=0)).real / self.flux)
jnp.abs(
self._kValue_func(
self.beta,
jnp.atleast_1d(k),
self._knorm_bis,
self._knorm,
self._prefactor,
self._maxRrD,
self.trunc,
self._r0,
)[0].real
/ self.flux
)
- self.gsparams.maxk_threshold
)

Expand Down Expand Up @@ -336,19 +349,79 @@ def _xValue(self, pos):
rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta)
)

@staticmethod
@jax.jit
def _kValue_untrunc_func(beta, k, _knorm_bis, _knorm):
"""Non truncated version of _kValue"""
k_ = jnp.where(k > 0, k, 1.0)
return jnp.where(
k > 0,
_knorm_bis * jnp.power(k_, beta - 1.0) * _Knu(beta - 1.0, k_),
_knorm,
)

@staticmethod
@jax.jit
def _kValue_trunc_func(beta, k, _knorm, _prefactor, _maxRrD):
"""Truncated version of _kValue"""
k_ = jnp.where(k <= 50.0, k, 50.0)
return jnp.where(
k <= 50.0,
_knorm * _prefactor * _hankel(k_, beta, _maxRrD),
0.0,
)

@staticmethod
@jax.jit
def _kValue_func(beta, k, _knorm_bis, _knorm, _prefactor, _maxRrD, trunc, _r0):
return jax.lax.cond(
trunc > 0,
lambda x: Moffat._kValue_trunc_func(beta, x, _knorm, _prefactor, _maxRrD),
lambda x: Moffat._kValue_untrunc_func(beta, x, _knorm_bis, _knorm),
k * _r0,
)

@jax.jit
def _kValue_interp_coeffs(self):
# this number of points gets the tests to pass
# I did not investigate further.
n_pts = 5000
k_min = 0
# this is a fudge factor to help numerical convergnce in the tests
# it should not be needed in principle since the profile is not
# evaluated above maxk, but it appears to be needed anyway and
# IDK why
k_max = jnp.minimum(self._maxk * 2, 50.0)
k = jnp.linspace(k_min, k_max, n_pts)
vals = self._kValue_func(
self.beta,
k,
self._knorm_bis,
self._knorm,
self._prefactor,
self._maxRrD,
self.trunc,
self._r0,
)
return k, vals, akima_interp_coeffs(k, vals)

def _kValue_untrunc(self, k):
"""Non truncated version of _kValue"""
k_ = jnp.where(k > 0, k, 1.0)
return jnp.where(
k > 0,
self._knorm_bis * jnp.power(k, self.beta - 1.0) * _Knu(self.beta - 1.0, k),
self._knorm_bis
* jnp.power(k_, self.beta - 1.0)
* _Knu(self.beta - 1.0, k_),
self._knorm,
)

def _kValue_trunc(self, k):
"""Truncated version of _kValue"""
k_ = jnp.where(k <= 50.0, k, 50.0)
return jnp.where(
k <= 50.0,
self._knorm * self._prefactor * _hankel(k, self.beta, self._maxRrD),
self._knorm * self._prefactor * _hankel(k_, self.beta, self._maxRrD),
0.0,
)

Expand All @@ -357,15 +430,12 @@ def _kValue(self, kpos):
"""computation of the Moffat response in k-space with switch of truncated/untracated case
kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
"""
k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq)
k = jnp.sqrt((kpos.x**2 + kpos.y**2))
out_shape = jnp.shape(k)
k = jnp.atleast_1d(k)
res = jax.lax.cond(
self.trunc > 0,
lambda x: self._kValue_trunc(x),
lambda x: self._kValue_untrunc(x),
k,
)
k_, vals_, coeffs = self._kValue_interp_coeffs()
res = akima_interp(k, k_, vals_, coeffs, fixed_spacing=True)

return res.reshape(out_shape)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
Expand Down
43 changes: 43 additions & 0 deletions tests/jax/test_derivs_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest

import jax_galsim as jgs


@pytest.mark.parametrize(
"params,gsobj,args",
[
(["scale_radius", "half_light_radius"], jgs.Spergel, [1.0]),
(["scale_radius", "half_light_radius"], jgs.Exponential, []),
(["sigma", "fwhm", "half_light_radius"], jgs.Gaussian, []),
(["scale_radius", "half_light_radius", "fwhm"], jgs.Moffat, [2.0]),
],
)
def test_deriv_params_gsobject(params, gsobj, args):
val = 2.0
eps = 1e-5

for param in params:
print("\nparam:", param, flush=True)

def _run(val_):
kwargs = {param: val_}
return jnp.max(
gsobj(
*args,
**kwargs,
gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64),
)
.drawImage(nx=48, ny=48, scale=0.2)
.array[24, 24]
** 2
)

gfunc = jax.jit(jax.grad(_run))
gval = gfunc(val)

gfdiff = (_run(val + eps) - _run(val - eps)) / 2.0 / eps

np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6)
79 changes: 47 additions & 32 deletions tests/jax/test_moffat_comp_galsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import jax_galsim as galsim


def test_moffat_comp_galsim_maxk():
psfs = [
@pytest.mark.parametrize(
"psf",
[
# Make sure to include all the specialized betas we have in C++ layer.
# The scale_radius and flux don't matter, but vary themm too.
# Note: We also specialize beta=1, but that seems to be impossible to realize,
Expand All @@ -25,37 +26,51 @@ def test_moffat_comp_galsim_maxk():
galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30),
galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50),
galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000),
]
threshs = [1.0e-3, 1.0e-4, 0.03]
print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk")
for psf in psfs:
for thresh in threshs:
psf = psf.withGSParams(maxk_threshold=thresh)
gpsf = _galsim.Moffat(
beta=psf.beta,
scale_radius=psf.scale_radius,
flux=psf.flux,
trunc=psf.trunc,
)
gpsf = gpsf.withGSParams(maxk_threshold=thresh)
fk = psf.kValue(psf.maxk, 0).real / psf.flux
],
)
@pytest.mark.parametrize("thresh", [1.0e-4, 1.0e-3, 0.03])
def test_moffat_comp_galsim_maxk(psf, thresh):
print(
"\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk", flush=True
)
psf = psf.withGSParams(maxk_threshold=thresh)
gpsf = _galsim.Moffat(
beta=psf.beta,
scale_radius=psf.scale_radius,
flux=psf.flux,
trunc=psf.trunc,
)
gpsf = gpsf.withGSParams(maxk_threshold=thresh)
fk = psf.kValue(psf.maxk, 0).real / psf.flux
maxk_test_val_one = jnp.minimum(1.0, psf.maxk)
maxk_test_val_pone = maxk_test_val_one / 10.0

print(
f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}"
)
np.testing.assert_allclose(
psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5
)
np.testing.assert_allclose(
psf.kValue(0.0, 0.1), gpsf.kValue(0.0, 0.1), rtol=1e-5
)
np.testing.assert_allclose(
psf.kValue(-1.0, 0.0), gpsf.kValue(-1.0, 0.0), rtol=1e-5
)
np.testing.assert_allclose(
psf.kValue(1.0, 0.0), gpsf.kValue(1.0, 0.0), rtol=1e-5
)
np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0)
print(
f"{psf.beta} \t {int(psf.trunc)} \t {thresh:.1e} \t {fk:.3e} \t {psf.maxk:.3e} \t {gpsf.maxk:.3e}",
flush=True,
)
np.testing.assert_allclose(gpsf.maxk, psf.maxk, rtol=0.25, atol=0)
np.testing.assert_allclose(
psf.kValue(0.0, 0.0), gpsf.kValue(0.0, 0.0), rtol=1e-5, atol=1e-8
)
np.testing.assert_allclose(
psf.kValue(0.0, maxk_test_val_pone),
gpsf.kValue(0.0, maxk_test_val_pone),
rtol=1e-5,
atol=1e-8,
)
np.testing.assert_allclose(
psf.kValue(-maxk_test_val_one, 0.0),
gpsf.kValue(-maxk_test_val_one, 0.0),
rtol=1e-5,
atol=1e-8,
)
np.testing.assert_allclose(
psf.kValue(maxk_test_val_one, 0.0),
gpsf.kValue(maxk_test_val_one, 0.0),
rtol=1e-5,
atol=1e-8,
)


@pytest.mark.test_in_float32
Expand Down
Loading