From 59e9e1388b67fd3088376151ca56b565c4a1ec94 Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 8 Feb 2026 17:47:16 +0100 Subject: [PATCH 1/3] Code efficiency fix --- CLAUDE.md | 99 +++++++++++++++++++++++++++ jax_galsim/box.py | 12 +++- jax_galsim/convolve.py | 31 +++++++-- jax_galsim/core/draw.py | 27 +++----- jax_galsim/deltafunction.py | 17 +++-- jax_galsim/exponential.py | 10 ++- jax_galsim/gaussian.py | 10 ++- jax_galsim/gsobject.py | 26 ++++++++ jax_galsim/interpolatedimage.py | 26 ++++++++ jax_galsim/moffat.py | 14 ++-- jax_galsim/spergel.py | 10 ++- jax_galsim/sum.py | 31 +++++++-- jax_galsim/transform.py | 38 +++++++++-- tests/jax/test_convolution_grad.py | 103 +++++++++++++++++++++++++++++ 14 files changed, 398 insertions(+), 56 deletions(-) create mode 100644 CLAUDE.md create mode 100644 tests/jax/test_convolution_grad.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..ab9e5c6e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,99 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What is JAX-GalSim? + +JAX-GalSim is a pure-JAX reimplementation of [GalSim](https://github.com/GalSim-developers/GalSim), the modular galaxy image simulation toolkit. It provides a near-identical API to GalSim but with full support for JAX transformations (`jit`, `vmap`, `grad`), enabling differentiable and GPU-accelerated galaxy simulations. Currently ~22.5% of the GalSim API is implemented. + +## Build and Development Commands + +```bash +# Setup (requires --recurse-submodules for test submodules) +git clone --recurse-submodules +pip install -e ".[dev]" +pre-commit install + +# Run all tests +pytest + +# Run a single test file +pytest tests/jax/test_jitting.py + +# Run a specific test +pytest tests/jax/test_jitting.py::test_name -v + +# Lint and format +ruff check . --fix +ruff format . + +# Or via pre-commit +pre-commit run --all-files +``` + +## Architecture + +### GSObject System + +The core abstraction is `GSObject` (in `gsobject.py`), the base class for all galaxy/PSF profiles. Key design: + +- **Parameters**: Traced (differentiable) values stored in `self._params` dict. Static config in `self._gsparams`. +- **Pytree protocol**: Every GSObject subclass is decorated with `@register_pytree_node_class` and implements `tree_flatten()`/`tree_unflatten()` so JAX can trace through them. +- **Drawing**: Profiles implement `_xValue(pos)` for real-space and `_kValue(kpos)` for Fourier-space evaluation. Drawing is dispatched through `core/draw.py` which uses `jax.vmap` over pixel centers. +- **Composition**: `Convolution`, `Sum`, `Transform` compose GSObjects into new ones while preserving the pytree structure. + +Concrete profiles: `Gaussian`, `Exponential`, `Moffat`, `Spergel`, `Box`/`Pixel`, `DeltaFunction`, `InterpolatedImage`. + +### Key Patterns + +**Implementing a new GSObject**: Every profile follows this pattern: +```python +import galsim as _galsim +from jax.tree_util import register_pytree_node_class +from jax_galsim.core.utils import implements + +@implements(_galsim.ProfileName) +@register_pytree_node_class +class ProfileName(GSObject): + def __init__(self, ..., flux=1.0, gsparams=None): + super().__init__(param1=val1, flux=flux, gsparams=gsparams) + + def tree_flatten(self): + children = (self.params,) + aux_data = {"gsparams": self.gsparams} + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + ... +``` + +**`@implements` decorator** (in `core/utils.py`): Copies docstrings from the reference GalSim object and optionally appends a `lax_description` noting JAX-specific differences. Use this on all public classes/functions that mirror GalSim. + +**`from_galsim()`/`to_galsim()` methods**: Many classes provide conversion to/from GalSim equivalents (Position, Bounds, Image, Shear, GSParams, WCS, CelestialCoord). + +### Image System + +`Image` (in `image.py`) wraps `jnp.ndarray` with bounds and WCS metadata. Unlike GalSim, JAX arrays are immutable — all operations return new Image instances instead of modifying in place. + +### Testing Architecture + +Tests come from three sources, configured in `tests/galsim_tests_config.yaml`: +- **`tests/GalSim/tests/`** — GalSim's own test suite (git submodule), run against jax_galsim via `conftest.py` module-swapping hooks +- **`tests/Coord/tests/`** — Coord package test suite (git submodule) +- **`tests/jax/`** — JAX-specific tests (jit, vmap, grad, benchmarks) + +The `conftest.py` replaces `galsim` imports with `jax_galsim` in collected test modules. Tests for unimplemented features auto-pass via the `allowed_failures` list in the YAML config. + +CI splits tests into 4 parallel groups using `pytest-split`. + +## Key Differences from GalSim + +- JAX arrays are **immutable** — no in-place operations on images +- `jax.config.update("jax_enable_x64", True)` is required for 64-bit precision (set in conftest.py) +- Random number generation uses JAX's functional RNG, not GalSim's C++ RNG +- Reference GalSim is always imported as `_galsim` (private) to avoid confusion with the jax_galsim namespace + +## Ruff Configuration + +Selects rules: E, F, I, W. Ignores: C901 (complexity), E203, E501 (line length). `__init__.py` ignores F401 (unused imports) and I001 (import order). Ruff excludes `tests/GalSim/`, `tests/Coord/`, `tests/jax/galsim/`, `dev/notebooks/`. diff --git a/jax_galsim/box.py b/jax_galsim/box.py index 95b3d373..2e5147a1 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -79,17 +79,23 @@ def _max_sb(self): return self.flux / (self.width * self.height) def _xValue(self, pos): + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): norm = self.flux / (self.width * self.height) return jnp.where( - 2.0 * jnp.abs(pos.x) < self.width, - jnp.where(2.0 * jnp.abs(pos.y) < self.height, norm, 0.0), + 2.0 * jnp.abs(x) < self.width, + jnp.where(2.0 * jnp.abs(y) < self.height, norm, 0.0), 0.0, ) def _kValue(self, kpos): + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): _wo2pi = self.width / (2.0 * jnp.pi) _ho2pi = self.height / (2.0 * jnp.pi) - return self.flux * jnp.sinc(kpos.x * _wo2pi) * jnp.sinc(kpos.y * _ho2pi) + return self.flux * jnp.sinc(kx * _wo2pi) * jnp.sinc(ky * _ho2pi) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): _jac = jnp.eye(2) if jac is None else jac diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..506890bb 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -292,10 +292,13 @@ def _xValue(self, pos): raise NotImplementedError("Real-space convolutions are not implemented") def _kValue(self, kpos): - kv_list = [ - obj.kValue(kpos) for obj in self.obj_list - ] # In GalSim one uses obj.kValue - return jnp.prod(jnp.array(kv_list)) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + result = self.obj_list[0]._kValue_array(kx, ky) + for obj in self.obj_list[1:]: + result = result * obj._kValue_array(kx, ky) + return result def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): raise NotImplementedError("Real-space convolutions are not implemented") @@ -313,10 +316,23 @@ def _shoot(self, photons, rng): photons.convolve(p1, rng) def _drawKImage(self, image, jac=None): + from jax_galsim import Image + image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image *= obj._drawKImage(image, jac) + # Use a fresh blank image with matching metadata to sever + # the false AD dependency on the galaxy-filled array. + # The PSF's _drawKImage only uses image metadata (bounds, + # scale, wcs), never the array data. + blank = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale) + obj_kimage = obj._drawKImage(blank, jac) + image = Image( + array=image.array * obj_kimage.array, + bounds=image.bounds, + wcs=image.wcs, + _check_bounds=False, + ) return image def tree_flatten(self): @@ -474,10 +490,13 @@ def _max_sb(self): return -self.orig_obj.max_sb / self.orig_obj.flux**2 def _kValue(self, pos): + return self._kValue_array(pos.x, pos.y) + + def _kValue_array(self, kx, ky): # Really, for very low original kvalues, this gets very high, which can be unstable # in the presence of noise. So if the original value is less than min_acc_kvalue, # we instead just return 1/min_acc_kvalue rather than the real inverse. - kval = self.orig_obj._kValue(pos) + kval = self.orig_obj._kValue_array(kx, ky) return jnp.where( jnp.abs(kval) < self._min_acc_kvalue, self._inv_min_acc_kvalue, diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index a197d151..378d23bf 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -11,7 +11,7 @@ def draw_by_xValue( ): """Utility function to draw a real-space GSObject into an Image.""" # putting the import here to avoid circular imports - from jax_galsim import Image, PositionD + from jax_galsim import Image # Applies flux scaling to compensate for pixel scale # See SBProfile.draw() @@ -29,9 +29,7 @@ def draw_by_xValue( flux_scaling *= jnp.exp(logdet) # Draw the object - im = jax.vmap(lambda *args: gsobject._xValue(PositionD(*args)))( - coords[..., 0], coords[..., 1] - ) + im = gsobject._xValue_array(coords[..., 0], coords[..., 1]) # Apply the flux scaling im = (im * flux_scaling).astype(image.dtype) @@ -42,7 +40,7 @@ def draw_by_xValue( def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): # putting the import here to avoid circular imports - from jax_galsim import Image, PositionD + from jax_galsim import Image # Create an array of coordinates coords = jnp.stack(image.get_pixel_centers(), axis=-1) @@ -50,9 +48,7 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): coords = jnp.dot(coords, jacobian) # Draw the object - im = jax.vmap(lambda *args: gsobject._kValue(PositionD(*args)))( - coords[..., 0], coords[..., 1] - ) + im = gsobject._kValue_array(coords[..., 0], coords[..., 1]) im = (im).astype(image.dtype) # Return an image @@ -61,23 +57,18 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)): # putting the import here to avoid circular imports - from jax_galsim import Image, PositionD + from jax_galsim import Image # Create an array of coordinates kcoords = jnp.stack(image.get_pixel_centers(), axis=-1) kcoords = kcoords * image.scale # Scale by the image pixel scale kcoords = jnp.dot(kcoords, jacobian) - cenx, ceny = offset.x, offset.y - # flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) ) - # NB: seems that tere is no jax.lax.polar equivalent to c++ std::polar function - def phase(kpos): - arg = -(kpos.x * cenx + kpos.y * ceny) - return jnp.cos(arg) + 1j * jnp.sin(arg) + # flux Exp(-i (kx cx + ky cy) ) + kx, ky = kcoords[..., 0], kcoords[..., 1] + arg = -(kx * offset.x + ky * offset.y) + im_phase = jnp.cos(arg) + 1j * jnp.sin(arg) - im_phase = jax.vmap(lambda *args: phase(PositionD(*args)))( - kcoords[..., 0], kcoords[..., 1] - ) return Image( array=image.array * im_phase, bounds=image.bounds, diff --git a/jax_galsim/deltafunction.py b/jax_galsim/deltafunction.py index 71043e96..f32d3e10 100644 --- a/jax_galsim/deltafunction.py +++ b/jax_galsim/deltafunction.py @@ -1,5 +1,4 @@ import galsim as _galsim -import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class @@ -54,16 +53,16 @@ def _max_sb(self): return DeltaFunction._mock_inf def _xValue(self, pos): - return jax.lax.cond( - jnp.array(pos.x == 0.0, dtype=bool) & jnp.array(pos.y == 0.0, dtype=bool), - lambda *a: DeltaFunction._mock_inf, - lambda *a: 0.0, - ) + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + return jnp.where((x == 0.0) & (y == 0.0), DeltaFunction._mock_inf, 0.0) def _kValue(self, kpos): - # this is a wasteful and fancy way to get the shape to broadcast to - # to match the input kpos - return self.flux + kpos.x * (0.0 + 0.0j) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + return self.flux + kx * (0.0 + 0.0j) @implements(_galsim.DeltaFunction._shoot) def _shoot(self, photons, rng): diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index 3d556d59..48e7631d 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -126,11 +126,17 @@ def _max_sb(self): return self._norm def _xValue(self, pos): - r = jnp.sqrt(pos.x**2 + pos.y**2) + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + r = jnp.sqrt(x**2 + y**2) return self._norm * jnp.exp(-r * self._inv_r0) def _kValue(self, kpos): - ksqp1 = (kpos.x**2 + kpos.y**2) * self._r0**2 + 1.0 + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + ksqp1 = (kx**2 + ky**2) * self._r0**2 + 1.0 return self.flux / (ksqp1 * jnp.sqrt(ksqp1)) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/jax_galsim/gaussian.py b/jax_galsim/gaussian.py index 3b937a11..08e887fd 100644 --- a/jax_galsim/gaussian.py +++ b/jax_galsim/gaussian.py @@ -128,11 +128,17 @@ def _max_sb(self): return self._norm def _xValue(self, pos): - rsq = pos.x**2 + pos.y**2 + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + rsq = x**2 + y**2 return self._norm * jnp.exp(-0.5 * rsq * self._inv_sigsq) def _kValue(self, kpos): - ksq = (kpos.x**2 + kpos.y**2) * self._sigsq + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + ksq = (kx**2 + ky**2) * self._sigsq return self.flux * jnp.exp(-0.5 * ksq) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 296173ff..cf014646 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -224,6 +224,32 @@ def _kValue(self, kpos): "%s does not implement kValue" % self.__class__.__name__ ) + def _kValue_array(self, kx, ky): + """Evaluate the k-space profile at arrays of kx, ky coordinates. + + By default falls back to vmap over _kValue with PositionD. Concrete + profiles should override this with direct array operations for better + performance (avoids per-pixel PositionD construction). + """ + out_shape = jnp.shape(kx) + kx = jnp.atleast_1d(jnp.asarray(kx)).ravel() + ky = jnp.atleast_1d(jnp.asarray(ky)).ravel() + result = jax.vmap(lambda x, y: self._kValue(PositionD(x, y)))(kx, ky) + return result.reshape(out_shape) + + def _xValue_array(self, x, y): + """Evaluate the real-space profile at arrays of x, y coordinates. + + By default falls back to vmap over _xValue with PositionD. Concrete + profiles should override this with direct array operations for better + performance (avoids per-pixel PositionD construction). + """ + out_shape = jnp.shape(x) + x = jnp.atleast_1d(jnp.asarray(x)).ravel() + y = jnp.atleast_1d(jnp.asarray(y)).ravel() + result = jax.vmap(lambda xx, yy: self._xValue(PositionD(xx, yy)))(x, y) + return result.reshape(out_shape) + @implements(_galsim.GSObject.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 1b54f80b..0d082ab2 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -763,6 +763,18 @@ def _xValue(self, pos): self._x_interpolant, )[0] + def _xValue_array(self, x, y): + return _xValue_arr( + x, + y, + self._offset.x, + self._offset.y, + self._pad_image.bounds.xmin, + self._pad_image.bounds.ymin, + self._pad_image.array, + self._x_interpolant, + ) + def _kValue(self, kpos): kx = jnp.array([kpos.x], dtype=float) ky = jnp.array([kpos.y], dtype=float) @@ -779,6 +791,20 @@ def _kValue(self, kpos): self._k_interpolant, )[0] + def _kValue_array(self, kx, ky): + return _kValue_arr( + kx, + ky, + self._offset.x, + self._offset.y, + self._kim.bounds.xmin, + self._kim.bounds.ymin, + self._kim.array, + self._kim.scale, + self._x_interpolant, + self._k_interpolant, + ) + def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): jacobian = jnp.eye(2) if jac is None else jac diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index fe596398..54a361cf 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -330,8 +330,10 @@ def _max_sb(self): @jax.jit def _xValue(self, pos): - rsq = (pos.x**2 + pos.y**2) * self._inv_r0_sq - # trunc if r>maxR with r0 scaled version + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + rsq = (x**2 + y**2) * self._inv_r0_sq return jnp.where( rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta) ) @@ -357,9 +359,13 @@ def _kValue(self, kpos): """computation of the Moffat response in k-space with switch of truncated/untracated case kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) """ - k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + k = jnp.sqrt((kx**2 + ky**2) * self._r0_sq) out_shape = jnp.shape(k) - k = jnp.atleast_1d(k) + # Flatten to 1D for _hankel's vmap which only maps over axis 0 + k = jnp.atleast_1d(k).ravel() res = jax.lax.cond( self.trunc > 0, lambda x: self._kValue_trunc(x), diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index b0bf22a8..8862bccc 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -410,13 +410,19 @@ def _max_sb(self): @jax.jit def _xValue(self, pos): - r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0 + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + r = jnp.sqrt(x**2 + y**2) * self._inv_r0 res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu)) return self._xnorm * res @jax.jit def _kValue(self, kpos): - ksq = (kpos.x**2 + kpos.y**2) * self._r0_sq + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + ksq = (kx**2 + ky**2) * self._r0_sq return self.flux * jnp.power(1.0 + ksq, -1.0 - self.nu) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 958e6bfa..f56d24b9 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -153,12 +153,22 @@ def _max_sb(self): return jnp.sum(sb_list) def _xValue(self, pos): - xv_list = jnp.array([obj._xValue(pos) for obj in self.obj_list]) - return jnp.sum(xv_list, axis=0) + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + result = self.obj_list[0]._xValue_array(x, y) + for obj in self.obj_list[1:]: + result = result + obj._xValue_array(x, y) + return result def _kValue(self, pos): - kv_list = jnp.array([obj._kValue(pos) for obj in self.obj_list]) - return jnp.sum(kv_list, axis=0) + return self._kValue_array(pos.x, pos.y) + + def _kValue_array(self, kx, ky): + result = self.obj_list[0]._kValue_array(kx, ky) + for obj in self.obj_list[1:]: + result = result + obj._kValue_array(kx, ky) + return result def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling) @@ -168,10 +178,21 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): return image def _drawKImage(self, image, jac=None): + from jax_galsim import Image + image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: - image += obj._drawKImage(image, jac) + # Use a fresh blank image with matching metadata to sever + # the false AD dependency on the previously-filled array. + blank = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale) + obj_kimage = obj._drawKImage(blank, jac) + image = Image( + array=image.array + obj_kimage.array, + bounds=image.bounds, + wcs=image.wcs, + _check_bounds=False, + ) return image @property diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index d7d23838..2f18fdf8 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -337,13 +337,41 @@ def _max_sb(self): return self._amp_scaling * self._original.max_sb def _xValue(self, pos): - pos -= self._offset - inv_pos = PositionD(self._inv(pos.x, pos.y)) - return self._original._xValue(inv_pos) * self._amp_scaling + return self._xValue_array(pos.x, pos.y) + + def _xValue_array(self, x, y): + x = x - self._offset.x + y = y - self._offset.y + inv_x, inv_y = self._inv_array(x, y) + return self._original._xValue_array(inv_x, inv_y) * self._amp_scaling def _kValue(self, kpos): - fwdT_kpos = PositionD(self._fwdT(kpos.x, kpos.y)) - return self._original._kValue(fwdT_kpos) * self._kfactor(kpos.x, kpos.y) + return self._kValue_array(kpos.x, kpos.y) + + def _kValue_array(self, kx, ky): + fwdT_kx, fwdT_ky = self._fwdT_array(kx, ky) + return self._original._kValue_array(fwdT_kx, fwdT_ky) * self._kfactor_array( + kx, ky + ) + + def _fwdT_array(self, x, y): + """Like _fwdT but works on arrays of any shape.""" + m = self._jac.T + rx = m[0, 0] * x + m[0, 1] * y + ry = m[1, 0] * x + m[1, 1] * y + return rx, ry + + def _inv_array(self, x, y): + """Like _inv but works on arrays of any shape.""" + m = self._invjac + rx = m[0, 0] * x + m[0, 1] * y + ry = m[1, 0] * x + m[1, 1] * y + return rx, ry + + def _kfactor_array(self, kx, ky): + """Like _kfactor but works on arrays of any shape without mutating locals.""" + arg = -1j * self._offset.x * kx + (-1j * self._offset.y) * ky + return self._flux_scaling * jnp.exp(arg) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): dx, dy = offset diff --git a/tests/jax/test_convolution_grad.py b/tests/jax/test_convolution_grad.py new file mode 100644 index 00000000..8b4e22b9 --- /dev/null +++ b/tests/jax/test_convolution_grad.py @@ -0,0 +1,103 @@ +"""Test that gradients through convolution rendering are correct. + +This verifies that the structural optimization (breaking false AD dependency +in Convolution._drawKImage) produces correct gradients by comparing against +finite differences. +""" + +import jax +import jax.numpy as jnp +import numpy as np + +import jax_galsim as galsim + +jax.config.update("jax_enable_x64", True) + +gsparams = galsim.GSParams(minimum_fft_size=128, maximum_fft_size=128) + + +def _draw_and_sum(half_light_radius, flux): + """Draw a convolution of Exponential * Moffat and return the sum of pixel values.""" + gal = galsim.Exponential( + half_light_radius=half_light_radius, flux=flux, gsparams=gsparams + ) + psf = galsim.Moffat(beta=3.5, fwhm=0.7, gsparams=gsparams) + obj = galsim.Convolve(gal, psf, gsparams=gsparams) + image = obj.drawImage(nx=64, ny=64, scale=0.2, dtype=float) + return jnp.sum(image.array) + + +def test_convolution_grad_vs_finite_diff(): + """Test that jax.grad through Convolve(Exponential, Moffat).drawImage() + matches finite-difference approximation.""" + hlr = 1.0 + flux = 100.0 + eps = 1e-5 + + grad_fn = jax.grad(_draw_and_sum, argnums=(0, 1)) + grad_hlr, grad_flux = grad_fn(jnp.float64(hlr), jnp.float64(flux)) + + # Finite-difference for half_light_radius + f_plus = _draw_and_sum(hlr + eps, flux) + f_minus = _draw_and_sum(hlr - eps, flux) + fd_grad_hlr = (f_plus - f_minus) / (2 * eps) + + # Finite-difference for flux + f_plus = _draw_and_sum(hlr, flux + eps) + f_minus = _draw_and_sum(hlr, flux - eps) + fd_grad_flux = (f_plus - f_minus) / (2 * eps) + + np.testing.assert_allclose( + grad_hlr, + fd_grad_hlr, + rtol=1e-3, + atol=0, + err_msg="Gradient w.r.t. half_light_radius is incorrect", + ) + np.testing.assert_allclose( + grad_flux, + fd_grad_flux, + rtol=1e-3, + atol=0, + err_msg="Gradient w.r.t. flux is incorrect", + ) + + +def test_sum_grad_vs_finite_diff(): + """Test that jax.grad through Sum.drawImage() is correct.""" + + def _draw_sum(flux1, flux2): + _gsparams = galsim.GSParams(minimum_fft_size=128, maximum_fft_size=128) + g1 = galsim.Gaussian(sigma=1.5, flux=flux1, gsparams=_gsparams) + g2 = galsim.Gaussian(sigma=2.0, flux=flux2, gsparams=_gsparams) + obj = galsim.Add(g1, g2) + image = obj.drawImage(nx=64, ny=64, scale=0.2, method="no_pixel", dtype=float) + return jnp.sum(image.array) + + flux1, flux2 = jnp.float64(50.0), jnp.float64(80.0) + eps = 1e-5 + + grad_fn = jax.grad(_draw_sum, argnums=(0, 1)) + grad_f1, grad_f2 = grad_fn(flux1, flux2) + + fd_grad_f1 = (_draw_sum(flux1 + eps, flux2) - _draw_sum(flux1 - eps, flux2)) / ( + 2 * eps + ) + fd_grad_f2 = (_draw_sum(flux1, flux2 + eps) - _draw_sum(flux1, flux2 - eps)) / ( + 2 * eps + ) + + np.testing.assert_allclose( + grad_f1, + fd_grad_f1, + rtol=1e-3, + atol=0, + err_msg="Sum gradient w.r.t. flux1 is incorrect", + ) + np.testing.assert_allclose( + grad_f2, + fd_grad_f2, + rtol=1e-3, + atol=0, + err_msg="Sum gradient w.r.t. flux2 is incorrect", + ) From 3bbefeac2ab056fba141b8052958a22d2d3e4add Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 8 Feb 2026 18:22:44 +0100 Subject: [PATCH 2/3] minor fixes --- jax_galsim/core/draw.py | 2 +- jax_galsim/gsobject.py | 8 ++++---- jax_galsim/moffat.py | 5 ----- jax_galsim/spergel.py | 2 -- jax_galsim/transform.py | 8 ++++---- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 378d23bf..f8d4d79a 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -49,7 +49,7 @@ def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)): # Draw the object im = gsobject._kValue_array(coords[..., 0], coords[..., 1]) - im = (im).astype(image.dtype) + im = im.astype(image.dtype) # Return an image return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index cf014646..c1fdcfbb 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -232,8 +232,8 @@ def _kValue_array(self, kx, ky): performance (avoids per-pixel PositionD construction). """ out_shape = jnp.shape(kx) - kx = jnp.atleast_1d(jnp.asarray(kx)).ravel() - ky = jnp.atleast_1d(jnp.asarray(ky)).ravel() + kx = jnp.atleast_1d(kx).ravel() + ky = jnp.atleast_1d(ky).ravel() result = jax.vmap(lambda x, y: self._kValue(PositionD(x, y)))(kx, ky) return result.reshape(out_shape) @@ -245,8 +245,8 @@ def _xValue_array(self, x, y): performance (avoids per-pixel PositionD construction). """ out_shape = jnp.shape(x) - x = jnp.atleast_1d(jnp.asarray(x)).ravel() - y = jnp.atleast_1d(jnp.asarray(y)).ravel() + x = jnp.atleast_1d(x).ravel() + y = jnp.atleast_1d(y).ravel() result = jax.vmap(lambda xx, yy: self._xValue(PositionD(xx, yy)))(x, y) return result.reshape(out_shape) diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 54a361cf..f6fab369 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -328,7 +328,6 @@ def _has_hard_edges(self): def _max_sb(self): return self._norm - @jax.jit def _xValue(self, pos): return self._xValue_array(pos.x, pos.y) @@ -354,11 +353,7 @@ def _kValue_trunc(self, k): 0.0, ) - @jax.jit def _kValue(self, kpos): - """computation of the Moffat response in k-space with switch of truncated/untracated case - kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image) - """ return self._kValue_array(kpos.x, kpos.y) def _kValue_array(self, kx, ky): diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index 8862bccc..e9d021f4 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -408,7 +408,6 @@ def _max_sb(self): # from SBSpergelImpl.h return jnp.abs(self._xnorm) * self._xnorm0 - @jax.jit def _xValue(self, pos): return self._xValue_array(pos.x, pos.y) @@ -417,7 +416,6 @@ def _xValue_array(self, x, y): res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu)) return self._xnorm * res - @jax.jit def _kValue(self, kpos): return self._kValue_array(kpos.x, kpos.y) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 2f18fdf8..a4a0aece 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -355,22 +355,22 @@ def _kValue_array(self, kx, ky): ) def _fwdT_array(self, x, y): - """Like _fwdT but works on arrays of any shape.""" + """Apply the transposed Jacobian without constructing intermediate arrays.""" m = self._jac.T rx = m[0, 0] * x + m[0, 1] * y ry = m[1, 0] * x + m[1, 1] * y return rx, ry def _inv_array(self, x, y): - """Like _inv but works on arrays of any shape.""" + """Apply the inverse Jacobian without constructing intermediate arrays.""" m = self._invjac rx = m[0, 0] * x + m[0, 1] * y ry = m[1, 0] * x + m[1, 1] * y return rx, ry def _kfactor_array(self, kx, ky): - """Like _kfactor but works on arrays of any shape without mutating locals.""" - arg = -1j * self._offset.x * kx + (-1j * self._offset.y) * ky + """Compute the k-space phase factor on arrays without mutating locals.""" + arg = -1j * (self._offset.x * kx + self._offset.y * ky) return self._flux_scaling * jnp.exp(arg) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): From 14364ef1921e55ec0fffa2eecddc8207a510d0ea Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sun, 8 Feb 2026 18:28:15 +0100 Subject: [PATCH 3/3] Delete CLAUDE.md --- CLAUDE.md | 99 ------------------------------------------------------- 1 file changed, 99 deletions(-) delete mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index ab9e5c6e..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,99 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## What is JAX-GalSim? - -JAX-GalSim is a pure-JAX reimplementation of [GalSim](https://github.com/GalSim-developers/GalSim), the modular galaxy image simulation toolkit. It provides a near-identical API to GalSim but with full support for JAX transformations (`jit`, `vmap`, `grad`), enabling differentiable and GPU-accelerated galaxy simulations. Currently ~22.5% of the GalSim API is implemented. - -## Build and Development Commands - -```bash -# Setup (requires --recurse-submodules for test submodules) -git clone --recurse-submodules -pip install -e ".[dev]" -pre-commit install - -# Run all tests -pytest - -# Run a single test file -pytest tests/jax/test_jitting.py - -# Run a specific test -pytest tests/jax/test_jitting.py::test_name -v - -# Lint and format -ruff check . --fix -ruff format . - -# Or via pre-commit -pre-commit run --all-files -``` - -## Architecture - -### GSObject System - -The core abstraction is `GSObject` (in `gsobject.py`), the base class for all galaxy/PSF profiles. Key design: - -- **Parameters**: Traced (differentiable) values stored in `self._params` dict. Static config in `self._gsparams`. -- **Pytree protocol**: Every GSObject subclass is decorated with `@register_pytree_node_class` and implements `tree_flatten()`/`tree_unflatten()` so JAX can trace through them. -- **Drawing**: Profiles implement `_xValue(pos)` for real-space and `_kValue(kpos)` for Fourier-space evaluation. Drawing is dispatched through `core/draw.py` which uses `jax.vmap` over pixel centers. -- **Composition**: `Convolution`, `Sum`, `Transform` compose GSObjects into new ones while preserving the pytree structure. - -Concrete profiles: `Gaussian`, `Exponential`, `Moffat`, `Spergel`, `Box`/`Pixel`, `DeltaFunction`, `InterpolatedImage`. - -### Key Patterns - -**Implementing a new GSObject**: Every profile follows this pattern: -```python -import galsim as _galsim -from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import implements - -@implements(_galsim.ProfileName) -@register_pytree_node_class -class ProfileName(GSObject): - def __init__(self, ..., flux=1.0, gsparams=None): - super().__init__(param1=val1, flux=flux, gsparams=gsparams) - - def tree_flatten(self): - children = (self.params,) - aux_data = {"gsparams": self.gsparams} - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - ... -``` - -**`@implements` decorator** (in `core/utils.py`): Copies docstrings from the reference GalSim object and optionally appends a `lax_description` noting JAX-specific differences. Use this on all public classes/functions that mirror GalSim. - -**`from_galsim()`/`to_galsim()` methods**: Many classes provide conversion to/from GalSim equivalents (Position, Bounds, Image, Shear, GSParams, WCS, CelestialCoord). - -### Image System - -`Image` (in `image.py`) wraps `jnp.ndarray` with bounds and WCS metadata. Unlike GalSim, JAX arrays are immutable — all operations return new Image instances instead of modifying in place. - -### Testing Architecture - -Tests come from three sources, configured in `tests/galsim_tests_config.yaml`: -- **`tests/GalSim/tests/`** — GalSim's own test suite (git submodule), run against jax_galsim via `conftest.py` module-swapping hooks -- **`tests/Coord/tests/`** — Coord package test suite (git submodule) -- **`tests/jax/`** — JAX-specific tests (jit, vmap, grad, benchmarks) - -The `conftest.py` replaces `galsim` imports with `jax_galsim` in collected test modules. Tests for unimplemented features auto-pass via the `allowed_failures` list in the YAML config. - -CI splits tests into 4 parallel groups using `pytest-split`. - -## Key Differences from GalSim - -- JAX arrays are **immutable** — no in-place operations on images -- `jax.config.update("jax_enable_x64", True)` is required for 64-bit precision (set in conftest.py) -- Random number generation uses JAX's functional RNG, not GalSim's C++ RNG -- Reference GalSim is always imported as `_galsim` (private) to avoid confusion with the jax_galsim namespace - -## Ruff Configuration - -Selects rules: E, F, I, W. Ignores: C901 (complexity), E203, E501 (line length). `__init__.py` ignores F401 (unused imports) and I001 (import order). Ruff excludes `tests/GalSim/`, `tests/Coord/`, `tests/jax/galsim/`, `dev/notebooks/`.