From c1109e7c41a3ce677ec6e33a1612c127f8408476 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 20 Feb 2026 17:52:52 +0100 Subject: [PATCH 01/10] docs: add Material for MkDocs documentation site Set up full documentation infrastructure with auto-generated API reference from docstrings, narrative guides, and GitHub Pages deployment via CI. - Add mkdocs-material and mkdocstrings[python] as docs dependencies - Create mkdocs.yml with Material theme, left sidebar nav, Mermaid, MathJax - Add narrative pages: home, getting started, architecture, notable differences - Add API reference stubs for all public modules targeting specific classes - Add GitHub Actions workflow for build on PR / deploy on push to main Co-Authored-By: Claude Opus 4.6 --- .github/workflows/docs.yml | 35 ++++++ docs/api-coverage.md | 131 +++++++++++++++++++ docs/api/composition/convolve.md | 7 ++ docs/api/composition/sum.md | 5 + docs/api/composition/transform.md | 5 + docs/api/config/errors.md | 5 + docs/api/config/gsparams.md | 5 + docs/api/config/utilities.md | 5 + docs/api/coordinates/angle.md | 7 ++ docs/api/coordinates/bounds.md | 7 ++ docs/api/coordinates/celestial.md | 5 + docs/api/coordinates/position.md | 7 ++ docs/api/coordinates/shear.md | 5 + docs/api/core/draw.md | 5 + docs/api/core/interpolate.md | 5 + docs/api/core/math.md | 5 + docs/api/core/utils.md | 5 + docs/api/image.md | 5 + docs/api/interpolation/interpolant.md | 17 +++ docs/api/interpolation/interpolatedimage.md | 5 + docs/api/math/bessel.md | 5 + docs/api/math/integ.md | 5 + docs/api/noise/noise.md | 15 +++ docs/api/noise/random.md | 19 +++ docs/api/photons/photon_array.md | 5 + docs/api/photons/sensor.md | 5 + docs/api/profiles/box.md | 7 ++ docs/api/profiles/deltafunction.md | 5 + docs/api/profiles/exponential.md | 5 + docs/api/profiles/gaussian.md | 5 + docs/api/profiles/gsobject.md | 5 + docs/api/profiles/moffat.md | 5 + docs/api/profiles/spergel.md | 5 + docs/api/wcs/fits.md | 5 + docs/api/wcs/fitswcs.md | 5 + docs/api/wcs/wcs.md | 17 +++ docs/architecture/drawing.md | 79 ++++++++++++ docs/architecture/gsobject.md | 91 ++++++++++++++ docs/architecture/implements.md | 60 +++++++++ docs/architecture/index.md | 112 +++++++++++++++++ docs/architecture/pytree.md | 77 ++++++++++++ docs/getting-started/index.md | 7 ++ docs/getting-started/installation.md | 69 ++++++++++ docs/getting-started/key-concepts.md | 96 ++++++++++++++ docs/getting-started/quickstart.md | 103 +++++++++++++++ docs/index.md | 80 ++++++++++++ docs/javascripts/mathjax.js | 16 +++ docs/notable-differences.md | 73 +++++++++++ mkdocs.yml | 132 ++++++++++++++++++++ pyproject.toml | 5 + 50 files changed, 1394 insertions(+) create mode 100644 .github/workflows/docs.yml create mode 100644 docs/api-coverage.md create mode 100644 docs/api/composition/convolve.md create mode 100644 docs/api/composition/sum.md create mode 100644 docs/api/composition/transform.md create mode 100644 docs/api/config/errors.md create mode 100644 docs/api/config/gsparams.md create mode 100644 docs/api/config/utilities.md create mode 100644 docs/api/coordinates/angle.md create mode 100644 docs/api/coordinates/bounds.md create mode 100644 docs/api/coordinates/celestial.md create mode 100644 docs/api/coordinates/position.md create mode 100644 docs/api/coordinates/shear.md create mode 100644 docs/api/core/draw.md create mode 100644 docs/api/core/interpolate.md create mode 100644 docs/api/core/math.md create mode 100644 docs/api/core/utils.md create mode 100644 docs/api/image.md create mode 100644 docs/api/interpolation/interpolant.md create mode 100644 docs/api/interpolation/interpolatedimage.md create mode 100644 docs/api/math/bessel.md create mode 100644 docs/api/math/integ.md create mode 100644 docs/api/noise/noise.md create mode 100644 docs/api/noise/random.md create mode 100644 docs/api/photons/photon_array.md create mode 100644 docs/api/photons/sensor.md create mode 100644 docs/api/profiles/box.md create mode 100644 docs/api/profiles/deltafunction.md create mode 100644 docs/api/profiles/exponential.md create mode 100644 docs/api/profiles/gaussian.md create mode 100644 docs/api/profiles/gsobject.md create mode 100644 docs/api/profiles/moffat.md create mode 100644 docs/api/profiles/spergel.md create mode 100644 docs/api/wcs/fits.md create mode 100644 docs/api/wcs/fitswcs.md create mode 100644 docs/api/wcs/wcs.md create mode 100644 docs/architecture/drawing.md create mode 100644 docs/architecture/gsobject.md create mode 100644 docs/architecture/implements.md create mode 100644 docs/architecture/index.md create mode 100644 docs/architecture/pytree.md create mode 100644 docs/getting-started/index.md create mode 100644 docs/getting-started/installation.md create mode 100644 docs/getting-started/key-concepts.md create mode 100644 docs/getting-started/quickstart.md create mode 100644 docs/index.md create mode 100644 docs/javascripts/mathjax.js create mode 100644 docs/notable-differences.md create mode 100644 mkdocs.yml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..0162acc1 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,35 @@ +name: Documentation + +on: + push: + branches: [main] + pull_request: + branches: [main] + +permissions: + contents: write + +concurrency: + group: docs-${{ github.ref }} + cancel-in-progress: true + +jobs: + docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Install dependencies + run: pip install -e ".[docs]" + + - name: Build documentation + if: github.event_name == 'pull_request' + run: mkdocs build --strict + + - name: Deploy documentation + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: mkdocs gh-deploy --force diff --git a/docs/api-coverage.md b/docs/api-coverage.md new file mode 100644 index 00000000..48ab74a2 --- /dev/null +++ b/docs/api-coverage.md @@ -0,0 +1,131 @@ +# API Coverage + +JAX-GalSim has implemented **22.5%** of the GalSim API. The project focuses on +the most commonly used profiles and operations, with coverage expanding over time. + +## Supported APIs + +??? note "Click to expand the full list of implemented APIs" + + - `galsim.Add` + - `galsim.AffineTransform` + - `galsim.Angle` + - `galsim.AngleUnit` + - `galsim.BaseDeviate` + - `galsim.BaseNoise` + - `galsim.BaseWCS` + - `galsim.BinomialDeviate` + - `galsim.Bounds` + - `galsim.BoundsD` + - `galsim.BoundsI` + - `galsim.Box` + - `galsim.CCDNoise` + - `galsim.CelestialCoord` + - `galsim.Chi2Deviate` + - `galsim.Convolution` + - `galsim.Convolve` + - `galsim.Cubic` + - `galsim.Deconvolution` + - `galsim.Deconvolve` + - `galsim.Delta` + - `galsim.DeltaFunction` + - `galsim.DeviateNoise` + - `galsim.Exponential` + - `galsim.FitsHeader` + - `galsim.FitsWCS` + - `galsim.GSFitsWCS` + - `galsim.GSObject` + - `galsim.GSParams` + - `galsim.GalSimBoundsError` + - `galsim.GalSimConfigError` + - `galsim.GalSimConfigValueError` + - `galsim.GalSimDeprecationWarning` + - `galsim.GalSimError` + - `galsim.GalSimFFTSizeError` + - `galsim.GalSimHSMError` + - `galsim.GalSimImmutableError` + - `galsim.GalSimIncompatibleValuesError` + - `galsim.GalSimIndexError` + - `galsim.GalSimKeyError` + - `galsim.GalSimNotImplementedError` + - `galsim.GalSimRangeError` + - `galsim.GalSimSEDError` + - `galsim.GalSimUndefinedBoundsError` + - `galsim.GalSimValueError` + - `galsim.GalSimWarning` + - `galsim.GammaDeviate` + - `galsim.Gaussian` + - `galsim.GaussianDeviate` + - `galsim.GaussianNoise` + - `galsim.Image` + - `galsim.ImageCD` + - `galsim.ImageCF` + - `galsim.ImageD` + - `galsim.ImageF` + - `galsim.ImageI` + - `galsim.ImageS` + - `galsim.ImageUI` + - `galsim.ImageUS` + - `galsim.Interpolant` + - `galsim.InterpolatedImage` + - `galsim.JacobianWCS` + - `galsim.Lanczos` + - `galsim.Linear` + - `galsim.Moffat` + - `galsim.Nearest` + - `galsim.OffsetShearWCS` + - `galsim.OffsetWCS` + - `galsim.PhotonArray` + - `galsim.Pixel` + - `galsim.PixelScale` + - `galsim.PoissonDeviate` + - `galsim.PoissonNoise` + - `galsim.Position` + - `galsim.PositionD` + - `galsim.PositionI` + - `galsim.Quintic` + - `galsim.Sensor` + - `galsim.Shear` + - `galsim.ShearWCS` + - `galsim.SincInterpolant` + - `galsim.Spergel` + - `galsim.Sum` + - `galsim.TanWCS` + - `galsim.Transform` + - `galsim.Transformation` + - `galsim.UniformDeviate` + - `galsim.VariableGaussianNoise` + - `galsim.WeibullDeviate` + - `galsim.bessel.j0` + - `galsim.bessel.kv` + - `galsim.bessel.si` + - `galsim.fits.closeHDUList` + - `galsim.fits.readCube` + - `galsim.fits.readFile` + - `galsim.fits.readMulti` + - `galsim.fits.write` + - `galsim.fits.writeFile` + - `galsim.fitswcs.CelestialWCS` + - `galsim.integ.int1d` + - `galsim.noise.addNoise` + - `galsim.noise.addNoiseSNR` + - `galsim.random.permute` + - `galsim.utilities.g1g2_to_e1e2` + - `galsim.utilities.horner` + - `galsim.utilities.printoptions` + - `galsim.utilities.unweighted_moments` + - `galsim.utilities.unweighted_shape` + - `galsim.wcs.EuclideanWCS` + - `galsim.wcs.LocalWCS` + - `galsim.wcs.UniformWCS` + +## Updating Coverage + +The coverage list is generated automatically: + +```bash +python scripts/update_api_coverage.py +``` + +This script compares GalSim's public API against `jax_galsim`'s implementations +and updates the coverage percentage and list in `README.md`. diff --git a/docs/api/composition/convolve.md b/docs/api/composition/convolve.md new file mode 100644 index 00000000..edfe1c3e --- /dev/null +++ b/docs/api/composition/convolve.md @@ -0,0 +1,7 @@ +# Convolution & Deconvolution + +Convolve profiles together (e.g., galaxy with PSF) or deconvolve. + +::: jax_galsim.convolve.Convolution + +::: jax_galsim.convolve.Deconvolution diff --git a/docs/api/composition/sum.md b/docs/api/composition/sum.md new file mode 100644 index 00000000..5106d845 --- /dev/null +++ b/docs/api/composition/sum.md @@ -0,0 +1,5 @@ +# Sum (Add) + +Add surface brightness profiles together. + +::: jax_galsim.sum.Sum diff --git a/docs/api/composition/transform.md b/docs/api/composition/transform.md new file mode 100644 index 00000000..51b89e14 --- /dev/null +++ b/docs/api/composition/transform.md @@ -0,0 +1,5 @@ +# Transform & Transformation + +Affine transformations of surface brightness profiles (shear, shift, rotation, flux scaling). + +::: jax_galsim.transform.Transformation diff --git a/docs/api/config/errors.md b/docs/api/config/errors.md new file mode 100644 index 00000000..3105921d --- /dev/null +++ b/docs/api/config/errors.md @@ -0,0 +1,5 @@ +# Errors & Warnings + +Exception and warning classes for JAX-GalSim error handling. + +::: jax_galsim.errors diff --git a/docs/api/config/gsparams.md b/docs/api/config/gsparams.md new file mode 100644 index 00000000..8ea867b5 --- /dev/null +++ b/docs/api/config/gsparams.md @@ -0,0 +1,5 @@ +# GSParams + +Numerical configuration parameters controlling accuracy and performance trade-offs. + +::: jax_galsim.gsparams.GSParams diff --git a/docs/api/config/utilities.md b/docs/api/config/utilities.md new file mode 100644 index 00000000..009851cf --- /dev/null +++ b/docs/api/config/utilities.md @@ -0,0 +1,5 @@ +# Utilities + +General utility functions. + +::: jax_galsim.utilities diff --git a/docs/api/coordinates/angle.md b/docs/api/coordinates/angle.md new file mode 100644 index 00000000..b49d6e74 --- /dev/null +++ b/docs/api/coordinates/angle.md @@ -0,0 +1,7 @@ +# Angle & AngleUnit + +Angle representation and unit conversion (radians, degrees, arcminutes, arcseconds, hours). + +::: jax_galsim.angle.Angle + +::: jax_galsim.angle.AngleUnit diff --git a/docs/api/coordinates/bounds.md b/docs/api/coordinates/bounds.md new file mode 100644 index 00000000..8bd03672 --- /dev/null +++ b/docs/api/coordinates/bounds.md @@ -0,0 +1,7 @@ +# Bounds + +Rectangular bounding box types for real-valued (`BoundsD`) and integer (`BoundsI`) coordinates. + +::: jax_galsim.bounds.BoundsD + +::: jax_galsim.bounds.BoundsI diff --git a/docs/api/coordinates/celestial.md b/docs/api/coordinates/celestial.md new file mode 100644 index 00000000..610ca0ad --- /dev/null +++ b/docs/api/coordinates/celestial.md @@ -0,0 +1,5 @@ +# CelestialCoord + +Celestial coordinate (RA, Dec) representation and operations. + +::: jax_galsim.celestial.CelestialCoord diff --git a/docs/api/coordinates/position.md b/docs/api/coordinates/position.md new file mode 100644 index 00000000..28510902 --- /dev/null +++ b/docs/api/coordinates/position.md @@ -0,0 +1,7 @@ +# Position + +2D position types for real-valued (`PositionD`) and integer (`PositionI`) coordinates. + +::: jax_galsim.position.PositionD + +::: jax_galsim.position.PositionI diff --git a/docs/api/coordinates/shear.md b/docs/api/coordinates/shear.md new file mode 100644 index 00000000..75f9442a --- /dev/null +++ b/docs/api/coordinates/shear.md @@ -0,0 +1,5 @@ +# Shear + +Gravitational shear representation with multiple parametrizations (g1/g2, e1/e2, eta1/eta2). + +::: jax_galsim.shear.Shear diff --git a/docs/api/core/draw.md b/docs/api/core/draw.md new file mode 100644 index 00000000..2897eb20 --- /dev/null +++ b/docs/api/core/draw.md @@ -0,0 +1,5 @@ +# Core: Drawing + +Internal drawing utilities for rendering profiles to pixel grids. + +::: jax_galsim.core.draw diff --git a/docs/api/core/interpolate.md b/docs/api/core/interpolate.md new file mode 100644 index 00000000..e400e8db --- /dev/null +++ b/docs/api/core/interpolate.md @@ -0,0 +1,5 @@ +# Core: Interpolation + +Internal interpolation utilities (Akima splines, coefficient computation). + +::: jax_galsim.core.interpolate diff --git a/docs/api/core/math.md b/docs/api/core/math.md new file mode 100644 index 00000000..a2188f04 --- /dev/null +++ b/docs/api/core/math.md @@ -0,0 +1,5 @@ +# Core: Math + +Gradient-safe mathematical utilities (`safe_sqrt`, etc.). + +::: jax_galsim.core.math diff --git a/docs/api/core/utils.md b/docs/api/core/utils.md new file mode 100644 index 00000000..6d2a4ff9 --- /dev/null +++ b/docs/api/core/utils.md @@ -0,0 +1,5 @@ +# Core: Utilities + +Core utilities: `@implements` decorator, `has_tracers()`, type casting helpers. + +::: jax_galsim.core.utils diff --git a/docs/api/image.md b/docs/api/image.md new file mode 100644 index 00000000..c965b13d --- /dev/null +++ b/docs/api/image.md @@ -0,0 +1,5 @@ +# Image + +Immutable JAX array wrapper with WCS and bounds metadata. + +::: jax_galsim.image.Image diff --git a/docs/api/interpolation/interpolant.md b/docs/api/interpolation/interpolant.md new file mode 100644 index 00000000..23bcd68d --- /dev/null +++ b/docs/api/interpolation/interpolant.md @@ -0,0 +1,17 @@ +# Interpolants + +Interpolation kernels for image resampling. + +::: jax_galsim.interpolant.Interpolant + +::: jax_galsim.interpolant.Nearest + +::: jax_galsim.interpolant.SincInterpolant + +::: jax_galsim.interpolant.Linear + +::: jax_galsim.interpolant.Cubic + +::: jax_galsim.interpolant.Quintic + +::: jax_galsim.interpolant.Lanczos diff --git a/docs/api/interpolation/interpolatedimage.md b/docs/api/interpolation/interpolatedimage.md new file mode 100644 index 00000000..0cb4b52e --- /dev/null +++ b/docs/api/interpolation/interpolatedimage.md @@ -0,0 +1,5 @@ +# InterpolatedImage + +Surface brightness profile defined by interpolation over a given image. + +::: jax_galsim.interpolatedimage.InterpolatedImage diff --git a/docs/api/math/bessel.md b/docs/api/math/bessel.md new file mode 100644 index 00000000..8b59134f --- /dev/null +++ b/docs/api/math/bessel.md @@ -0,0 +1,5 @@ +# Bessel Functions + +Bessel and related special functions. + +::: jax_galsim.bessel diff --git a/docs/api/math/integ.md b/docs/api/math/integ.md new file mode 100644 index 00000000..50278f00 --- /dev/null +++ b/docs/api/math/integ.md @@ -0,0 +1,5 @@ +# Integration + +Numerical integration utilities. + +::: jax_galsim.integ diff --git a/docs/api/noise/noise.md b/docs/api/noise/noise.md new file mode 100644 index 00000000..d15ad5fb --- /dev/null +++ b/docs/api/noise/noise.md @@ -0,0 +1,15 @@ +# Noise Models + +Noise classes for adding realistic noise to images. + +::: jax_galsim.noise.BaseNoise + +::: jax_galsim.noise.GaussianNoise + +::: jax_galsim.noise.PoissonNoise + +::: jax_galsim.noise.CCDNoise + +::: jax_galsim.noise.DeviateNoise + +::: jax_galsim.noise.VariableGaussianNoise diff --git a/docs/api/noise/random.md b/docs/api/noise/random.md new file mode 100644 index 00000000..88042531 --- /dev/null +++ b/docs/api/noise/random.md @@ -0,0 +1,19 @@ +# Random Deviates + +Random number generators. + +::: jax_galsim.random.BaseDeviate + +::: jax_galsim.random.UniformDeviate + +::: jax_galsim.random.GaussianDeviate + +::: jax_galsim.random.PoissonDeviate + +::: jax_galsim.random.Chi2Deviate + +::: jax_galsim.random.GammaDeviate + +::: jax_galsim.random.WeibullDeviate + +::: jax_galsim.random.BinomialDeviate diff --git a/docs/api/photons/photon_array.md b/docs/api/photons/photon_array.md new file mode 100644 index 00000000..30d34dfc --- /dev/null +++ b/docs/api/photons/photon_array.md @@ -0,0 +1,5 @@ +# PhotonArray + +Array of photon positions, fluxes, and other properties for photon shooting. + +::: jax_galsim.photon_array.PhotonArray diff --git a/docs/api/photons/sensor.md b/docs/api/photons/sensor.md new file mode 100644 index 00000000..6c45d17d --- /dev/null +++ b/docs/api/photons/sensor.md @@ -0,0 +1,5 @@ +# Sensor + +Sensor model for converting photons to pixel counts. + +::: jax_galsim.sensor.Sensor diff --git a/docs/api/profiles/box.md b/docs/api/profiles/box.md new file mode 100644 index 00000000..cecaa943 --- /dev/null +++ b/docs/api/profiles/box.md @@ -0,0 +1,7 @@ +# Box & Pixel + +Box (uniform rectangular) and Pixel (unit-width box) surface brightness profiles. + +::: jax_galsim.box.Box + +::: jax_galsim.box.Pixel diff --git a/docs/api/profiles/deltafunction.md b/docs/api/profiles/deltafunction.md new file mode 100644 index 00000000..69ce9f05 --- /dev/null +++ b/docs/api/profiles/deltafunction.md @@ -0,0 +1,5 @@ +# DeltaFunction + +Delta function (point source) surface brightness profile. + +::: jax_galsim.deltafunction.DeltaFunction diff --git a/docs/api/profiles/exponential.md b/docs/api/profiles/exponential.md new file mode 100644 index 00000000..9e93cdde --- /dev/null +++ b/docs/api/profiles/exponential.md @@ -0,0 +1,5 @@ +# Exponential + +Exponential surface brightness profile, commonly used for galaxy disk components. + +::: jax_galsim.exponential.Exponential diff --git a/docs/api/profiles/gaussian.md b/docs/api/profiles/gaussian.md new file mode 100644 index 00000000..99360444 --- /dev/null +++ b/docs/api/profiles/gaussian.md @@ -0,0 +1,5 @@ +# Gaussian + +Circular or elliptical Gaussian surface brightness profile. + +::: jax_galsim.gaussian.Gaussian diff --git a/docs/api/profiles/gsobject.md b/docs/api/profiles/gsobject.md new file mode 100644 index 00000000..9c3cdda9 --- /dev/null +++ b/docs/api/profiles/gsobject.md @@ -0,0 +1,5 @@ +# GSObject + +Base class for all surface brightness profiles. + +::: jax_galsim.gsobject.GSObject diff --git a/docs/api/profiles/moffat.md b/docs/api/profiles/moffat.md new file mode 100644 index 00000000..b7c1f34d --- /dev/null +++ b/docs/api/profiles/moffat.md @@ -0,0 +1,5 @@ +# Moffat + +Moffat surface brightness profile, commonly used for PSF modeling. + +::: jax_galsim.moffat.Moffat diff --git a/docs/api/profiles/spergel.md b/docs/api/profiles/spergel.md new file mode 100644 index 00000000..13c19314 --- /dev/null +++ b/docs/api/profiles/spergel.md @@ -0,0 +1,5 @@ +# Spergel + +Spergel surface brightness profile, a flexible model for galaxy light distributions. + +::: jax_galsim.spergel.Spergel diff --git a/docs/api/wcs/fits.md b/docs/api/wcs/fits.md new file mode 100644 index 00000000..07cb05d5 --- /dev/null +++ b/docs/api/wcs/fits.md @@ -0,0 +1,5 @@ +# FITS I/O + +FITS file reading, writing, and header handling. + +::: jax_galsim.fits diff --git a/docs/api/wcs/fitswcs.md b/docs/api/wcs/fitswcs.md new file mode 100644 index 00000000..b5721c38 --- /dev/null +++ b/docs/api/wcs/fitswcs.md @@ -0,0 +1,5 @@ +# FITS WCS + +FITS-based World Coordinate Systems. + +::: jax_galsim.fitswcs.GSFitsWCS diff --git a/docs/api/wcs/wcs.md b/docs/api/wcs/wcs.md new file mode 100644 index 00000000..b9ecf8e5 --- /dev/null +++ b/docs/api/wcs/wcs.md @@ -0,0 +1,17 @@ +# WCS Base Classes + +World Coordinate System hierarchy. + +::: jax_galsim.wcs.BaseWCS + +::: jax_galsim.wcs.PixelScale + +::: jax_galsim.wcs.ShearWCS + +::: jax_galsim.wcs.JacobianWCS + +::: jax_galsim.wcs.OffsetWCS + +::: jax_galsim.wcs.OffsetShearWCS + +::: jax_galsim.wcs.AffineTransform diff --git a/docs/architecture/drawing.md b/docs/architecture/drawing.md new file mode 100644 index 00000000..02b43235 --- /dev/null +++ b/docs/architecture/drawing.md @@ -0,0 +1,79 @@ +# Drawing Pipeline + +When you call `obj.drawImage()`, JAX-GalSim converts a surface brightness profile +into a pixel array. This page describes the rendering pipeline. + +## Overview + +```mermaid +flowchart TD + A["obj.drawImage(scale, nx, ny)"] --> B{Method?} + B -->|"auto / fft"| C["draw_by_kValue"] + B -->|"real_space"| D["draw_by_xValue"] + B -->|"phot"| E["Photon shooting"] + + C --> F["Evaluate _kValue on k-grid"] + F --> G["Inverse FFT"] + G --> H["Apply pixel response"] + + D --> I["Evaluate _xValue on pixel grid"] + I --> J["Direct summation"] + + E --> K["Generate PhotonArray"] + K --> L["Accumulate on Image"] + + H --> M["Image"] + J --> M + L --> M +``` + +## Drawing Methods + +### FFT Drawing (`draw_by_kValue`) + +The default method for most profiles. Steps: + +1. Determine image size from `_stepk` and `_maxk` (or user-specified `nx`, `ny`) +2. Build a grid of k-space positions +3. Evaluate `_kValue(kpos)` at each grid point +4. Inverse FFT to get the real-space image +5. Apply pixel convolution (if not already included) + +This is efficient for smooth profiles and convolutions, since convolution in +real space is multiplication in Fourier space. + +### Real-Space Drawing (`draw_by_xValue`) + +Evaluates the profile directly at each pixel center: + +1. Build a grid of real-space positions from image bounds and WCS +2. Evaluate `_xValue(pos)` at each grid point +3. Multiply by pixel area + +Implemented in `jax_galsim/core/draw.py`. This is simpler but slower for +large images. + +### Photon Shooting + +Generates random photon positions from the profile and accumulates them on an +image grid. Uses `PhotonArray` and `Sensor` for the photon-to-pixel mapping. + +## Convolution + +`Convolve([gal, psf])` doesn't immediately compute anything — it creates a lazy +`Convolution` object. The actual convolution happens at draw time: + +- **FFT method**: Multiply the k-space representations, then inverse FFT +- **Real-space method**: Only supported for simple profile combinations + +## The Role of `_maxk` and `_stepk` + +These properties control automatic image sizing: + +- **`_maxk`**: The maximum spatial frequency where the profile has significant + power. Determines the pixel scale (Nyquist sampling). +- **`_stepk`**: The spacing in k-space. Determines the image size (field of view). + +For a `Convolution`, `_maxk` is the minimum of the components' `_maxk` values +(the most compact Fourier representation wins), and `_stepk` is the minimum +`_stepk` (the largest real-space extent sets the field of view). diff --git a/docs/architecture/gsobject.md b/docs/architecture/gsobject.md new file mode 100644 index 00000000..c7e3aae6 --- /dev/null +++ b/docs/architecture/gsobject.md @@ -0,0 +1,91 @@ +# GSObject Hierarchy + +`GSObject` is the base class for all surface brightness profiles in JAX-GalSim. +Every galaxy model, PSF, and optical component inherits from it. + +## Class Hierarchy + +```mermaid +classDiagram + GSObject <|-- Gaussian + GSObject <|-- Moffat + GSObject <|-- Spergel + GSObject <|-- Exponential + GSObject <|-- DeltaFunction + GSObject <|-- Box + Box <|-- Pixel + GSObject <|-- InterpolatedImage + GSObject <|-- Sum + GSObject <|-- Convolution + GSObject <|-- Transformation + + class GSObject { + +_params: dict + +_gsparams: GSParams + +drawImage(scale, nx, ny) Image + +shear(shear) GSObject + +shift(dx, dy) GSObject + +_xValue(pos) float + +_kValue(kpos) complex + +_maxk: float + +_stepk: float + } + + class Gaussian { + +sigma: float + +flux: float + } + + class Moffat { + +beta: float + +scale_radius: float + +flux: float + } + + class Spergel { + +nu: float + +half_light_radius: float + +flux: float + } +``` + +## The Base Class Contract + +Subclasses must implement these methods and properties: + +| Member | Type | Purpose | +|--------|------|---------| +| `_xValue(pos)` | Method | Surface brightness at real-space position | +| `_kValue(kpos)` | Method | Fourier-space amplitude at frequency | +| `_maxk` | Property | Maximum k beyond which the profile is negligible | +| `_stepk` | Property | Sampling interval in k-space | + +The base class provides the public API built on these primitives: + +- `drawImage()` — Renders the profile to a pixel grid (delegates to `core.draw`) +- `shear()`, `shift()`, `rotate()`, `dilate()` — Return transformed copies via `Transformation` +- `__add__` — Returns a `Sum` of profiles +- `withFlux()`, `withGSParams()` — Return copies with modified parameters + +## The `_params` Dict + +All traced parameters live in a `_params` dictionary. This is the canonical +storage for values that JAX can differentiate through: + +```python +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +gal._params # {"flux": Array(1e5), "sigma": Array(2.0)} +``` + +Properties like `gal.sigma` and `gal.flux` are thin accessors into `_params`. + +## Composition Objects + +Profiles can be combined through three composition types: + +- **`Sum`** (`Add`): Adds surface brightness profiles together +- **`Convolution`** (`Convolve`): Convolves profiles (e.g., galaxy with PSF) +- **`Transformation`** (`Transform`): Applies affine transformations (shear, + shift, rotation, flux scaling) + +These are themselves GSObjects and can be nested arbitrarily. diff --git a/docs/architecture/implements.md b/docs/architecture/implements.md new file mode 100644 index 00000000..f4244540 --- /dev/null +++ b/docs/architecture/implements.md @@ -0,0 +1,60 @@ +# The @implements Decorator + +JAX-GalSim reuses GalSim's docstrings rather than duplicating them. The +`@implements` decorator handles this automatically. + +## Usage + +```python +import galsim as _galsim +from jax_galsim.core.utils import implements + +@implements(_galsim.Gaussian, + lax_description="LAX: Does not support ChromaticObject.") +class Gaussian(GSObject): + ... +``` + +This does three things: + +1. **Copies the docstring** from `_galsim.Gaussian` to `jax_galsim.Gaussian` +2. **Adds a LAX note** (the `lax_description`) documenting JAX-specific + differences or limitations +3. **Sets `__galsim_wrapped__`** on the class, linking back to the original + GalSim implementation + +## The `lax_description` Parameter + +Use this to document any differences from the reference GalSim implementation: + +```python +@implements(_galsim.Add, + lax_description="Does not support ChromaticObject at this point.") +def Add(*args, **kwargs): + return Sum(*args, **kwargs) +``` + +The description is inserted at the top of the docstring, after the summary line, +as a "LAX-backend implementation" note. + +## How It Works + +The decorator (defined in `jax_galsim/core/utils.py`) does the following: + +1. Retrieves the original function's docstring +2. Parses the GalSim-style numpydoc format (summary, parameters, etc.) +3. Reconstructs the docstring with: + - The original summary line + - A "LAX-backend implementation of `:func:original.name`" note + - The `lax_description` text (if provided) + - The rest of the original docstring body +4. Assigns the combined docstring to the wrapped function + +## When to Use It + +- **Always** when implementing a GalSim class or function — this is the standard + pattern in JAX-GalSim +- The `lax_description` should note any restricted functionality, different + behavior, or missing parameters compared to GalSim +- If there is no corresponding GalSim function (e.g., JAX-specific utilities), + write a normal docstring instead diff --git a/docs/architecture/index.md b/docs/architecture/index.md new file mode 100644 index 00000000..eeb45141 --- /dev/null +++ b/docs/architecture/index.md @@ -0,0 +1,112 @@ +# Architecture Overview + +JAX-GalSim mirrors GalSim's architecture while replacing NumPy/C++ internals +with pure JAX. This page provides a high-level map of the major components. + +## Component Map + +```mermaid +graph TB + subgraph "User API" + Profiles["Profiles
(Gaussian, Moffat, ...)"] + Comp["Composition
(Convolve, Sum, Transform)"] + Img["Image"] + Noise["Noise & Random"] + end + + subgraph "Coordinate System" + Pos["Position"] + Bounds["Bounds"] + WCS["WCS"] + Shear["Shear"] + Angle["Angle"] + end + + subgraph "Rendering Pipeline" + Draw["core.draw"] + FFT["FFT Convolution"] + Photon["Photon Shooting"] + Interp["Interpolation"] + end + + subgraph "JAX Infrastructure" + PyTree["PyTree Registration"] + Impl["@implements Decorator"] + Utils["Type Casting & Tracers"] + end + + Profiles --> Comp + Comp --> Draw + Draw --> Img + Photon --> Img + Noise --> Img + WCS --> Img + Pos --> Bounds + Bounds --> Img + Shear --> Profiles + Interp --> Draw + PyTree --> Profiles + PyTree --> Img + Impl --> Profiles + Utils --> Draw +``` + +## Module Layout + +``` +jax_galsim/ +├── core/ # JAX internals +│ ├── draw.py # Real-space & k-space drawing +│ ├── integrate.py # Numerical integration +│ ├── interpolate.py # Spline interpolation +│ ├── math.py # Gradient-safe math (safe_sqrt, etc.) +│ ├── utils.py # @implements, has_tracers(), type casting +│ └── wrap_image.py # Image wrapping utilities +├── gaussian.py # Gaussian profile +├── moffat.py # Moffat profile +├── spergel.py # Spergel profile +├── exponential.py # Exponential profile +├── box.py # Box & Pixel profiles +├── deltafunction.py # Delta function profile +├── gsobject.py # Base GSObject class +├── sum.py # Sum (Add) composition +├── convolve.py # Convolution & Deconvolution +├── transform.py # Affine transformations +├── interpolatedimage.py # InterpolatedImage +├── image.py # Image class hierarchy +├── wcs.py # WCS base classes +├── fitswcs.py # FITS WCS implementations +├── fits.py # FITS I/O +├── noise.py # Noise models +├── random.py # Random deviates +├── interpolant.py # Interpolant functions +├── photon_array.py # Photon arrays +├── sensor.py # Sensor model +├── position.py # Position types +├── bounds.py # Bounds types +├── angle.py # Angle & AngleUnit +├── shear.py # Shear +├── celestial.py # CelestialCoord +├── gsparams.py # GSParams configuration +├── errors.py # Exception classes +├── bessel.py # Bessel functions +├── integ.py # Integration (int1d) +└── utilities.py # General utilities +``` + +## Design Principles + +1. **Drop-in replacement**: Match GalSim's public API so that `import jax_galsim as galsim` works for supported features. + +2. **PyTree everywhere**: Every object is a registered JAX PyTree, separating traced parameters (children) from static configuration (auxiliary data). + +3. **Docstring inheritance**: The `@implements` decorator copies docstrings from GalSim, appending JAX-specific notes via `lax_description`. + +4. **Pure functions**: All operations are pure (no side effects), enabling `jit`, `grad`, and `vmap` compatibility. + +## Deep Dives + +- [PyTree Registration](pytree.md) — How objects become JAX-compatible +- [The @implements Decorator](implements.md) — Docstring inheritance from GalSim +- [GSObject Hierarchy](gsobject.md) — Base class contract and profile implementations +- [Drawing Pipeline](drawing.md) — How profiles become pixel images diff --git a/docs/architecture/pytree.md b/docs/architecture/pytree.md new file mode 100644 index 00000000..b9476c76 --- /dev/null +++ b/docs/architecture/pytree.md @@ -0,0 +1,77 @@ +# PyTree Registration + +JAX transformations (`jit`, `grad`, `vmap`) decompose Python objects into a flat +list of arrays (leaves) and a static structure (treedef). JAX-GalSim objects must +be registered as PyTrees for this to work. + +## The Pattern + +Every JAX-GalSim class uses `@register_pytree_node_class` and implements two methods: + +```python +from jax.tree_util import register_pytree_node_class + +@register_pytree_node_class +class Gaussian(GSObject): + def tree_flatten(self): + # Children: values that JAX can trace (differentiate through) + children = (self.params,) + # Aux data: static values that trigger recompilation if changed + aux_data = (self.gsparams,) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + # Reconstruct the object from its parts + return cls(params=children[0], gsparams=aux_data[0]) +``` + +## Children vs Auxiliary Data + +| Component | Role | Examples | Effect of changing | +|-----------|------|----------|--------------------| +| **Children** (traced) | Values JAX differentiates through | `flux`, `sigma`, `half_light_radius` | Triggers re-evaluation, not recompilation | +| **Auxiliary data** (static) | Structure and configuration | `GSParams`, enum flags | Triggers full recompilation under `jit` | + +In practice, profile parameters live in a `_params` dict (children) and numerical +configuration lives in `_gsparams` (auxiliary): + +```python +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +# gal._params = {"flux": 1e5, "sigma": 2.0} — traced +# gal._gsparams = GSParams(...) — static +``` + +## The `__init__` Gotcha + +During `tree_unflatten`, JAX calls the constructor with potentially traced values +(not concrete Python numbers). If `__init__` performs type checks like +`isinstance(sigma, float)`, these will fail on JAX tracers. + +The recommended solution (from the +[JAX docs](https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization)): +separate validation from initialization, or use `has_tracers()` to skip checks +during tracing: + +```python +from jax_galsim.core.utils import has_tracers + +class MyProfile(GSObject): + def __init__(self, sigma, gsparams=None): + if not has_tracers(sigma): + # Only validate with concrete values + if sigma <= 0: + raise ValueError("sigma must be positive") + ... +``` + +## Practical Implications + +1. **`jax.grad` works**: Because profile parameters are traced children, you get + gradients for free. + +2. **`GSParams` changes recompile**: Changing `GSParams` between calls to a + `jit`-compiled function triggers recompilation, since it's static auxiliary data. + +3. **No mutable state**: PyTree flattening and unflattening means all state must + be reconstructable from children + aux_data. There's no hidden mutable state. diff --git a/docs/getting-started/index.md b/docs/getting-started/index.md new file mode 100644 index 00000000..a8733c82 --- /dev/null +++ b/docs/getting-started/index.md @@ -0,0 +1,7 @@ +# Getting Started + +New to JAX-GalSim? Start here. + +- [Installation](installation.md) — Install JAX-GalSim and set up GPU support +- [Quick Start](quickstart.md) — Simulate a galaxy image in a few lines of code +- [Key Concepts](key-concepts.md) — Understand GSObjects, Images, PyTrees, and how JAX changes things diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md new file mode 100644 index 00000000..88142dd2 --- /dev/null +++ b/docs/getting-started/installation.md @@ -0,0 +1,69 @@ +# Installation + +## Quick Install + +```bash +pip install jax-galsim +``` + +This installs JAX-GalSim and its dependencies (JAX, NumPy, GalSim, Astropy). + +## GPU Support + +JAX-GalSim inherits GPU support from JAX. To use NVIDIA GPUs, install the appropriate JAX variant: + +```bash +pip install -U "jax[cuda12]" +``` + +See the [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html) for other accelerators and platform-specific instructions. + +## Development Install + +To contribute to JAX-GalSim or run the test suite: + +```bash +# Clone with submodules (required for GalSim reference tests) +git clone --recurse-submodules https://github.com/GalSim-developers/JAX-GalSim +cd JAX-GalSim + +# Create a virtual environment +python -m venv .venv && source .venv/bin/activate + +# Install in editable mode with dev dependencies +pip install -e ".[dev]" + +# Install pre-commit hooks +pre-commit install +``` + +### Running Tests + +```bash +# Run all tests +pytest + +# Run a specific test file +pytest tests/jax/test_api.py + +# Run a specific test +pytest tests/jax/test_api.py::test_api_same + +# Verbose output with timing +pytest -vv --durations=100 +``` + +### Linting and Formatting + +```bash +# Lint +ruff check . --fix + +# Format +ruff format . + +# Or run both via pre-commit +pre-commit run --all-files +``` + +See [CONTRIBUTING.md](https://github.com/GalSim-developers/JAX-GalSim/blob/main/CONTRIBUTING.md) for full contribution guidelines. diff --git a/docs/getting-started/key-concepts.md b/docs/getting-started/key-concepts.md new file mode 100644 index 00000000..9b7c154e --- /dev/null +++ b/docs/getting-started/key-concepts.md @@ -0,0 +1,96 @@ +# Key Concepts + +This page covers the core ideas you need to understand when using JAX-GalSim. + +## GSObject: The Building Block + +Every galaxy profile, PSF model, and optical component is a **GSObject** — the +base class for all surface brightness profiles. GSObjects support: + +- **Arithmetic**: `+` (sum), `*` (scalar multiply) +- **Transformations**: `.shear()`, `.shift()`, `.dilate()`, `.rotate()` +- **Convolution**: `jax_galsim.Convolve([obj1, obj2, ...])` +- **Drawing**: `.drawImage(scale=...)` renders the profile to a pixel grid + +```python +import jax_galsim + +# A Gaussian galaxy, sheared and shifted +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +gal = gal.shear(g1=0.2, g2=0.1) +gal = gal.shift(dx=0.5, dy=-0.3) +``` + +## Images Are Immutable + +Unlike GalSim (which uses mutable NumPy arrays), JAX-GalSim images are +**immutable** because JAX arrays cannot be modified in-place. Operations that +would mutate an image in GalSim return a new image instead: + +```python +# GalSim (mutable): +# image.addNoise(noise) # modifies image in-place +# +# JAX-GalSim (immutable): +image = image.addNoise(noise) # returns a new image +``` + +This is the most common source of differences when porting GalSim code. + +## PyTree Registration + +JAX transformations (`jit`, `grad`, `vmap`) need to understand how to +decompose objects into arrays and metadata. JAX-GalSim objects are registered +as **PyTrees** with two components: + +- **Children** (traced): Parameters that can be differentiated — stored in + `obj._params` (e.g., flux, sigma, half_light_radius) +- **Auxiliary data** (static): Configuration that doesn't change — stored in + `obj._gsparams` and similar fields + +This means you can pass JAX-GalSim objects directly to `jit`, `grad`, and `vmap` +without any special handling: + +```python +import jax + +@jax.jit +def render(gal): + return gal.drawImage(scale=0.2).array + +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +result = render(gal) # gal is automatically flattened/unflattened +``` + +See [PyTree Registration](../architecture/pytree.md) for implementation details. + +## GSParams: Numerical Configuration + +`GSParams` controls numerical accuracy and performance trade-offs for rendering: + +```python +# Use tighter tolerances for high-precision work +gsparams = jax_galsim.GSParams(maximum_fft_size=8192, kvalue_accuracy=1e-6) +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0, gsparams=gsparams) +``` + +`GSParams` is treated as **static** auxiliary data in the PyTree, so changing it +triggers recompilation under `jit`. + +## Functional Random Numbers + +JAX uses a functional PRNG — random state is explicit and never mutated: + +```python +import jax + +# Create a random key +key = jax.random.PRNGKey(42) + +# Create a noise model (JAX-GalSim wraps this in GalSim-compatible API) +noise = jax_galsim.GaussianNoise(sigma=30.0) +image = image.addNoise(noise) +``` + +See [Notable Differences](../notable-differences.md) for more on how the RNG +differs from GalSim. diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md new file mode 100644 index 00000000..9f00e076 --- /dev/null +++ b/docs/getting-started/quickstart.md @@ -0,0 +1,103 @@ +# Quick Start + +This tutorial walks through a complete galaxy image simulation, then shows how +JAX transformations (`jit`, `grad`, `vmap`) apply to it. + +## A Simple Simulation + +This example creates a Gaussian galaxy, convolves it with a Gaussian PSF, draws +the image, and adds noise — equivalent to GalSim's `demo1.py`. + +```python +import jax_galsim + +# Galaxy parameters +gal_flux = 1e5 # total counts +gal_sigma = 2.0 # arcsec +psf_sigma = 1.0 # arcsec +pixel_scale = 0.2 # arcsec/pixel +noise_sigma = 30.0 # counts per pixel + +# Define profiles +gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) +psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) + +# Convolve galaxy with PSF +final = jax_galsim.Convolve([gal, psf]) + +# Draw the image +image = final.drawImage(scale=pixel_scale) + +# Add Gaussian noise +image = image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma)) + +# Write to FITS +image.write("output/demo1.fits") +``` + +The API is intentionally close to GalSim — most GalSim tutorials translate +directly to JAX-GalSim by replacing `import galsim` with `import jax_galsim`. + +## JIT Compilation + +Wrap your simulation in `jax.jit` to compile it into an optimized XLA computation: + +```python +import jax + +@jax.jit +def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(scale=0.2) + +# First call compiles; subsequent calls are fast +image = simulate(1e5, 2.0) +``` + +## Automatic Differentiation + +Compute gradients of any scalar output with respect to galaxy parameters: + +```python +def total_flux(gal_sigma, psf_sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=gal_sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) + final = jax_galsim.Convolve([gal, psf]) + image = final.drawImage(scale=0.2) + return image.array.sum() + +# Gradient of total image flux with respect to both sigmas +grad_fn = jax.grad(total_flux, argnums=(0, 1)) +d_gal, d_psf = grad_fn(2.0, 1.0) +``` + +This is useful for fitting galaxy models to data, where gradients enable +efficient optimization. + +## Vectorization with vmap + +Simulate a batch of galaxies with different parameters without writing loops: + +```python +import jax.numpy as jnp + +sigmas = jnp.linspace(1.0, 4.0, 10) + +@jax.vmap +def batch_simulate(sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(scale=0.2, nx=64, ny=64).array + +# Simulate all 10 galaxies in parallel +images = batch_simulate(sigmas) # shape: (10, 64, 64) +``` + +## Next Steps + +- [Key Concepts](key-concepts.md) — Understand the design behind JAX-GalSim +- [Notable Differences](../notable-differences.md) — What's different from GalSim +- [API Reference](../api/profiles/gaussian.md) — Full API documentation diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..ecb15ae5 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,80 @@ +# JAX-GalSim + +**JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.** + +[![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) + +!!! warning "Early Development" + + This project is still in an early development phase. Please use the + [reference GalSim implementation](https://github.com/GalSim-developers/GalSim) + for any scientific applications. + +--- + +## Why JAX-GalSim? + +JAX-GalSim reimplements [GalSim](https://github.com/GalSim-developers/GalSim) in pure JAX, unlocking three capabilities for galaxy image simulation: + +!!! tip "JIT Compilation" + + Compile simulation pipelines with `jax.jit` for significant speedups, especially on GPU. + +!!! tip "Automatic Differentiation" + + Compute gradients of simulation outputs with respect to galaxy parameters using `jax.grad`. + +!!! tip "Vectorization" + + Batch simulations over parameter grids with `jax.vmap` --- no explicit loops needed. + +--- + +## Quick Install + +```bash +pip install jax-galsim +``` + +## Minimal Example + +```python +import jax +import jax_galsim + +# Define a galaxy and PSF +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + +# Convolve and draw +final = jax_galsim.Convolve([gal, psf]) +image = final.drawImage(scale=0.2) + +# Add noise +image = image.addNoise(jax_galsim.GaussianNoise(sigma=30.0)) +``` + +Because JAX-GalSim objects are JAX pytrees, you can JIT-compile and differentiate the entire pipeline: + +```python +@jax.jit +def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + return jax_galsim.Convolve([gal, psf]).drawImage(scale=0.2).array.sum() + +# Compute gradients with respect to galaxy parameters +grad_fn = jax.grad(simulate, argnums=(0, 1)) +dflux, dsigma = grad_fn(1e5, 2.0) +``` + +--- + +## Next Steps + +- [Installation](getting-started/installation.md) --- Set up JAX-GalSim with GPU support +- [Quick Start](getting-started/quickstart.md) --- Walk through a complete simulation +- [Key Concepts](getting-started/key-concepts.md) --- Understand JAX-GalSim's design +- [API Reference](api/profiles/gaussian.md) --- Browse the full API diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js new file mode 100644 index 00000000..117b0460 --- /dev/null +++ b/docs/javascripts/mathjax.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true, + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex", + }, +}; + +document$.subscribe(() => { + MathJax.typesetPromise(); +}); diff --git a/docs/notable-differences.md b/docs/notable-differences.md new file mode 100644 index 00000000..9bfe759f --- /dev/null +++ b/docs/notable-differences.md @@ -0,0 +1,73 @@ +# Notable Differences from GalSim + +JAX-GalSim strives to be a drop-in replacement for GalSim, but JAX's design +imposes some fundamental differences. This page documents them. + +## Immutability + +JAX arrays are immutable — you cannot modify them in-place. This affects all +image operations: + +```python +# GalSim: mutates image in-place +image.addNoise(noise) +image.array[10, 10] = 0.0 + +# JAX-GalSim: returns a new image +image = image.addNoise(noise) +# Direct array mutation is not supported +``` + +Any GalSim code that relies on in-place modification of images needs to be +rewritten to use the return value. + +## Array Views + +JAX does not support all the array view semantics that NumPy provides. In +particular: + +- **Real views of complex images** are not available. In GalSim, you can get a + real-valued view of a complex image's real or imaginary part that shares + memory. JAX-GalSim returns copies instead. + +## Random Number Generation + +JAX uses a **functional PRNG** — random state is explicit and must be threaded +through computations. Key differences: + +- JAX's PRNG is deterministic and reproducible across platforms +- Random deviates cannot "fill" an existing array; they return new arrays +- The sequence of random numbers differs from GalSim's RNG, so results will + not be numerically identical even with the same seed +- Different RNG classes may have different stability properties for discarding + +## Profile Restrictions + +Some GalSim features are not yet implemented: + +- **Truncated Moffat profiles** are not supported (the `trunc` parameter) +- **ChromaticObject** and all chromatic functionality is not available +- **InterpolatedKImage** is not implemented +- See [API Coverage](api-coverage.md) for the full list of supported APIs + +## Control Flow and Tracing + +JAX's tracing system places restrictions on Python control flow: + +- **`if`/`else` on traced values**: You cannot branch on values that JAX is + tracing (e.g., profile parameters inside a `jit`-compiled function). Use + `jax.lax.cond` instead. +- **Variable-size operations**: Operations whose output shape depends on input + values (e.g., adaptive image sizing) may not work under `jit`. + +JAX-GalSim uses `has_tracers()` internally to detect when code is being traced +and avoid problematic control flow patterns. + +## Numerical Precision + +Some operations may produce slightly different numerical results due to: + +- Different order of floating-point operations (JAX may reorder for performance) +- Use of XLA-compiled math kernels instead of system math libraries +- Custom gradient-safe implementations (e.g., `safe_sqrt` in `core/math.py`) + that handle edge cases differently for differentiability diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..fce1a980 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,132 @@ +site_name: JAX-GalSim +site_description: JAX port of GalSim for GPU-accelerated, differentiable galaxy image simulations +site_url: https://galsim-developers.github.io/JAX-GalSim/ +repo_url: https://github.com/GalSim-developers/JAX-GalSim +repo_name: GalSim-developers/JAX-GalSim + +theme: + name: material + palette: + - media: "(prefers-color-scheme: light)" + scheme: default + primary: indigo + accent: indigo + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: indigo + accent: indigo + toggle: + icon: material/brightness-4 + name: Switch to light mode + features: + - content.code.copy + - navigation.sections + - navigation.expand + - navigation.top + - navigation.indexes + - search.highlight + - toc.follow + - toc.integrate + +plugins: + - search + - mkdocstrings: + handlers: + python: + options: + docstring_style: numpy + show_source: true + merge_init_into_class: true + members_order: source + show_root_heading: true + show_root_full_path: false + show_if_no_docstring: true + heading_level: 2 + show_symbol_type_heading: true + show_symbol_type_toc: true + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.tabbed: + alternate_style: true + - pymdownx.arithmatex: + generic: true + - toc: + permalink: true + +extra_javascript: + - javascripts/mathjax.js + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + +nav: + - Home: index.md + - Getting Started: + - getting-started/index.md + - Installation: getting-started/installation.md + - Quick Start: getting-started/quickstart.md + - Key Concepts: getting-started/key-concepts.md + - Architecture: + - architecture/index.md + - PyTree Registration: architecture/pytree.md + - The @implements Decorator: architecture/implements.md + - GSObject Hierarchy: architecture/gsobject.md + - Drawing Pipeline: architecture/drawing.md + - Notable Differences: notable-differences.md + - API Coverage: api-coverage.md + - API Reference: + - Profiles: + - api/profiles/gsobject.md + - api/profiles/gaussian.md + - api/profiles/moffat.md + - api/profiles/spergel.md + - api/profiles/exponential.md + - api/profiles/deltafunction.md + - api/profiles/box.md + - Composition: + - api/composition/convolve.md + - api/composition/sum.md + - api/composition/transform.md + - Image: api/image.md + - Coordinates: + - api/coordinates/position.md + - api/coordinates/bounds.md + - api/coordinates/angle.md + - api/coordinates/shear.md + - api/coordinates/celestial.md + - WCS: + - api/wcs/wcs.md + - api/wcs/fitswcs.md + - api/wcs/fits.md + - Noise & Random: + - api/noise/random.md + - api/noise/noise.md + - Interpolation: + - api/interpolation/interpolant.md + - api/interpolation/interpolatedimage.md + - Photon Shooting: + - api/photons/photon_array.md + - api/photons/sensor.md + - Configuration: + - api/config/gsparams.md + - api/config/errors.md + - api/config/utilities.md + - Math: + - api/math/bessel.md + - api/math/integ.md + - Core Internals: + - api/core/draw.md + - api/core/interpolate.md + - api/core/math.md + - api/core/utils.md diff --git a/pyproject.toml b/pyproject.toml index 3a726884..7cf120fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,14 @@ dependencies = [ [project.optional-dependencies] dev = ["pytest", "pytest-codspeed"] +docs = [ + "mkdocs-material>=9.0", + "mkdocstrings[python]>=0.24", +] [project.urls] home = "https://github.com/GalSim-developers/JAX-GalSim" +documentation = "https://galsim-developers.github.io/JAX-GalSim/" [tool.setuptools.packages.find] include = ["jax_galsim*"] From 36666f886eafa5656b36b191aa6912a9b67202e0 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 20 Feb 2026 18:05:10 +0100 Subject: [PATCH 02/10] docs: trim verbosity and redundancy across documentation pages Remove duplicated explanations, tighten prose, and condense sections that repeated information already covered on other pages. Net reduction of ~80 lines with no content loss -- cross-links replace duplication. Co-Authored-By: Claude Opus 4.6 --- docs/api-coverage.md | 5 +-- docs/architecture/drawing.md | 45 +++++++++---------------- docs/architecture/gsobject.md | 20 +++++------ docs/architecture/implements.md | 22 ++++-------- docs/architecture/index.md | 14 +++----- docs/architecture/pytree.md | 21 ++++-------- docs/getting-started/key-concepts.md | 50 +++++++++------------------- docs/getting-started/quickstart.md | 16 ++++----- docs/index.md | 4 +-- docs/notable-differences.md | 49 ++++++++++----------------- mkdocs.yml | 1 - 11 files changed, 84 insertions(+), 163 deletions(-) diff --git a/docs/api-coverage.md b/docs/api-coverage.md index 48ab74a2..702cb085 100644 --- a/docs/api-coverage.md +++ b/docs/api-coverage.md @@ -121,11 +121,8 @@ the most commonly used profiles and operations, with coverage expanding over tim ## Updating Coverage -The coverage list is generated automatically: - ```bash python scripts/update_api_coverage.py ``` -This script compares GalSim's public API against `jax_galsim`'s implementations -and updates the coverage percentage and list in `README.md`. +Compares GalSim's public API against `jax_galsim`'s implementations and updates the coverage percentage and list in `README.md`. diff --git a/docs/architecture/drawing.md b/docs/architecture/drawing.md index 02b43235..cce389e5 100644 --- a/docs/architecture/drawing.md +++ b/docs/architecture/drawing.md @@ -1,7 +1,6 @@ # Drawing Pipeline -When you call `obj.drawImage()`, JAX-GalSim converts a surface brightness profile -into a pixel array. This page describes the rendering pipeline. +`obj.drawImage()` converts a surface brightness profile into a pixel array. ## Overview @@ -31,49 +30,37 @@ flowchart TD ### FFT Drawing (`draw_by_kValue`) -The default method for most profiles. Steps: +The default for most profiles: -1. Determine image size from `_stepk` and `_maxk` (or user-specified `nx`, `ny`) -2. Build a grid of k-space positions -3. Evaluate `_kValue(kpos)` at each grid point -4. Inverse FFT to get the real-space image -5. Apply pixel convolution (if not already included) +1. Determine image size from `_stepk`/`_maxk` (or user-specified `nx`, `ny`) +2. Evaluate `_kValue(kpos)` on a k-space grid +3. Inverse FFT to real space +4. Apply pixel convolution (if not already included) -This is efficient for smooth profiles and convolutions, since convolution in -real space is multiplication in Fourier space. +Efficient for smooth profiles and convolutions (convolution in real space = multiplication in Fourier space). ### Real-Space Drawing (`draw_by_xValue`) -Evaluates the profile directly at each pixel center: - -1. Build a grid of real-space positions from image bounds and WCS -2. Evaluate `_xValue(pos)` at each grid point -3. Multiply by pixel area - -Implemented in `jax_galsim/core/draw.py`. This is simpler but slower for -large images. +Evaluates `_xValue(pos)` directly at each pixel center, multiplied by pixel area. Simpler but slower for large images. Implemented in `jax_galsim/core/draw.py`. ### Photon Shooting Generates random photon positions from the profile and accumulates them on an -image grid. Uses `PhotonArray` and `Sensor` for the photon-to-pixel mapping. +image grid via `PhotonArray` and `Sensor`. ## Convolution -`Convolve([gal, psf])` doesn't immediately compute anything — it creates a lazy -`Convolution` object. The actual convolution happens at draw time: +`Convolve([gal, psf])` creates a lazy `Convolution` object. The actual +convolution happens at draw time: -- **FFT method**: Multiply the k-space representations, then inverse FFT +- **FFT method**: Multiply k-space representations, then inverse FFT - **Real-space method**: Only supported for simple profile combinations -## The Role of `_maxk` and `_stepk` +## `_maxk` and `_stepk` These properties control automatic image sizing: -- **`_maxk`**: The maximum spatial frequency where the profile has significant - power. Determines the pixel scale (Nyquist sampling). -- **`_stepk`**: The spacing in k-space. Determines the image size (field of view). +- **`_maxk`**: Maximum significant spatial frequency. Determines pixel scale (Nyquist sampling). +- **`_stepk`**: Spacing in k-space. Determines image size (field of view). -For a `Convolution`, `_maxk` is the minimum of the components' `_maxk` values -(the most compact Fourier representation wins), and `_stepk` is the minimum -`_stepk` (the largest real-space extent sets the field of view). +For a `Convolution`, `_maxk` = min of components (most compact Fourier representation wins), `_stepk` = min of components (largest real-space extent sets the field of view). diff --git a/docs/architecture/gsobject.md b/docs/architecture/gsobject.md index c7e3aae6..1cb170b8 100644 --- a/docs/architecture/gsobject.md +++ b/docs/architecture/gsobject.md @@ -1,7 +1,7 @@ # GSObject Hierarchy -`GSObject` is the base class for all surface brightness profiles in JAX-GalSim. -Every galaxy model, PSF, and optical component inherits from it. +`GSObject` is the base class for all surface brightness profiles. Every galaxy +model, PSF, and optical component inherits from it. ## Class Hierarchy @@ -69,23 +69,19 @@ The base class provides the public API built on these primitives: ## The `_params` Dict -All traced parameters live in a `_params` dictionary. This is the canonical -storage for values that JAX can differentiate through: +Traced parameters live in a `_params` dictionary -- the canonical storage for +values that JAX differentiates through. Properties like `gal.sigma` are +accessors into `_params`: ```python gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) gal._params # {"flux": Array(1e5), "sigma": Array(2.0)} ``` -Properties like `gal.sigma` and `gal.flux` are thin accessors into `_params`. - ## Composition Objects -Profiles can be combined through three composition types: +Profiles combine through three composition types (themselves GSObjects, nestable arbitrarily): -- **`Sum`** (`Add`): Adds surface brightness profiles together +- **`Sum`** (`Add`): Adds surface brightness profiles - **`Convolution`** (`Convolve`): Convolves profiles (e.g., galaxy with PSF) -- **`Transformation`** (`Transform`): Applies affine transformations (shear, - shift, rotation, flux scaling) - -These are themselves GSObjects and can be nested arbitrarily. +- **`Transformation`** (`Transform`): Affine transforms (shear, shift, rotation, flux scaling) diff --git a/docs/architecture/implements.md b/docs/architecture/implements.md index f4244540..835a2b47 100644 --- a/docs/architecture/implements.md +++ b/docs/architecture/implements.md @@ -1,7 +1,6 @@ # The @implements Decorator -JAX-GalSim reuses GalSim's docstrings rather than duplicating them. The -`@implements` decorator handles this automatically. +JAX-GalSim reuses GalSim's docstrings via the `@implements` decorator. ## Usage @@ -25,7 +24,7 @@ This does three things: ## The `lax_description` Parameter -Use this to document any differences from the reference GalSim implementation: +Documents differences from the reference GalSim implementation: ```python @implements(_galsim.Add, @@ -34,21 +33,14 @@ def Add(*args, **kwargs): return Sum(*args, **kwargs) ``` -The description is inserted at the top of the docstring, after the summary line, -as a "LAX-backend implementation" note. - ## How It Works -The decorator (defined in `jax_galsim/core/utils.py`) does the following: +The decorator (in `jax_galsim/core/utils.py`): -1. Retrieves the original function's docstring -2. Parses the GalSim-style numpydoc format (summary, parameters, etc.) -3. Reconstructs the docstring with: - - The original summary line - - A "LAX-backend implementation of `:func:original.name`" note - - The `lax_description` text (if provided) - - The rest of the original docstring body -4. Assigns the combined docstring to the wrapped function +1. Copies the original GalSim docstring +2. Inserts a "LAX-backend implementation of `:func:original.name`" note after the summary line +3. Appends the `lax_description` text (if provided) +4. Assigns the combined docstring to the wrapped class/function ## When to Use It diff --git a/docs/architecture/index.md b/docs/architecture/index.md index eeb45141..864a5167 100644 --- a/docs/architecture/index.md +++ b/docs/architecture/index.md @@ -1,7 +1,6 @@ # Architecture Overview -JAX-GalSim mirrors GalSim's architecture while replacing NumPy/C++ internals -with pure JAX. This page provides a high-level map of the major components. +JAX-GalSim mirrors GalSim's architecture, replacing NumPy/C++ internals with pure JAX. ## Component Map @@ -96,13 +95,10 @@ jax_galsim/ ## Design Principles -1. **Drop-in replacement**: Match GalSim's public API so that `import jax_galsim as galsim` works for supported features. - -2. **PyTree everywhere**: Every object is a registered JAX PyTree, separating traced parameters (children) from static configuration (auxiliary data). - -3. **Docstring inheritance**: The `@implements` decorator copies docstrings from GalSim, appending JAX-specific notes via `lax_description`. - -4. **Pure functions**: All operations are pure (no side effects), enabling `jit`, `grad`, and `vmap` compatibility. +1. **Drop-in replacement**: `import jax_galsim as galsim` works for supported features. +2. **PyTree everywhere**: Every object separates traced parameters (children) from static configuration (auxiliary data). +3. **Docstring inheritance**: `@implements` copies GalSim docstrings, appending JAX-specific notes. +4. **Pure functions**: No side effects, enabling `jit`, `grad`, and `vmap`. ## Deep Dives diff --git a/docs/architecture/pytree.md b/docs/architecture/pytree.md index b9476c76..baeff6f3 100644 --- a/docs/architecture/pytree.md +++ b/docs/architecture/pytree.md @@ -1,12 +1,8 @@ # PyTree Registration -JAX transformations (`jit`, `grad`, `vmap`) decompose Python objects into a flat -list of arrays (leaves) and a static structure (treedef). JAX-GalSim objects must -be registered as PyTrees for this to work. - -## The Pattern - -Every JAX-GalSim class uses `@register_pytree_node_class` and implements two methods: +JAX transformations decompose objects into arrays (leaves) and static structure +(treedef). Every JAX-GalSim class uses `@register_pytree_node_class` and +implements two methods: ```python from jax.tree_util import register_pytree_node_class @@ -44,14 +40,9 @@ gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) ## The `__init__` Gotcha -During `tree_unflatten`, JAX calls the constructor with potentially traced values -(not concrete Python numbers). If `__init__` performs type checks like -`isinstance(sigma, float)`, these will fail on JAX tracers. - -The recommended solution (from the -[JAX docs](https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization)): -separate validation from initialization, or use `has_tracers()` to skip checks -during tracing: +During `tree_unflatten`, JAX calls the constructor with traced values, not +concrete Python numbers. Type checks like `isinstance(sigma, float)` will fail +on tracers. Use `has_tracers()` to skip validation during tracing: ```python from jax_galsim.core.utils import has_tracers diff --git a/docs/getting-started/key-concepts.md b/docs/getting-started/key-concepts.md index 9b7c154e..a9a239f1 100644 --- a/docs/getting-started/key-concepts.md +++ b/docs/getting-started/key-concepts.md @@ -1,11 +1,8 @@ # Key Concepts -This page covers the core ideas you need to understand when using JAX-GalSim. - ## GSObject: The Building Block -Every galaxy profile, PSF model, and optical component is a **GSObject** — the -base class for all surface brightness profiles. GSObjects support: +Every galaxy profile, PSF, and optical component is a **GSObject**. GSObjects support: - **Arithmetic**: `+` (sum), `*` (scalar multiply) - **Transformations**: `.shear()`, `.shift()`, `.dilate()`, `.rotate()` @@ -23,33 +20,21 @@ gal = gal.shift(dx=0.5, dy=-0.3) ## Images Are Immutable -Unlike GalSim (which uses mutable NumPy arrays), JAX-GalSim images are -**immutable** because JAX arrays cannot be modified in-place. Operations that -would mutate an image in GalSim return a new image instead: +JAX arrays cannot be modified in-place, so operations that would mutate an image +in GalSim return a new image instead: ```python -# GalSim (mutable): -# image.addNoise(noise) # modifies image in-place -# -# JAX-GalSim (immutable): -image = image.addNoise(noise) # returns a new image +# GalSim: image.addNoise(noise) # modifies in-place +# JAX-GalSim: image = image.addNoise(noise) # returns new image ``` -This is the most common source of differences when porting GalSim code. +This is the most common difference when porting GalSim code. +See [Notable Differences](../notable-differences.md) for the full list. ## PyTree Registration -JAX transformations (`jit`, `grad`, `vmap`) need to understand how to -decompose objects into arrays and metadata. JAX-GalSim objects are registered -as **PyTrees** with two components: - -- **Children** (traced): Parameters that can be differentiated — stored in - `obj._params` (e.g., flux, sigma, half_light_radius) -- **Auxiliary data** (static): Configuration that doesn't change — stored in - `obj._gsparams` and similar fields - -This means you can pass JAX-GalSim objects directly to `jit`, `grad`, and `vmap` -without any special handling: +All JAX-GalSim objects are registered as JAX **PyTrees**, so you can pass them +directly to `jit`, `grad`, and `vmap`: ```python import jax @@ -62,7 +47,10 @@ gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) result = render(gal) # gal is automatically flattened/unflattened ``` -See [PyTree Registration](../architecture/pytree.md) for implementation details. +Profile parameters (flux, sigma, etc.) are **traced children** that JAX can +differentiate through. Configuration like `GSParams` is **static auxiliary data** +that triggers recompilation when changed. See [PyTree Registration](../architecture/pytree.md) +for details. ## GSParams: Numerical Configuration @@ -79,18 +67,12 @@ triggers recompilation under `jit`. ## Functional Random Numbers -JAX uses a functional PRNG — random state is explicit and never mutated: +JAX uses a functional PRNG -- random state is explicit and never mutated. +JAX-GalSim wraps this in GalSim's familiar noise API: ```python -import jax - -# Create a random key -key = jax.random.PRNGKey(42) - -# Create a noise model (JAX-GalSim wraps this in GalSim-compatible API) noise = jax_galsim.GaussianNoise(sigma=30.0) image = image.addNoise(noise) ``` -See [Notable Differences](../notable-differences.md) for more on how the RNG -differs from GalSim. +See [Notable Differences](../notable-differences.md) for details on RNG behavior. diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 9f00e076..444aac5b 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -1,12 +1,10 @@ # Quick Start -This tutorial walks through a complete galaxy image simulation, then shows how -JAX transformations (`jit`, `grad`, `vmap`) apply to it. +A complete galaxy image simulation, then JAX transformations (`jit`, `grad`, `vmap`) on top. ## A Simple Simulation -This example creates a Gaussian galaxy, convolves it with a Gaussian PSF, draws -the image, and adds noise — equivalent to GalSim's `demo1.py`. +A Gaussian galaxy convolved with a Gaussian PSF, drawn and noised -- equivalent to GalSim's `demo1.py`. ```python import jax_galsim @@ -35,8 +33,7 @@ image = image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma)) image.write("output/demo1.fits") ``` -The API is intentionally close to GalSim — most GalSim tutorials translate -directly to JAX-GalSim by replacing `import galsim` with `import jax_galsim`. +Most GalSim code translates directly by replacing `import galsim` with `import jax_galsim`. ## JIT Compilation @@ -58,7 +55,7 @@ image = simulate(1e5, 2.0) ## Automatic Differentiation -Compute gradients of any scalar output with respect to galaxy parameters: +Compute gradients of any scalar output with respect to parameters: ```python def total_flux(gal_sigma, psf_sigma): @@ -73,12 +70,11 @@ grad_fn = jax.grad(total_flux, argnums=(0, 1)) d_gal, d_psf = grad_fn(2.0, 1.0) ``` -This is useful for fitting galaxy models to data, where gradients enable -efficient optimization. +Useful for fitting galaxy models to data via gradient-based optimization. ## Vectorization with vmap -Simulate a batch of galaxies with different parameters without writing loops: +Batch-simulate galaxies with different parameters without explicit loops: ```python import jax.numpy as jnp diff --git a/docs/index.md b/docs/index.md index ecb15ae5..2612a75d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -16,7 +16,7 @@ ## Why JAX-GalSim? -JAX-GalSim reimplements [GalSim](https://github.com/GalSim-developers/GalSim) in pure JAX, unlocking three capabilities for galaxy image simulation: +JAX-GalSim reimplements [GalSim](https://github.com/GalSim-developers/GalSim) in pure JAX, unlocking: !!! tip "JIT Compilation" @@ -56,7 +56,7 @@ image = final.drawImage(scale=0.2) image = image.addNoise(jax_galsim.GaussianNoise(sigma=30.0)) ``` -Because JAX-GalSim objects are JAX pytrees, you can JIT-compile and differentiate the entire pipeline: +JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline: ```python @jax.jit diff --git a/docs/notable-differences.md b/docs/notable-differences.md index 9bfe759f..a18d1c6a 100644 --- a/docs/notable-differences.md +++ b/docs/notable-differences.md @@ -1,12 +1,11 @@ # Notable Differences from GalSim -JAX-GalSim strives to be a drop-in replacement for GalSim, but JAX's design -imposes some fundamental differences. This page documents them. +JAX-GalSim is a drop-in replacement for most GalSim code, but JAX's design +imposes some fundamental differences. ## Immutability -JAX arrays are immutable — you cannot modify them in-place. This affects all -image operations: +JAX arrays are immutable. Operations that mutate in GalSim return new objects instead: ```python # GalSim: mutates image in-place @@ -18,27 +17,18 @@ image = image.addNoise(noise) # Direct array mutation is not supported ``` -Any GalSim code that relies on in-place modification of images needs to be -rewritten to use the return value. - ## Array Views -JAX does not support all the array view semantics that NumPy provides. In -particular: - -- **Real views of complex images** are not available. In GalSim, you can get a - real-valued view of a complex image's real or imaginary part that shares - memory. JAX-GalSim returns copies instead. +JAX does not support NumPy's array view semantics. Real-valued views of complex +images (sharing memory) are not available; JAX-GalSim returns copies instead. ## Random Number Generation -JAX uses a **functional PRNG** — random state is explicit and must be threaded -through computations. Key differences: +JAX uses a **functional PRNG** -- random state is explicit and must be threaded through computations: -- JAX's PRNG is deterministic and reproducible across platforms -- Random deviates cannot "fill" an existing array; they return new arrays -- The sequence of random numbers differs from GalSim's RNG, so results will - not be numerically identical even with the same seed +- Deterministic and reproducible across platforms +- Deviates return new arrays (cannot "fill" existing ones) +- Number sequences differ from GalSim's RNG, even with the same seed - Different RNG classes may have different stability properties for discarding ## Profile Restrictions @@ -52,22 +42,17 @@ Some GalSim features are not yet implemented: ## Control Flow and Tracing -JAX's tracing system places restrictions on Python control flow: +JAX's tracing system restricts Python control flow: -- **`if`/`else` on traced values**: You cannot branch on values that JAX is - tracing (e.g., profile parameters inside a `jit`-compiled function). Use - `jax.lax.cond` instead. -- **Variable-size operations**: Operations whose output shape depends on input - values (e.g., adaptive image sizing) may not work under `jit`. +- **`if`/`else` on traced values**: Cannot branch on values JAX is tracing (e.g., profile parameters inside `jit`). Use `jax.lax.cond` instead. +- **Variable-size operations**: Operations whose output shape depends on input values (e.g., adaptive image sizing) may not work under `jit`. -JAX-GalSim uses `has_tracers()` internally to detect when code is being traced -and avoid problematic control flow patterns. +JAX-GalSim uses `has_tracers()` internally to detect tracing and avoid problematic control flow. ## Numerical Precision -Some operations may produce slightly different numerical results due to: +Results may differ slightly from GalSim due to: -- Different order of floating-point operations (JAX may reorder for performance) -- Use of XLA-compiled math kernels instead of system math libraries -- Custom gradient-safe implementations (e.g., `safe_sqrt` in `core/math.py`) - that handle edge cases differently for differentiability +- Different floating-point operation ordering (JAX may reorder for performance) +- XLA-compiled math kernels instead of system math libraries +- Gradient-safe implementations (e.g., `safe_sqrt` in `core/math.py`) that handle edge cases for differentiability diff --git a/mkdocs.yml b/mkdocs.yml index fce1a980..ecddfea1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,7 +44,6 @@ plugins: show_root_heading: true show_root_full_path: false show_if_no_docstring: true - heading_level: 2 show_symbol_type_heading: true show_symbol_type_toc: true From 9d718f8850511d2405b65a4f44a3c39cd98388e1 Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 20 Feb 2026 18:20:08 +0100 Subject: [PATCH 03/10] docs: remove architecture section and expand notable differences Remove the architecture/ pages (pytree, implements, gsobject, drawing) and key-concepts page. Consolidate their most useful content into notable-differences.md, which now thoroughly covers immutability, array views, RNG, PyTree registration, control flow/tracing, profile restrictions, numerical precision, and the @implements decorator. Co-Authored-By: Claude Opus 4.6 --- docs/architecture/drawing.md | 66 -------- docs/architecture/gsobject.md | 87 ----------- docs/architecture/implements.md | 52 ------- docs/architecture/index.md | 108 ------------- docs/architecture/pytree.md | 68 -------- docs/getting-started/index.md | 2 +- docs/getting-started/key-concepts.md | 78 ---------- docs/getting-started/quickstart.md | 3 +- docs/index.md | 2 +- docs/notable-differences.md | 224 +++++++++++++++++++++++---- mkdocs.yml | 7 - 11 files changed, 200 insertions(+), 497 deletions(-) delete mode 100644 docs/architecture/drawing.md delete mode 100644 docs/architecture/gsobject.md delete mode 100644 docs/architecture/implements.md delete mode 100644 docs/architecture/index.md delete mode 100644 docs/architecture/pytree.md delete mode 100644 docs/getting-started/key-concepts.md diff --git a/docs/architecture/drawing.md b/docs/architecture/drawing.md deleted file mode 100644 index cce389e5..00000000 --- a/docs/architecture/drawing.md +++ /dev/null @@ -1,66 +0,0 @@ -# Drawing Pipeline - -`obj.drawImage()` converts a surface brightness profile into a pixel array. - -## Overview - -```mermaid -flowchart TD - A["obj.drawImage(scale, nx, ny)"] --> B{Method?} - B -->|"auto / fft"| C["draw_by_kValue"] - B -->|"real_space"| D["draw_by_xValue"] - B -->|"phot"| E["Photon shooting"] - - C --> F["Evaluate _kValue on k-grid"] - F --> G["Inverse FFT"] - G --> H["Apply pixel response"] - - D --> I["Evaluate _xValue on pixel grid"] - I --> J["Direct summation"] - - E --> K["Generate PhotonArray"] - K --> L["Accumulate on Image"] - - H --> M["Image"] - J --> M - L --> M -``` - -## Drawing Methods - -### FFT Drawing (`draw_by_kValue`) - -The default for most profiles: - -1. Determine image size from `_stepk`/`_maxk` (or user-specified `nx`, `ny`) -2. Evaluate `_kValue(kpos)` on a k-space grid -3. Inverse FFT to real space -4. Apply pixel convolution (if not already included) - -Efficient for smooth profiles and convolutions (convolution in real space = multiplication in Fourier space). - -### Real-Space Drawing (`draw_by_xValue`) - -Evaluates `_xValue(pos)` directly at each pixel center, multiplied by pixel area. Simpler but slower for large images. Implemented in `jax_galsim/core/draw.py`. - -### Photon Shooting - -Generates random photon positions from the profile and accumulates them on an -image grid via `PhotonArray` and `Sensor`. - -## Convolution - -`Convolve([gal, psf])` creates a lazy `Convolution` object. The actual -convolution happens at draw time: - -- **FFT method**: Multiply k-space representations, then inverse FFT -- **Real-space method**: Only supported for simple profile combinations - -## `_maxk` and `_stepk` - -These properties control automatic image sizing: - -- **`_maxk`**: Maximum significant spatial frequency. Determines pixel scale (Nyquist sampling). -- **`_stepk`**: Spacing in k-space. Determines image size (field of view). - -For a `Convolution`, `_maxk` = min of components (most compact Fourier representation wins), `_stepk` = min of components (largest real-space extent sets the field of view). diff --git a/docs/architecture/gsobject.md b/docs/architecture/gsobject.md deleted file mode 100644 index 1cb170b8..00000000 --- a/docs/architecture/gsobject.md +++ /dev/null @@ -1,87 +0,0 @@ -# GSObject Hierarchy - -`GSObject` is the base class for all surface brightness profiles. Every galaxy -model, PSF, and optical component inherits from it. - -## Class Hierarchy - -```mermaid -classDiagram - GSObject <|-- Gaussian - GSObject <|-- Moffat - GSObject <|-- Spergel - GSObject <|-- Exponential - GSObject <|-- DeltaFunction - GSObject <|-- Box - Box <|-- Pixel - GSObject <|-- InterpolatedImage - GSObject <|-- Sum - GSObject <|-- Convolution - GSObject <|-- Transformation - - class GSObject { - +_params: dict - +_gsparams: GSParams - +drawImage(scale, nx, ny) Image - +shear(shear) GSObject - +shift(dx, dy) GSObject - +_xValue(pos) float - +_kValue(kpos) complex - +_maxk: float - +_stepk: float - } - - class Gaussian { - +sigma: float - +flux: float - } - - class Moffat { - +beta: float - +scale_radius: float - +flux: float - } - - class Spergel { - +nu: float - +half_light_radius: float - +flux: float - } -``` - -## The Base Class Contract - -Subclasses must implement these methods and properties: - -| Member | Type | Purpose | -|--------|------|---------| -| `_xValue(pos)` | Method | Surface brightness at real-space position | -| `_kValue(kpos)` | Method | Fourier-space amplitude at frequency | -| `_maxk` | Property | Maximum k beyond which the profile is negligible | -| `_stepk` | Property | Sampling interval in k-space | - -The base class provides the public API built on these primitives: - -- `drawImage()` — Renders the profile to a pixel grid (delegates to `core.draw`) -- `shear()`, `shift()`, `rotate()`, `dilate()` — Return transformed copies via `Transformation` -- `__add__` — Returns a `Sum` of profiles -- `withFlux()`, `withGSParams()` — Return copies with modified parameters - -## The `_params` Dict - -Traced parameters live in a `_params` dictionary -- the canonical storage for -values that JAX differentiates through. Properties like `gal.sigma` are -accessors into `_params`: - -```python -gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) -gal._params # {"flux": Array(1e5), "sigma": Array(2.0)} -``` - -## Composition Objects - -Profiles combine through three composition types (themselves GSObjects, nestable arbitrarily): - -- **`Sum`** (`Add`): Adds surface brightness profiles -- **`Convolution`** (`Convolve`): Convolves profiles (e.g., galaxy with PSF) -- **`Transformation`** (`Transform`): Affine transforms (shear, shift, rotation, flux scaling) diff --git a/docs/architecture/implements.md b/docs/architecture/implements.md deleted file mode 100644 index 835a2b47..00000000 --- a/docs/architecture/implements.md +++ /dev/null @@ -1,52 +0,0 @@ -# The @implements Decorator - -JAX-GalSim reuses GalSim's docstrings via the `@implements` decorator. - -## Usage - -```python -import galsim as _galsim -from jax_galsim.core.utils import implements - -@implements(_galsim.Gaussian, - lax_description="LAX: Does not support ChromaticObject.") -class Gaussian(GSObject): - ... -``` - -This does three things: - -1. **Copies the docstring** from `_galsim.Gaussian` to `jax_galsim.Gaussian` -2. **Adds a LAX note** (the `lax_description`) documenting JAX-specific - differences or limitations -3. **Sets `__galsim_wrapped__`** on the class, linking back to the original - GalSim implementation - -## The `lax_description` Parameter - -Documents differences from the reference GalSim implementation: - -```python -@implements(_galsim.Add, - lax_description="Does not support ChromaticObject at this point.") -def Add(*args, **kwargs): - return Sum(*args, **kwargs) -``` - -## How It Works - -The decorator (in `jax_galsim/core/utils.py`): - -1. Copies the original GalSim docstring -2. Inserts a "LAX-backend implementation of `:func:original.name`" note after the summary line -3. Appends the `lax_description` text (if provided) -4. Assigns the combined docstring to the wrapped class/function - -## When to Use It - -- **Always** when implementing a GalSim class or function — this is the standard - pattern in JAX-GalSim -- The `lax_description` should note any restricted functionality, different - behavior, or missing parameters compared to GalSim -- If there is no corresponding GalSim function (e.g., JAX-specific utilities), - write a normal docstring instead diff --git a/docs/architecture/index.md b/docs/architecture/index.md deleted file mode 100644 index 864a5167..00000000 --- a/docs/architecture/index.md +++ /dev/null @@ -1,108 +0,0 @@ -# Architecture Overview - -JAX-GalSim mirrors GalSim's architecture, replacing NumPy/C++ internals with pure JAX. - -## Component Map - -```mermaid -graph TB - subgraph "User API" - Profiles["Profiles
(Gaussian, Moffat, ...)"] - Comp["Composition
(Convolve, Sum, Transform)"] - Img["Image"] - Noise["Noise & Random"] - end - - subgraph "Coordinate System" - Pos["Position"] - Bounds["Bounds"] - WCS["WCS"] - Shear["Shear"] - Angle["Angle"] - end - - subgraph "Rendering Pipeline" - Draw["core.draw"] - FFT["FFT Convolution"] - Photon["Photon Shooting"] - Interp["Interpolation"] - end - - subgraph "JAX Infrastructure" - PyTree["PyTree Registration"] - Impl["@implements Decorator"] - Utils["Type Casting & Tracers"] - end - - Profiles --> Comp - Comp --> Draw - Draw --> Img - Photon --> Img - Noise --> Img - WCS --> Img - Pos --> Bounds - Bounds --> Img - Shear --> Profiles - Interp --> Draw - PyTree --> Profiles - PyTree --> Img - Impl --> Profiles - Utils --> Draw -``` - -## Module Layout - -``` -jax_galsim/ -├── core/ # JAX internals -│ ├── draw.py # Real-space & k-space drawing -│ ├── integrate.py # Numerical integration -│ ├── interpolate.py # Spline interpolation -│ ├── math.py # Gradient-safe math (safe_sqrt, etc.) -│ ├── utils.py # @implements, has_tracers(), type casting -│ └── wrap_image.py # Image wrapping utilities -├── gaussian.py # Gaussian profile -├── moffat.py # Moffat profile -├── spergel.py # Spergel profile -├── exponential.py # Exponential profile -├── box.py # Box & Pixel profiles -├── deltafunction.py # Delta function profile -├── gsobject.py # Base GSObject class -├── sum.py # Sum (Add) composition -├── convolve.py # Convolution & Deconvolution -├── transform.py # Affine transformations -├── interpolatedimage.py # InterpolatedImage -├── image.py # Image class hierarchy -├── wcs.py # WCS base classes -├── fitswcs.py # FITS WCS implementations -├── fits.py # FITS I/O -├── noise.py # Noise models -├── random.py # Random deviates -├── interpolant.py # Interpolant functions -├── photon_array.py # Photon arrays -├── sensor.py # Sensor model -├── position.py # Position types -├── bounds.py # Bounds types -├── angle.py # Angle & AngleUnit -├── shear.py # Shear -├── celestial.py # CelestialCoord -├── gsparams.py # GSParams configuration -├── errors.py # Exception classes -├── bessel.py # Bessel functions -├── integ.py # Integration (int1d) -└── utilities.py # General utilities -``` - -## Design Principles - -1. **Drop-in replacement**: `import jax_galsim as galsim` works for supported features. -2. **PyTree everywhere**: Every object separates traced parameters (children) from static configuration (auxiliary data). -3. **Docstring inheritance**: `@implements` copies GalSim docstrings, appending JAX-specific notes. -4. **Pure functions**: No side effects, enabling `jit`, `grad`, and `vmap`. - -## Deep Dives - -- [PyTree Registration](pytree.md) — How objects become JAX-compatible -- [The @implements Decorator](implements.md) — Docstring inheritance from GalSim -- [GSObject Hierarchy](gsobject.md) — Base class contract and profile implementations -- [Drawing Pipeline](drawing.md) — How profiles become pixel images diff --git a/docs/architecture/pytree.md b/docs/architecture/pytree.md deleted file mode 100644 index baeff6f3..00000000 --- a/docs/architecture/pytree.md +++ /dev/null @@ -1,68 +0,0 @@ -# PyTree Registration - -JAX transformations decompose objects into arrays (leaves) and static structure -(treedef). Every JAX-GalSim class uses `@register_pytree_node_class` and -implements two methods: - -```python -from jax.tree_util import register_pytree_node_class - -@register_pytree_node_class -class Gaussian(GSObject): - def tree_flatten(self): - # Children: values that JAX can trace (differentiate through) - children = (self.params,) - # Aux data: static values that trigger recompilation if changed - aux_data = (self.gsparams,) - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - # Reconstruct the object from its parts - return cls(params=children[0], gsparams=aux_data[0]) -``` - -## Children vs Auxiliary Data - -| Component | Role | Examples | Effect of changing | -|-----------|------|----------|--------------------| -| **Children** (traced) | Values JAX differentiates through | `flux`, `sigma`, `half_light_radius` | Triggers re-evaluation, not recompilation | -| **Auxiliary data** (static) | Structure and configuration | `GSParams`, enum flags | Triggers full recompilation under `jit` | - -In practice, profile parameters live in a `_params` dict (children) and numerical -configuration lives in `_gsparams` (auxiliary): - -```python -gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) -# gal._params = {"flux": 1e5, "sigma": 2.0} — traced -# gal._gsparams = GSParams(...) — static -``` - -## The `__init__` Gotcha - -During `tree_unflatten`, JAX calls the constructor with traced values, not -concrete Python numbers. Type checks like `isinstance(sigma, float)` will fail -on tracers. Use `has_tracers()` to skip validation during tracing: - -```python -from jax_galsim.core.utils import has_tracers - -class MyProfile(GSObject): - def __init__(self, sigma, gsparams=None): - if not has_tracers(sigma): - # Only validate with concrete values - if sigma <= 0: - raise ValueError("sigma must be positive") - ... -``` - -## Practical Implications - -1. **`jax.grad` works**: Because profile parameters are traced children, you get - gradients for free. - -2. **`GSParams` changes recompile**: Changing `GSParams` between calls to a - `jit`-compiled function triggers recompilation, since it's static auxiliary data. - -3. **No mutable state**: PyTree flattening and unflattening means all state must - be reconstructable from children + aux_data. There's no hidden mutable state. diff --git a/docs/getting-started/index.md b/docs/getting-started/index.md index a8733c82..062d9480 100644 --- a/docs/getting-started/index.md +++ b/docs/getting-started/index.md @@ -4,4 +4,4 @@ New to JAX-GalSim? Start here. - [Installation](installation.md) — Install JAX-GalSim and set up GPU support - [Quick Start](quickstart.md) — Simulate a galaxy image in a few lines of code -- [Key Concepts](key-concepts.md) — Understand GSObjects, Images, PyTrees, and how JAX changes things +- [Notable Differences](../notable-differences.md) — What changes when GalSim runs on JAX diff --git a/docs/getting-started/key-concepts.md b/docs/getting-started/key-concepts.md deleted file mode 100644 index a9a239f1..00000000 --- a/docs/getting-started/key-concepts.md +++ /dev/null @@ -1,78 +0,0 @@ -# Key Concepts - -## GSObject: The Building Block - -Every galaxy profile, PSF, and optical component is a **GSObject**. GSObjects support: - -- **Arithmetic**: `+` (sum), `*` (scalar multiply) -- **Transformations**: `.shear()`, `.shift()`, `.dilate()`, `.rotate()` -- **Convolution**: `jax_galsim.Convolve([obj1, obj2, ...])` -- **Drawing**: `.drawImage(scale=...)` renders the profile to a pixel grid - -```python -import jax_galsim - -# A Gaussian galaxy, sheared and shifted -gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) -gal = gal.shear(g1=0.2, g2=0.1) -gal = gal.shift(dx=0.5, dy=-0.3) -``` - -## Images Are Immutable - -JAX arrays cannot be modified in-place, so operations that would mutate an image -in GalSim return a new image instead: - -```python -# GalSim: image.addNoise(noise) # modifies in-place -# JAX-GalSim: image = image.addNoise(noise) # returns new image -``` - -This is the most common difference when porting GalSim code. -See [Notable Differences](../notable-differences.md) for the full list. - -## PyTree Registration - -All JAX-GalSim objects are registered as JAX **PyTrees**, so you can pass them -directly to `jit`, `grad`, and `vmap`: - -```python -import jax - -@jax.jit -def render(gal): - return gal.drawImage(scale=0.2).array - -gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) -result = render(gal) # gal is automatically flattened/unflattened -``` - -Profile parameters (flux, sigma, etc.) are **traced children** that JAX can -differentiate through. Configuration like `GSParams` is **static auxiliary data** -that triggers recompilation when changed. See [PyTree Registration](../architecture/pytree.md) -for details. - -## GSParams: Numerical Configuration - -`GSParams` controls numerical accuracy and performance trade-offs for rendering: - -```python -# Use tighter tolerances for high-precision work -gsparams = jax_galsim.GSParams(maximum_fft_size=8192, kvalue_accuracy=1e-6) -gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0, gsparams=gsparams) -``` - -`GSParams` is treated as **static** auxiliary data in the PyTree, so changing it -triggers recompilation under `jit`. - -## Functional Random Numbers - -JAX uses a functional PRNG -- random state is explicit and never mutated. -JAX-GalSim wraps this in GalSim's familiar noise API: - -```python -noise = jax_galsim.GaussianNoise(sigma=30.0) -image = image.addNoise(noise) -``` - -See [Notable Differences](../notable-differences.md) for details on RNG behavior. diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 444aac5b..d51f3f60 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -94,6 +94,5 @@ images = batch_simulate(sigmas) # shape: (10, 64, 64) ## Next Steps -- [Key Concepts](key-concepts.md) — Understand the design behind JAX-GalSim -- [Notable Differences](../notable-differences.md) — What's different from GalSim +- [Notable Differences](../notable-differences.md) — What changes when GalSim runs on JAX - [API Reference](../api/profiles/gaussian.md) — Full API documentation diff --git a/docs/index.md b/docs/index.md index 2612a75d..6796a026 100644 --- a/docs/index.md +++ b/docs/index.md @@ -76,5 +76,5 @@ dflux, dsigma = grad_fn(1e5, 2.0) - [Installation](getting-started/installation.md) --- Set up JAX-GalSim with GPU support - [Quick Start](getting-started/quickstart.md) --- Walk through a complete simulation -- [Key Concepts](getting-started/key-concepts.md) --- Understand JAX-GalSim's design +- [Notable Differences](notable-differences.md) --- What changes when GalSim runs on JAX - [API Reference](api/profiles/gaussian.md) --- Browse the full API diff --git a/docs/notable-differences.md b/docs/notable-differences.md index a18d1c6a..8d8b38d6 100644 --- a/docs/notable-differences.md +++ b/docs/notable-differences.md @@ -1,58 +1,228 @@ # Notable Differences from GalSim -JAX-GalSim is a drop-in replacement for most GalSim code, but JAX's design -imposes some fundamental differences. +JAX-GalSim is designed as a drop-in replacement for GalSim --- replacing +`import galsim` with `import jax_galsim` works for all supported features. +However, JAX's execution model introduces several fundamental differences +that you should understand before porting code or writing new simulations. + +--- ## Immutability -JAX arrays are immutable. Operations that mutate in GalSim return new objects instead: +JAX arrays are **immutable**. Any GalSim operation that modifies data in-place +returns a new object in JAX-GalSim instead. ```python -# GalSim: mutates image in-place +# GalSim — mutates the image in-place image.addNoise(noise) image.array[10, 10] = 0.0 -# JAX-GalSim: returns a new image +# JAX-GalSim — returns a new image each time image = image.addNoise(noise) -# Direct array mutation is not supported + +# Direct array element mutation is not supported. +# Use jax.numpy operations to produce a new array: +new_array = image.array.at[10, 10].set(0.0) ``` +This is the most common change when porting GalSim code. Every call that +modifies an image, adds noise, or updates a value must capture the return value. +If you forget the assignment, the original object is unchanged and no error is +raised --- a subtle source of bugs. + +--- + ## Array Views -JAX does not support NumPy's array view semantics. Real-valued views of complex -images (sharing memory) are not available; JAX-GalSim returns copies instead. +NumPy supports **array views** --- slices that share memory with the original +array. JAX does not. In GalSim, you can obtain a real-valued view of a complex +image (e.g., the real part shares memory with the underlying complex buffer). +In JAX-GalSim, these operations return **copies** instead. Modifying the copy +does not affect the original. + +```python +# GalSim — real_part is a view, shares memory with complex_image +real_part = complex_image.real + +# JAX-GalSim — real_part is a copy +real_part = complex_image.real # independent array +``` + +--- ## Random Number Generation -JAX uses a **functional PRNG** -- random state is explicit and must be threaded through computations: +JAX uses a **functional PRNG** --- random state is explicit and must be passed +through computations. This has several consequences: -- Deterministic and reproducible across platforms -- Deviates return new arrays (cannot "fill" existing ones) -- Number sequences differ from GalSim's RNG, even with the same seed -- Different RNG classes may have different stability properties for discarding +**Determinism**: Given the same seed, JAX-GalSim produces identical results +across runs and platforms (CPU, GPU, TPU). GalSim's results may vary by platform. -## Profile Restrictions +**Explicit state**: Random deviates carry their state explicitly. Under the hood, +JAX-GalSim wraps JAX's key-based PRNG in GalSim's familiar noise API, so the +user-facing interface looks the same: + +```python +noise = jax_galsim.GaussianNoise(sigma=30.0) +image = image.addNoise(noise) # state is managed internally +``` + +**Different sequences**: Even with the same seed value, the actual random number +sequences differ from GalSim. Results will not match GalSim number-for-number. +This is expected --- the underlying PRNG algorithms are completely different. + +**No in-place fill**: GalSim deviates can "fill" existing arrays. JAX deviates +always return new arrays, consistent with JAX's immutability model. -Some GalSim features are not yet implemented: +--- + +## PyTree Registration + +All JAX-GalSim objects are registered as JAX **PyTrees**. This is what allows +you to pass them directly to `jax.jit`, `jax.grad`, and `jax.vmap`. + +A PyTree splits each object into two parts: + +| Part | What it contains | Examples | Effect of changing | +|------|-----------------|----------|--------------------| +| **Children** (traced) | Values JAX differentiates through | `flux`, `sigma`, `half_light_radius` | Re-evaluation, not recompilation | +| **Auxiliary data** (static) | Structure and configuration | `GSParams`, enum flags | Full recompilation under `jit` | + +In practice, profile parameters live in a `_params` dict (children) and +numerical configuration lives in `_gsparams` (auxiliary): + +```python +gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0) +# gal._params = {"flux": 1e5, "sigma": 2.0} — traced by JAX +# gal._gsparams = GSParams(...) — static, triggers recompile +``` + +Because `GSParams` is static auxiliary data, changing it between calls to a +`jit`-compiled function triggers recompilation. Keep `GSParams` constant across +calls when possible. + +```python +import jax + +gsparams = jax_galsim.GSParams(maximum_fft_size=8192) + +@jax.jit +def simulate(flux, sigma): + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma, gsparams=gsparams) + return gal.drawImage(scale=0.2).array.sum() + +# Changing gsparams here would cause recompilation on next call +``` -- **Truncated Moffat profiles** are not supported (the `trunc` parameter) -- **ChromaticObject** and all chromatic functionality is not available -- **InterpolatedKImage** is not implemented -- See [API Coverage](api-coverage.md) for the full list of supported APIs +--- ## Control Flow and Tracing -JAX's tracing system restricts Python control flow: +JAX's JIT compiler works by **tracing** --- it records operations on abstract +values to build a computation graph. This restricts what Python code can do +inside `jit`-compiled functions. -- **`if`/`else` on traced values**: Cannot branch on values JAX is tracing (e.g., profile parameters inside `jit`). Use `jax.lax.cond` instead. -- **Variable-size operations**: Operations whose output shape depends on input values (e.g., adaptive image sizing) may not work under `jit`. +### No branching on traced values -JAX-GalSim uses `has_tracers()` internally to detect tracing and avoid problematic control flow. +You cannot use Python `if`/`else` on values that JAX is tracing (e.g., profile +parameters passed into a `jit`-compiled function): + +```python +@jax.jit +def bad(sigma): + if sigma > 1.0: # ERROR: sigma is a tracer, not a concrete value + return sigma * 2 + return sigma + +@jax.jit +def good(sigma): + return jax.lax.cond(sigma > 1.0, lambda s: s * 2, lambda s: s, sigma) +``` + +JAX-GalSim uses an internal `has_tracers()` utility to detect tracing and +avoid problematic control flow in its own implementations. + +### Fixed output shapes + +Under `jit`, the **shape** of every array must be determinable at compile time. +Operations whose output size depends on input values (e.g., adaptive image +sizing based on a traced parameter) may not work. When using `jax.vmap`, you +must specify fixed image dimensions: + +```python +@jax.vmap +def batch(sigma): + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) + # Must specify nx, ny so all images have the same shape + return gal.drawImage(scale=0.2, nx=64, ny=64).array +``` + +### The `__init__` gotcha + +During `jit` tracing, JAX calls constructors with **tracer objects** rather than +concrete Python numbers. Type checks like `isinstance(sigma, float)` will fail +on tracers. JAX-GalSim handles this internally, but if you subclass any +JAX-GalSim object, be aware that `__init__` may receive tracers: + +```python +from jax_galsim.core.utils import has_tracers + +class MyProfile(jax_galsim.GSObject): + def __init__(self, sigma, gsparams=None): + if not has_tracers(sigma): + # Only validate with concrete values + if sigma <= 0: + raise ValueError("sigma must be positive") + ... +``` + +--- + +## Profile Restrictions + +Some GalSim features are not yet implemented in JAX-GalSim: + +- **Truncated Moffat profiles**: The `trunc` parameter is not supported. +- **ChromaticObject**: All chromatic functionality (wavelength-dependent profiles) is not available. +- **InterpolatedKImage**: Not implemented. +- **Airy, Kolmogorov, OpticalPSF, RealGalaxy**: And other profiles --- see [API Coverage](api-coverage.md) for the full list. + +The project currently implements **22.5%** of the GalSim public API, focused on +the most commonly used profiles and operations. Coverage is expanding. + +--- ## Numerical Precision -Results may differ slightly from GalSim due to: +Simulation results may differ slightly from GalSim at the floating-point level: + +- **Operation reordering**: JAX (via XLA) may reorder floating-point operations for performance. Floating-point addition is not associative, so different orderings produce slightly different results. +- **Different math kernels**: XLA-compiled math kernels may differ from system math libraries (e.g., `libm`) that GalSim uses via NumPy/C++. +- **Gradient-safe functions**: JAX-GalSim uses special implementations (e.g., `safe_sqrt` that avoids `NaN` gradients at zero) where GalSim uses standard library functions. These may produce slightly different values at edge cases. +- **Default precision**: JAX defaults to 32-bit floats. Enable 64-bit with `jax.config.update("jax_enable_x64", True)` for higher precision matching GalSim's default behavior. + +These differences are typically at the level of floating-point round-off +($\sim 10^{-7}$ for float32, $\sim 10^{-15}$ for float64) and should not +affect scientific conclusions. + +--- + +## The `@implements` Decorator + +JAX-GalSim reuses GalSim's docstrings rather than duplicating them. Every public +class and function uses an `@implements` decorator that copies the docstring from +the corresponding GalSim object and appends a note about JAX-specific differences: + +```python +from jax_galsim.core.utils import implements +import galsim as _galsim + +@implements(_galsim.Gaussian, + lax_description="LAX: Does not support ChromaticObject.") +class Gaussian(GSObject): + ... +``` -- Different floating-point operation ordering (JAX may reorder for performance) -- XLA-compiled math kernels instead of system math libraries -- Gradient-safe implementations (e.g., `safe_sqrt` in `core/math.py`) that handle edge cases for differentiability +This means the [API Reference](api/profiles/gaussian.md) shows GalSim's +documentation with an added "LAX-backend" note. If you see RST-formatted cross-references +like `:func:` or `:class:` in the docs, they come from GalSim's original docstrings. diff --git a/mkdocs.yml b/mkdocs.yml index ecddfea1..c6cbdd79 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,13 +75,6 @@ nav: - getting-started/index.md - Installation: getting-started/installation.md - Quick Start: getting-started/quickstart.md - - Key Concepts: getting-started/key-concepts.md - - Architecture: - - architecture/index.md - - PyTree Registration: architecture/pytree.md - - The @implements Decorator: architecture/implements.md - - GSObject Hierarchy: architecture/gsobject.md - - Drawing Pipeline: architecture/drawing.md - Notable Differences: notable-differences.md - API Coverage: api-coverage.md - API Reference: From 2fbeb8aba3964d6ff03764f0c5fd8780bd6bd586 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Thu, 2 Apr 2026 10:18:12 -0400 Subject: [PATCH 04/10] small correction when adding noise --- docs/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index 6796a026..bdfdb1dd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -52,8 +52,8 @@ psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) final = jax_galsim.Convolve([gal, psf]) image = final.drawImage(scale=0.2) -# Add noise -image = image.addNoise(jax_galsim.GaussianNoise(sigma=30.0)) +# Add noise (changes underlying image array) +image.addNoise(jax_galsim.GaussianNoise(sigma=30.0)) ``` JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline: From 75f976710705fc2190ad91a9580003ebd594c6c7 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 3 Apr 2026 10:54:44 -0400 Subject: [PATCH 05/10] small corrections regarding jitting and modifying images --- docs/index.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/index.md b/docs/index.md index bdfdb1dd..73f352b3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -52,7 +52,7 @@ psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) final = jax_galsim.Convolve([gal, psf]) image = final.drawImage(scale=0.2) -# Add noise (changes underlying image array) +# Add noise (overwrites underlying image array with new array) image.addNoise(jax_galsim.GaussianNoise(sigma=30.0)) ``` @@ -60,16 +60,24 @@ JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the ```python @jax.jit -def simulate(flux, sigma): +def simulate(flux, sigma, *, slen=21, fft_size=128): + gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) - return jax_galsim.Convolve([gal, psf]).drawImage(scale=0.2).array.sum() + gal_convolved = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) + image = gal_convolved.drawImage(nx=slen, ny=slen, scale=0.2) + return image.array.sum() # Compute gradients with respect to galaxy parameters grad_fn = jax.grad(simulate, argnums=(0, 1)) dflux, dsigma = grad_fn(1e5, 2.0) ``` +Note that the size of the image in real space (`slen`) and fourier space +(`minimum_fft_size = maximum_fft_size`) need to be specified in advance for jitting. See the rest +of the documentation for more details and examples. + --- ## Next Steps From b5e5e2fd80efce0a4f541810b63bdc5328607a77 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 3 Apr 2026 14:09:25 -0400 Subject: [PATCH 06/10] some corrections in notable differences --- docs/notable-differences.md | 67 ++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/docs/notable-differences.md b/docs/notable-differences.md index 8d8b38d6..f1425cdd 100644 --- a/docs/notable-differences.md +++ b/docs/notable-differences.md @@ -9,26 +9,40 @@ that you should understand before porting code or writing new simulations. ## Immutability -JAX arrays are **immutable**. Any GalSim operation that modifies data in-place -returns a new object in JAX-GalSim instead. +JAX arrays are **immutable**. Any GalSim operation that originally modified data in-place, now +instead creates a new array that overwrites the original one. Let's look at `__iadd__` as an example. ```python # GalSim — mutates the image in-place -image.addNoise(noise) -image.array[10, 10] = 0.0 +# i.e. no new numpy array is created +image += 1.0 +# under the hood, some version of: `self.array[:,:] += a` does not create a new numpy array. -# JAX-GalSim — returns a new image each time -image = image.addNoise(noise) +# JAX-GalSim — creates a new array and overwrites original one +image += 1.0 +# under the hood: `image._array = image._array + 1.0`. The RHS is a new JAX array. +``` + +This could become a subtle source of bugs if you are used to numpy in place mutability. Here +is another example with `__iadd__` that illustrates this: + +```python +# galsim +image = galsim.ImageD(11, 11) +arr1 = image.array -# Direct array element mutation is not supported. -# Use jax.numpy operations to produce a new array: -new_array = image.array.at[10, 10].set(0.0) +image += 1.0 +arr1.sum(), image.array.sum() # -> 121.0, 121.0 + +# jax-galsim +image = jax_galsim.ImageD(11, 11) +arr1 = image.array + +image += 1.0 +arr1.sum(), image.array.sum() # -> 0.0, 121.0, original image array was unmodified! ``` -This is the most common change when porting GalSim code. Every call that -modifies an image, adds noise, or updates a value must capture the return value. -If you forget the assignment, the original object is unchanged and no error is -raised --- a subtle source of bugs. +For more details on JAX immutability please see the [Sharp Bits page](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates) of the JAX documentation. --- @@ -64,7 +78,7 @@ user-facing interface looks the same: ```python noise = jax_galsim.GaussianNoise(sigma=30.0) -image = image.addNoise(noise) # state is managed internally +image.addNoise(noise) # state is managed internally ``` **Different sequences**: Even with the same seed value, the actual random number @@ -88,7 +102,7 @@ A PyTree splits each object into two parts: | **Children** (traced) | Values JAX differentiates through | `flux`, `sigma`, `half_light_radius` | Re-evaluation, not recompilation | | **Auxiliary data** (static) | Structure and configuration | `GSParams`, enum flags | Full recompilation under `jit` | -In practice, profile parameters live in a `_params` dict (children) and +For `GSObject`, profile parameters live in a `_params` dict (children) and numerical configuration lives in `_gsparams` (auxiliary): ```python @@ -104,12 +118,13 @@ calls when possible. ```python import jax -gsparams = jax_galsim.GSParams(maximum_fft_size=8192) +gsparams = jax_galsim.GSParams(minimum_fft_size=8192, maximum_fft_size=8192) +slen = 21 # image size should also be constant for jit to work (see below for more details). @jax.jit def simulate(flux, sigma): gal = jax_galsim.Gaussian(flux=flux, sigma=sigma, gsparams=gsparams) - return gal.drawImage(scale=0.2).array.sum() + return gal.drawImage(nx=slen, ny=slen, scale=0.2).array.sum() # Changing gsparams here would cause recompilation on next call ``` @@ -146,17 +161,31 @@ avoid problematic control flow in its own implementations. Under `jit`, the **shape** of every array must be determinable at compile time. Operations whose output size depends on input values (e.g., adaptive image -sizing based on a traced parameter) may not work. When using `jax.vmap`, you +sizing based on a traced parameter) may not work. When using `jax.jit` or `jax.vmap`, you must specify fixed image dimensions: ```python +@jax.jit @jax.vmap def batch(sigma): - gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) + gsparams = GSParams(minimum_fft_size=256, maximum_fft_size=256) + gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma).withGSParams(gsparams) # Must specify nx, ny so all images have the same shape return gal.drawImage(scale=0.2, nx=64, ny=64).array ``` +Importantly, the default (and most commonly used) drawing procedure in GalSim (and JAX-GalSim) +transforms image to k-space via an FFT. The size of the "images" in Fourier space usually depends +on traced galaxy profile paramers e.g. size, which makes this incompatible with `jit`. Thus, in JAX-GalSim +we allow for this k-space image size to be fixed explicitly via `GSParams` as done above: + +```python + gsparams = GSParams(minimum_fft_size=256, maximum_fft_size=256) +``` + +where both `minimum_fft_size` and `maximum_fft_size` need to be set to the same value. + + ### The `__init__` gotcha During `jit` tracing, JAX calls constructors with **tracer objects** rather than From 784b9f13c126c3659e2ee2e777eaffcf7c20893c Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 3 Apr 2026 14:25:12 -0400 Subject: [PATCH 07/10] organize --- docs/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index 73f352b3..74e773da 100644 --- a/docs/index.md +++ b/docs/index.md @@ -61,10 +61,10 @@ JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the ```python @jax.jit def simulate(flux, sigma, *, slen=21, fft_size=128): - gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) - gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gal_convolved = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) image = gal_convolved.drawImage(nx=slen, ny=slen, scale=0.2) return image.array.sum() From 336aac91d28d7f825bd8fbbf7342520df16bb2ca Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 3 Apr 2026 14:25:29 -0400 Subject: [PATCH 08/10] please --- docs/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index 74e773da..ca0cd84c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -75,8 +75,8 @@ dflux, dsigma = grad_fn(1e5, 2.0) ``` Note that the size of the image in real space (`slen`) and fourier space -(`minimum_fft_size = maximum_fft_size`) need to be specified in advance for jitting. See the rest -of the documentation for more details and examples. +(`minimum_fft_size = maximum_fft_size`) need to be specified in advance for jitting. +Please see the rest of the documentation for more details and examples. --- From da87c73a6d6f05684c27f3e3d5e717e0de8b5b22 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 3 Apr 2026 14:26:35 -0400 Subject: [PATCH 09/10] a few corrections and further example how to jit --- docs/getting-started/quickstart.md | 34 ++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index d51f3f60..7f33ef5c 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -27,7 +27,7 @@ final = jax_galsim.Convolve([gal, psf]) image = final.drawImage(scale=pixel_scale) # Add Gaussian noise -image = image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma)) +image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma)) # Write to FITS image.write("output/demo1.fits") @@ -42,15 +42,33 @@ Wrap your simulation in `jax.jit` to compile it into an optimized XLA computatio ```python import jax -@jax.jit -def simulate(flux, sigma): +@jax.jit(static_argnames=['slen', 'fft_size']) +def simulate(flux, sigma, *, slen, fft_size): + gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) + psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) + final = jax_galsim.Convolve([gal, psf]) + return final.drawImage(nx=slen, ny=slen, scale=0.2) + +# First call compiles; subsequent calls are fast (as long as slen, fft_size stay the same) +image = simulate(1e5, 2.0, slen=21, fft_size=128) +``` + +Here is another option for jitting using the `partial` utility from `functools`: + +```python +from jax import jit +from functools import partial + +def simulate(flux, sigma, *, slen, fft_size): + gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) gal = jax_galsim.Gaussian(flux=flux, sigma=sigma) psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) final = jax_galsim.Convolve([gal, psf]) - return final.drawImage(scale=0.2) + return final.drawImage(nx=slen, ny=slen, scale=0.2) -# First call compiles; subsequent calls are fast -image = simulate(1e5, 2.0) +simulated_jitted = jit(partial(simulate, slen=21, fft_size=128)) +image = simulated_jitted(1e5, 2.0) ``` ## Automatic Differentiation @@ -81,11 +99,13 @@ import jax.numpy as jnp sigmas = jnp.linspace(1.0, 4.0, 10) +@jax.jit @jax.vmap def batch_simulate(sigma): + gsparams = GSParams(minimum_fft_size=128, maximum_fft_size=128) gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma) psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0) - final = jax_galsim.Convolve([gal, psf]) + final = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) return final.drawImage(scale=0.2, nx=64, ny=64).array # Simulate all 10 galaxies in parallel From 3660b3f8e126e933d32f8661ab5db572757f9ae5 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Fri, 3 Apr 2026 14:34:23 -0400 Subject: [PATCH 10/10] some notes --- docs/getting-started/quickstart.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 7f33ef5c..6d4d6f82 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -54,6 +54,8 @@ def simulate(flux, sigma, *, slen, fft_size): image = simulate(1e5, 2.0, slen=21, fft_size=128) ``` +**Remember**, any arguments that affect control flow (like image size) must be marked as `static_argnames` for JIT to work. + Here is another option for jitting using the `partial` utility from `functools`: ```python @@ -71,6 +73,9 @@ simulated_jitted = jit(partial(simulate, slen=21, fft_size=128)) image = simulated_jitted(1e5, 2.0) ``` +In this case `partial` is used to fix the values of `slen` and `fft_size`, allowing `simulate` to +be jitted without needing to specify those arguments each time. + ## Automatic Differentiation Compute gradients of any scalar output with respect to parameters: