diff --git a/jax_galsim/box.py b/jax_galsim/box.py index 95b3d373..2e5147a1 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -79,17 +79,23 @@ def _max_sb(self): return self.flux / (self.width * self.height) def _xValue(self, pos): + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): norm = self.flux / (self.width * self.height) return jnp.where( - 2.0 * jnp.abs(pos.x) < self.width, - jnp.where(2.0 * jnp.abs(pos.y) < self.height, norm, 0.0), + 2.0 * jnp.abs(x) < self.width, + jnp.where(2.0 * jnp.abs(y) < self.height, norm, 0.0), 0.0, ) def _kValue(self, kpos): + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): _wo2pi = self.width / (2.0 * jnp.pi) _ho2pi = self.height / (2.0 * jnp.pi) - return self.flux * jnp.sinc(kpos.x * _wo2pi) * jnp.sinc(kpos.y * _ho2pi) + return self.flux * jnp.sinc(kx * _wo2pi) * jnp.sinc(ky * _ho2pi) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): _jac = jnp.eye(2) if jac is None else jac diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..506890bb 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -292,10 +292,13 @@ def _xValue(self, pos): raise NotImplementedError("Real-space convolutions are not implemented") def _kValue(self, kpos): - kv_list = [ - obj.kValue(kpos) for obj in self.obj_list - ] # In GalSim one uses obj.kValue - return jnp.prod(jnp.array(kv_list)) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + result = self.obj_list[0]._kValue_array(kx, ky) + for obj in self.obj_list[1:]: + result = result * obj._kValue_array(kx, ky) + return result def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): raise NotImplementedError("Real-space convolutions are not implemented") @@ -313,10 +316,23 @@ def _shoot(self, photons, rng): photons.convolve(p1, rng) def _drawKImage(self, image, jac=None): + from jax_galsim import Image + image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image *= obj._drawKImage(image, jac) + # Use a fresh blank image with matching metadata to sever + # the false AD dependency on the galaxy-filled array. + # The PSF's _drawKImage only uses image metadata (bounds, + # scale, wcs), never the array data. + blank = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale) + obj_kimage = obj._drawKImage(blank, jac) + image = Image( + array=image.array * obj_kimage.array, + bounds=image.bounds, + wcs=image.wcs, + _check_bounds=False, + ) return image def tree_flatten(self): @@ -474,10 +490,13 @@ def _max_sb(self): return -self.orig_obj.max_sb / self.orig_obj.flux**2 def _kValue(self, pos): + return self._kValue_array(pos.x, pos.y) + + def _kValue_array(self, kx, ky): # Really, for very low original kvalues, this gets very high, which can be unstable # in the presence of noise. So if the original value is less than min_acc_kvalue, # we instead just return 1/min_acc_kvalue rather than the real inverse. - kval = self.orig_obj._kValue(pos) + kval = self.orig_obj._kValue_array(kx, ky) return jnp.where( jnp.abs(kval) < self._min_acc_kvalue, self._inv_min_acc_kvalue, diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index a197d151..f8d4d79a 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -11,7 +11,7 @@ def draw_by_xValue( ): """Utility function to draw a real-space GSObject into an Image.""" # putting the import here to avoid circular imports - from jax_galsim import Image, PositionD + from jax_galsim import Image # Applies flux scaling to compensate for pixel scale # See SBProfile.draw() @@ -29,9 +29,7 @@ def draw_by_xValue( flux_scaling *= jnp.exp(logdet) # Draw the object - im = jax.vmap(lambda *args: gsobject._xValue(PositionD(*args)))( - coords[..., 0], coords[..., 1] - ) + im = gsobject._xValue_array(coords[..., 0], coords[..., 1]) # Apply the flux scaling im = (im * flux_scaling).astype(image.dtype) @@ -42,7 +40,7 @@ def draw_by_xValue( def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): # putting the import here to avoid circular imports - from jax_galsim import Image, PositionD + from jax_galsim import Image # Create an array of coordinates coords = jnp.stack(image.get_pixel_centers(), axis=-1) @@ -50,10 +48,8 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): coords = jnp.dot(coords, jacobian) # Draw the object - im = jax.vmap(lambda *args: gsobject._kValue(PositionD(*args)))( - coords[..., 0], coords[..., 1] - ) - im = (im).astype(image.dtype) + im = gsobject._kValue_array(coords[..., 0], coords[..., 1]) + im = im.astype(image.dtype) # Return an image return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False) @@ -61,23 +57,18 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): # putting the import here to avoid circular imports - from jax_galsim import Image, PositionD + from jax_galsim import Image # Create an array of coordinates kcoords = jnp.stack(image.get_pixel_centers(), axis=-1) kcoords = kcoords * image.scale # Scale by the image pixel scale kcoords = jnp.dot(kcoords, jacobian) - cenx, ceny = offset.x, offset.y - # flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) ) - # NB: seems that tere is no jax.lax.polar equivalent to c++ std::polar function - def phase(kpos): - arg = -(kpos.x * cenx + kpos.y * ceny) - return jnp.cos(arg) + 1j * jnp.sin(arg) + # flux Exp(-i (kx cx + ky cy) ) + kx, ky = kcoords[..., 0], kcoords[..., 1] + arg = -(kx * offset.x + ky * offset.y) + im_phase = jnp.cos(arg) + 1j * jnp.sin(arg) - im_phase = jax.vmap(lambda *args: phase(PositionD(*args)))( - kcoords[..., 0], kcoords[..., 1] - ) return Image( array=image.array * im_phase, bounds=image.bounds, diff --git a/jax_galsim/deltafunction.py b/jax_galsim/deltafunction.py index 71043e96..f32d3e10 100644 --- a/jax_galsim/deltafunction.py +++ b/jax_galsim/deltafunction.py @@ -1,5 +1,4 @@ import galsim as _galsim -import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class @@ -54,16 +53,16 @@ def _max_sb(self): return DeltaFunction._mock_inf def _xValue(self, pos): - return jax.lax.cond( - jnp.array(pos.x == 0.0, dtype=bool) & jnp.array(pos.y == 0.0, dtype=bool), - lambda *a: DeltaFunction._mock_inf, - lambda *a: 0.0, - ) + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + return jnp.where((x == 0.0) & (y == 0.0), DeltaFunction._mock_inf, 0.0) def _kValue(self, kpos): - # this is a wasteful and fancy way to get the shape to broadcast to - # to match the input kpos - return self.flux + kpos.x * (0.0 + 0.0j) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + return self.flux + kx * (0.0 + 0.0j) @implements(_galsim.DeltaFunction._shoot) def _shoot(self, photons, rng): diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index 3d556d59..48e7631d 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -126,11 +126,17 @@ def _max_sb(self): return self._norm def _xValue(self, pos): - r = jnp.sqrt(pos.x**2 + pos.y**2) + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + r = jnp.sqrt(x**2 + y**2) return self._norm * jnp.exp(-r * self._inv_r0) def _kValue(self, kpos): - ksqp1 = (kpos.x**2 + kpos.y**2) * self._r0**2 + 1.0 + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + ksqp1 = (kx**2 + ky**2) * self._r0**2 + 1.0 return self.flux / (ksqp1 * jnp.sqrt(ksqp1)) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/jax_galsim/gaussian.py b/jax_galsim/gaussian.py index 3b937a11..08e887fd 100644 --- a/jax_galsim/gaussian.py +++ b/jax_galsim/gaussian.py @@ -128,11 +128,17 @@ def _max_sb(self): return self._norm def _xValue(self, pos): - rsq = pos.x**2 + pos.y**2 + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + rsq = x**2 + y**2 return self._norm * jnp.exp(-0.5 * rsq * self._inv_sigsq) def _kValue(self, kpos): - ksq = (kpos.x**2 + kpos.y**2) * self._sigsq + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + ksq = (kx**2 + ky**2) * self._sigsq return self.flux * jnp.exp(-0.5 * ksq) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 296173ff..c1fdcfbb 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -224,6 +224,32 @@ def _kValue(self, kpos): "%s does not implement kValue" % self.__class__.__name__ ) + def _kValue_array(self, kx, ky): + """Evaluate the k-space profile at arrays of kx, ky coordinates. + + By default falls back to vmap over _kValue with PositionD. Concrete + profiles should override this with direct array operations for better + performance (avoids per-pixel PositionD construction). + """ + out_shape = jnp.shape(kx) + kx = jnp.atleast_1d(kx).ravel() + ky = jnp.atleast_1d(ky).ravel() + result = jax.vmap(lambda x, y: self._kValue(PositionD(x, y)))(kx, ky) + return result.reshape(out_shape) + + def _xValue_array(self, x, y): + """Evaluate the real-space profile at arrays of x, y coordinates. + + By default falls back to vmap over _xValue with PositionD. Concrete + profiles should override this with direct array operations for better + performance (avoids per-pixel PositionD construction). + """ + out_shape = jnp.shape(x) + x = jnp.atleast_1d(x).ravel() + y = jnp.atleast_1d(y).ravel() + result = jax.vmap(lambda xx, yy: self._xValue(PositionD(xx, yy)))(x, y) + return result.reshape(out_shape) + @implements(_galsim.GSObject.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 1b54f80b..0d082ab2 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -763,6 +763,18 @@ def _xValue(self, pos): self._x_interpolant, )[0] + def _xValue_array(self, x, y): + return _xValue_arr( + x, + y, + self._offset.x, + self._offset.y, + self._pad_image.bounds.xmin, + self._pad_image.bounds.ymin, + self._pad_image.array, + self._x_interpolant, + ) + def _kValue(self, kpos): kx = jnp.array([kpos.x], dtype=float) ky = jnp.array([kpos.y], dtype=float) @@ -779,6 +791,20 @@ def _kValue(self, kpos): self._k_interpolant, )[0] + def _kValue_array(self, kx, ky): + return _kValue_arr( + kx, + ky, + self._offset.x, + self._offset.y, + self._kim.bounds.xmin, + self._kim.bounds.ymin, + self._kim.array, + self._kim.scale, + self._x_interpolant, + self._k_interpolant, + ) + def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): jacobian = jnp.eye(2) if jac is None else jac diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index fe596398..f6fab369 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -328,10 +328,11 @@ def _has_hard_edges(self): def _max_sb(self): return self._norm - @jax.jit def _xValue(self, pos): - rsq = (pos.x**2 + pos.y**2) * self._inv_r0_sq - # trunc if r>maxR with r0 scaled version + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + rsq = (x**2 + y**2) * self._inv_r0_sq return jnp.where( rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta) ) @@ -352,14 +353,14 @@ def _kValue_trunc(self, k): 0.0, ) - @jax.jit def _kValue(self, kpos): - """computation of the Moffat response in k-space with switch of truncated/untracated case - kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) - """ - k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + k = jnp.sqrt((kx**2 + ky**2) * self._r0_sq) out_shape = jnp.shape(k) - k = jnp.atleast_1d(k) + # Flatten to 1D for _hankel's vmap which only maps over axis 0 + k = jnp.atleast_1d(k).ravel() res = jax.lax.cond( self.trunc > 0, lambda x: self._kValue_trunc(x), diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index b0bf22a8..e9d021f4 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -408,15 +408,19 @@ def _max_sb(self): # from SBSpergelImpl.h return jnp.abs(self._xnorm) * self._xnorm0 - @jax.jit def _xValue(self, pos): - r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0 + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + r = jnp.sqrt(x**2 + y**2) * self._inv_r0 res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu)) return self._xnorm * res - @jax.jit def _kValue(self, kpos): - ksq = (kpos.x**2 + kpos.y**2) * self._r0_sq + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + ksq = (kx**2 + ky**2) * self._r0_sq return self.flux * jnp.power(1.0 + ksq, -1.0 - self.nu) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 958e6bfa..f56d24b9 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -153,12 +153,22 @@ def _max_sb(self): return jnp.sum(sb_list) def _xValue(self, pos): - xv_list = jnp.array([obj._xValue(pos) for obj in self.obj_list]) - return jnp.sum(xv_list, axis=0) + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + result = self.obj_list[0]._xValue_array(x, y) + for obj in self.obj_list[1:]: + result = result + obj._xValue_array(x, y) + return result def _kValue(self, pos): - kv_list = jnp.array([obj._kValue(pos) for obj in self.obj_list]) - return jnp.sum(kv_list, axis=0) + return self._kValue_array(pos.x, pos.y) + + def _kValue_array(self, kx, ky): + result = self.obj_list[0]._kValue_array(kx, ky) + for obj in self.obj_list[1:]: + result = result + obj._kValue_array(kx, ky) + return result def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling) @@ -168,10 +178,21 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): return image def _drawKImage(self, image, jac=None): + from jax_galsim import Image + image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image += obj._drawKImage(image, jac) + # Use a fresh blank image with matching metadata to sever + # the false AD dependency on the previously-filled array. + blank = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale) + obj_kimage = obj._drawKImage(blank, jac) + image = Image( + array=image.array + obj_kimage.array, + bounds=image.bounds, + wcs=image.wcs, + _check_bounds=False, + ) return image @property diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index d7d23838..a4a0aece 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -337,13 +337,41 @@ def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self._offset - inv_pos = PositionD(self._inv(pos.x, pos.y)) - return self._original._xValue(inv_pos) * self._amp_scaling + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + x = x - self._offset.x + y = y - self._offset.y + inv_x, inv_y = self._inv_array(x, y) + return self._original._xValue_array(inv_x, inv_y) * self._amp_scaling def _kValue(self, kpos): - fwdT_kpos = PositionD(self._fwdT(kpos.x, kpos.y)) - return self._original._kValue(fwdT_kpos) * self._kfactor(kpos.x, kpos.y) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + fwdT_kx, fwdT_ky = self._fwdT_array(kx, ky) + return self._original._kValue_array(fwdT_kx, fwdT_ky) * self._kfactor_array( + kx, ky + ) + + def _fwdT_array(self, x, y): + """Apply the transposed Jacobian without constructing intermediate arrays.""" + m = self._jac.T + rx = m[0, 0] * x + m[0, 1] * y + ry = m[1, 0] * x + m[1, 1] * y + return rx, ry + + def _inv_array(self, x, y): + """Apply the inverse Jacobian without constructing intermediate arrays.""" + m = self._invjac + rx = m[0, 0] * x + m[0, 1] * y + ry = m[1, 0] * x + m[1, 1] * y + return rx, ry + + def _kfactor_array(self, kx, ky): + """Compute the k-space phase factor on arrays without mutating locals.""" + arg = -1j * (self._offset.x * kx + self._offset.y * ky) + return self._flux_scaling * jnp.exp(arg) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): dx, dy = offset diff --git a/tests/jax/test_convolution_grad.py b/tests/jax/test_convolution_grad.py new file mode 100644 index 00000000..8b4e22b9 --- /dev/null +++ b/tests/jax/test_convolution_grad.py @@ -0,0 +1,103 @@ +"""Test that gradients through convolution rendering are correct. + +This verifies that the structural optimization (breaking false AD dependency +in Convolution._drawKImage) produces correct gradients by comparing against +finite differences. +""" + +import jax +import jax.numpy as jnp +import numpy as np + +import jax_galsim as galsim + +jax.config.update("jax_enable_x64", True) + +gsparams = galsim.GSParams(minimum_fft_size=128, maximum_fft_size=128) + + +def _draw_and_sum(half_light_radius, flux): + """Draw a convolution of Exponential * Moffat and return the sum of pixel values.""" + gal = galsim.Exponential( + half_light_radius=half_light_radius, flux=flux, gsparams=gsparams + ) + psf = galsim.Moffat(beta=3.5, fwhm=0.7, gsparams=gsparams) + obj = galsim.Convolve(gal, psf, gsparams=gsparams) + image = obj.drawImage(nx=64, ny=64, scale=0.2, dtype=float) + return jnp.sum(image.array) + + +def test_convolution_grad_vs_finite_diff(): + """Test that jax.grad through Convolve(Exponential, Moffat).drawImage() + matches finite-difference approximation.""" + hlr = 1.0 + flux = 100.0 + eps = 1e-5 + + grad_fn = jax.grad(_draw_and_sum, argnums=(0, 1)) + grad_hlr, grad_flux = grad_fn(jnp.float64(hlr), jnp.float64(flux)) + + # Finite-difference for half_light_radius + f_plus = _draw_and_sum(hlr + eps, flux) + f_minus = _draw_and_sum(hlr - eps, flux) + fd_grad_hlr = (f_plus - f_minus) / (2 * eps) + + # Finite-difference for flux + f_plus = _draw_and_sum(hlr, flux + eps) + f_minus = _draw_and_sum(hlr, flux - eps) + fd_grad_flux = (f_plus - f_minus) / (2 * eps) + + np.testing.assert_allclose( + grad_hlr, + fd_grad_hlr, + rtol=1e-3, + atol=0, + err_msg="Gradient w.r.t. half_light_radius is incorrect", + ) + np.testing.assert_allclose( + grad_flux, + fd_grad_flux, + rtol=1e-3, + atol=0, + err_msg="Gradient w.r.t. flux is incorrect", + ) + + +def test_sum_grad_vs_finite_diff(): + """Test that jax.grad through Sum.drawImage() is correct.""" + + def _draw_sum(flux1, flux2): + _gsparams = galsim.GSParams(minimum_fft_size=128, maximum_fft_size=128) + g1 = galsim.Gaussian(sigma=1.5, flux=flux1, gsparams=_gsparams) + g2 = galsim.Gaussian(sigma=2.0, flux=flux2, gsparams=_gsparams) + obj = galsim.Add(g1, g2) + image = obj.drawImage(nx=64, ny=64, scale=0.2, method="no_pixel", dtype=float) + return jnp.sum(image.array) + + flux1, flux2 = jnp.float64(50.0), jnp.float64(80.0) + eps = 1e-5 + + grad_fn = jax.grad(_draw_sum, argnums=(0, 1)) + grad_f1, grad_f2 = grad_fn(flux1, flux2) + + fd_grad_f1 = (_draw_sum(flux1 + eps, flux2) - _draw_sum(flux1 - eps, flux2)) / ( + 2 * eps + ) + fd_grad_f2 = (_draw_sum(flux1, flux2 + eps) - _draw_sum(flux1, flux2 - eps)) / ( + 2 * eps + ) + + np.testing.assert_allclose( + grad_f1, + fd_grad_f1, + rtol=1e-3, + atol=0, + err_msg="Sum gradient w.r.t. flux1 is incorrect", + ) + np.testing.assert_allclose( + grad_f2, + fd_grad_f2, + rtol=1e-3, + atol=0, + err_msg="Sum gradient w.r.t. flux2 is incorrect", + )