From 0848bc705bf3345ad4e0f790835738a86ed4e63c Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Mar 2026 16:14:09 -0500 Subject: [PATCH 01/16] fix: ensure bounds comparisons can be vmapped --- jax_galsim/bounds.py | 22 +++++++++++++--------- tests/jax/test_vmapping.py | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index ed5942af..a92e4716 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -110,18 +110,20 @@ def includes(self, *args): b = args[0] return ( self.isDefined() - and b.isDefined() - and self.xmin <= b.xmin - and self.xmax >= b.xmax - and self.ymin <= b.ymin - and self.ymax >= b.ymax + & b.isDefined() + & (self.xmin <= b.xmin) + & (self.xmax >= b.xmax) + & (self.ymin <= b.ymin) + & (self.ymax >= b.ymax) ) elif isinstance(args[0], Position): p = args[0] return ( self.isDefined() - and self.xmin <= p.x <= self.xmax - and self.ymin <= p.y <= self.ymax + & (self.xmin <= p.x) + & (p.y <= self.ymax) + & (self.xmin <= p.x) + & (p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) @@ -129,8 +131,10 @@ def includes(self, *args): x, y = args return ( self.isDefined() - and self.xmin <= float(x) <= self.xmax - and self.ymin <= float(y) <= self.ymax + & (self.xmin <= x) + & (y <= self.ymax) + & (self.xmin <= x) + & (y <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index 6b8d7c40..ea092f0d 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -1,5 +1,6 @@ import jax import jax.numpy as jnp +import numpy as np import jax_galsim as galsim @@ -226,3 +227,25 @@ def drawGalaxy(flux): assert arr.shape[0] == 2 assert arr.shape[1] == arr.shape[2] == 128 assert arr[0].sum() < arr[1].sum() + + +def test_bounds_includes_vmapping(): + # See https://github.com/GalSim-developers/JAX-GalSim/issues/190#issuecomment-4031602051 + # for the source of the test code + b0 = galsim.BoundsI(1, 128, 1, 128) + b1 = galsim.BoundsI(32, 98, 32, 98) + b2 = galsim.BoundsI(-1, 10, 5, 200) + + # bounds array + bnd_list = [b1, b2] + bnd_array = jax.tree.map(lambda *vals: jnp.array(vals), *bnd_list) + res = jax.vmap(lambda x: b0.includes(x))(bnd_array) + res_jit = jax.jit(jax.vmap(lambda x: b0.includes(x)))(bnd_array) + np.testing.assert_array_equal(res, res_jit) + + # position array + pos_list = [galsim.PositionD(4, 10), galsim.PositionD(-4, -20)] + pos_array = jax.tree.map(lambda *vals: jnp.array(vals), *pos_list) + res = jax.vmap(lambda x: b0.includes(x))(pos_array) + res_jit = jax.jit(jax.vmap(lambda x: b0.includes(x)))(pos_array) + np.testing.assert_array_equal(res, res_jit) From f6d113b407de0d10ab760863fda47286892bfd25 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Mar 2026 16:20:36 -0500 Subject: [PATCH 02/16] test: add test for position arrays --- tests/jax/test_vmapping.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index ea092f0d..f71022b5 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -242,10 +242,19 @@ def test_bounds_includes_vmapping(): res = jax.vmap(lambda x: b0.includes(x))(bnd_array) res_jit = jax.jit(jax.vmap(lambda x: b0.includes(x)))(bnd_array) np.testing.assert_array_equal(res, res_jit) + np.testing.assert_array_equal(res, np.array([True, False])) - # position array + # position objects pos_list = [galsim.PositionD(4, 10), galsim.PositionD(-4, -20)] pos_array = jax.tree.map(lambda *vals: jnp.array(vals), *pos_list) res = jax.vmap(lambda x: b0.includes(x))(pos_array) res_jit = jax.jit(jax.vmap(lambda x: b0.includes(x)))(pos_array) np.testing.assert_array_equal(res, res_jit) + np.testing.assert_array_equal(res, np.array([True, False])) + + # position arrays - shape is (n_points, 2) + pos_array = jnp.array([[4.0, -4.0, 7.0], [10.0, -20.0, 8.0]]).T + res = jax.vmap(lambda x: b0.includes(*x))(pos_array) + res_jit = jax.jit(jax.vmap(lambda x: b0.includes(*x)))(pos_array) + np.testing.assert_array_equal(res, res_jit) + np.testing.assert_array_equal(res, np.array([True, False, True])) From 03628c16f490866634c289736ab088eec48f14dd Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Mar 2026 16:33:02 -0500 Subject: [PATCH 03/16] fix: wrong comparisons --- jax_galsim/bounds.py | 8 ++++---- tests/jax/test_render_scene.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) create mode 100644 tests/jax/test_render_scene.py diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a92e4716..a209e6c1 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -121,8 +121,8 @@ def includes(self, *args): return ( self.isDefined() & (self.xmin <= p.x) - & (p.y <= self.ymax) - & (self.xmin <= p.x) + & (p.x <= self.xmax) + & (self.ymin <= p.y) & (p.y <= self.ymax) ) else: @@ -132,8 +132,8 @@ def includes(self, *args): return ( self.isDefined() & (self.xmin <= x) - & (y <= self.ymax) - & (self.xmin <= x) + & (x <= self.xmax) + & (self.ymin <= y) & (y <= self.ymax) ) elif len(args) == 0: diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py new file mode 100644 index 00000000..776812e0 --- /dev/null +++ b/tests/jax/test_render_scene.py @@ -0,0 +1,10 @@ +from functools import partial + +import jax + +import jax_galsim as jgs + + +@partial(jax.jit, static_argnames=("n_obj")) +def _generate_sim(rng_key, n_obj): + pass From 7607130058056f945e1aa83b1cd5fc712710ad12 Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Mar 2026 16:54:34 -0500 Subject: [PATCH 04/16] test: add minimal test of vectorized drawing --- jax_galsim/exponential.py | 4 +-- tests/jax/test_render_scene.py | 47 ++++++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index cbb716ad..1d472544 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -92,8 +92,8 @@ def __repr__(self): ) def __str__(self): - s = "galsim.Exponential(scale_radius=%s" % ensure_hashable(self.scale_radius) - s += ", flux=%s" % ensure_hashable(self.flux) + s = "galsim.Exponential(scale_radius=%s" % (ensure_hashable(self.scale_radius),) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 776812e0..94ed6473 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -1,10 +1,53 @@ from functools import partial import jax +import jax.random as jrng import jax_galsim as jgs +def _generate_stamp(rng_key, psf): + rng_key, use_key = jrng.split(rng_key) + flux = jrng.uniform(use_key, minval=1.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + hlr = jrng.uniform(use_key, minval=0.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + g1 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + rng_key, use_key = jrng.split(rng_key) + g2 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + + rng_key, use_key = jrng.split(rng_key) + dx = jrng.uniform(use_key, minval=-10, maxval=10) + rng_key, use_key = jrng.split(rng_key) + dy = jrng.uniform(use_key, minval=-10, maxval=10) + + return ( + jgs.Convolve( + [ + jgs.Exponential(half_light_radius=hlr) + .shear(g1=g1, g2=g2) + .shift(dx, dy) + .withFlux(flux), + psf, + ] + ) + .withGSParams(minimum_fft_size=1024, maximum_fft_size=1024) + .drawImage(nx=200, ny=200, scale=0.2) + ) + + @partial(jax.jit, static_argnames=("n_obj")) -def _generate_sim(rng_key, n_obj): - pass +def _generate_stamps(rng_key, psf, n_obj): + use_keys = jrng.split(rng_key, num=n_obj + 1) + rng_key = use_keys[0] + use_keys = use_keys[1:] + + return jax.vmap(_generate_stamp, in_axes=(0, None))(use_keys, psf) + + +def test_render_scene_smoke(): + psf = jgs.Gaussian(fwhm=0.9) + img = _generate_stamps(jrng.key(42), psf, 5) + + assert img.array.shape == (5, 200, 200) + assert img.array.sum() > 5.0 From a4012865e2e09835d7f3e5045ce2dddcfe9e592c Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 11 Mar 2026 16:55:37 -0500 Subject: [PATCH 05/16] test: rename to be more accurate --- tests/jax/test_render_scene.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 94ed6473..0c32d5a7 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -45,7 +45,7 @@ def _generate_stamps(rng_key, psf, n_obj): return jax.vmap(_generate_stamp, in_axes=(0, None))(use_keys, psf) -def test_render_scene_smoke(): +def test_render_scene_draw_many_ffts_full_img(): psf = jgs.Gaussian(fwhm=0.9) img = _generate_stamps(jrng.key(42), psf, 5) From cafea4930f22d2e7d7bbf7acad39b96890967d11 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 13 Mar 2026 12:28:19 -0500 Subject: [PATCH 06/16] fix: lots of changes for scenes --- jax_galsim/angle.py | 4 +- jax_galsim/bounds.py | 14 +-- jax_galsim/box.py | 6 +- jax_galsim/convolve.py | 7 ++ jax_galsim/core/utils.py | 38 ++++--- jax_galsim/gaussian.py | 4 +- jax_galsim/image.py | 14 +-- jax_galsim/moffat.py | 4 +- jax_galsim/spergel.py | 2 +- jax_galsim/sum.py | 7 ++ jax_galsim/transform.py | 12 +- jax_galsim/wcs.py | 2 +- tests/jax/test_render_scene.py | 194 ++++++++++++++++++++++++++++++++- 13 files changed, 260 insertions(+), 48 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index ed513315..fad56976 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -89,7 +89,7 @@ def __repr__(self): elif self == arcsec: return "galsim.arcsec" else: - return "galsim.AngleUnit(%r)" % ensure_hashable(self.value) + return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),) def __eq__(self, other): return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value) @@ -222,7 +222,7 @@ def __str__(self): return str(ensure_hashable(self._rad)) + " radians" def __repr__(self): - return "galsim.Angle(%r, galsim.radians)" % ensure_hashable(self.rad) + return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),) def __eq__(self, other): return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a209e6c1..39c74535 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -162,10 +162,11 @@ def __and__(self, other): xmax = jnp.minimum(self.xmax, other.xmax) ymin = jnp.maximum(self.ymin, other.ymin) ymax = jnp.minimum(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: - return self.__class__() - else: - return self.__class__(xmin, xmax, ymin, ymax) + return jax.lax.cond( + (xmin > xmax) | (ymin > ymax), + lambda : self.__class__(), + lambda : self.__class__(xmin, xmax, ymin, ymax), + ) def __add__(self, other): if isinstance(other, self.__class__): @@ -233,10 +234,7 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - children = (self.xmin, self.xmax, self.ymin, self.ymax) - else: - children = tuple() + children = (self.xmin, self.xmax, self.ymin, self.ymax) # Define auxiliary static data that doesn’t need to be traced aux_data = None return (children, aux_data) diff --git a/jax_galsim/box.py b/jax_galsim/box.py index 95b3d373..6a6ccde9 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -62,7 +62,7 @@ def __str__(self): ensure_hashable(self.height), ) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s @@ -146,9 +146,9 @@ def __repr__(self): ) def __str__(self): - s = "galsim.Pixel(scale=%s" % ensure_hashable(self.scale) + s = "galsim.Pixel(scale=%s" % (ensure_hashable(self.scale),) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 6961ad39..8f3de7c4 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -337,6 +337,13 @@ def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(children[0]["obj_list"], **aux_data) + def to_galsim(self): + return _galsim.Convolution( + [arg.to_galsim() for arg in self.obj_list], + gsparams=self._gsparams.to_galsim(), + propagate_gsparams=self._propagate_gsparams, + ) + @implements( _galsim.convolve.Deconvolve, diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c4cc5413..c2f943e1 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -115,24 +115,28 @@ def cast_to_int(x): return int(x) except Exception: try: - if not jnp.any(jnp.isnan(x)): - return jnp.asarray(x, dtype=int) - else: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x + return jnp.asarray(x, dtype=int) except Exception: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x + return x + # try: + # if not jnp.any(jnp.isnan(x)): + # return jnp.asarray(x, dtype=int) + # else: + # # this will return the same value for anything int-like that + # # cannot be cast to int + # # however, it will raise an error if something is not int-like + # if type(x) is object: + # return x + # else: + # return 1 * x + # except Exception: + # # this will return the same value for anything int-like that + # # cannot be cast to int + # # however, it will raise an error if something is not int-like + # if type(x) is object: + # return x + # else: + # return 1 * x def is_equal_with_arrays(x, y): diff --git a/jax_galsim/gaussian.py b/jax_galsim/gaussian.py index 3b937a11..f9c1616d 100644 --- a/jax_galsim/gaussian.py +++ b/jax_galsim/gaussian.py @@ -107,8 +107,8 @@ def __repr__(self): ) def __str__(self): - s = "galsim.Gaussian(sigma=%s" % ensure_hashable(self.sigma) - s += ", flux=%s" % ensure_hashable(self.flux) + s = "galsim.Gaussian(sigma=%s" % (ensure_hashable(self.sigma),) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/image.py b/jax_galsim/image.py index f6f0f518..ab1e04b1 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1061,8 +1061,8 @@ def rot_180(self): def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing - children = (self.array, self.wcs) - aux_data = {"dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst} + children = (self.array, self.wcs, self.bounds) + aux_data = {"dtype": self.dtype, "isconst": self.isconst} # other routines may add these attributes to images on the fly # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): @@ -1080,15 +1080,15 @@ def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj._array = children[0] obj.wcs = children[1] - obj._bounds = aux_data["bounds"] + obj._bounds = children[2] obj._dtype = aux_data["dtype"] obj._is_const = aux_data["isconst"] - if len(children) > 2: - obj.added_flux = children[2] + if len(children) > 3: + obj.added_flux = children[3] if "header" in aux_data: obj.header = aux_data["header"] - if len(children) > 3: - obj.photons = children[3] + if len(children) > 4: + obj.photons = children[4] return obj @classmethod diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 22d4e16e..2a9b312b 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -232,9 +232,9 @@ def __str__(self): ensure_hashable(self.scale_radius), ) if self.trunc != 0.0: - s += ", trunc=%s" % ensure_hashable(self.trunc) + s += ", trunc=%s" % (ensure_hashable(self.trunc),) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index b4b17904..fbc756db 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -387,7 +387,7 @@ def __str__(self): ensure_hashable(self.half_light_radius), ) if self.flux != 1.0: - s += ", flux=%s" % ensure_hashable(self.flux) + s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 958e6bfa..ec434f6f 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -255,3 +255,10 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(children[0]["obj_list"], **aux_data) + + def to_galsim(self): + return _galsim.Sum( + [arg.to_galsim() for arg in self.obj_list], + gsparams=self._gsparams.to_galsim(), + propagate_gsparams=self._propagate_gsparams, + ) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index d7d23838..21ad71d7 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -232,7 +232,7 @@ def __str__(self): ensure_hashable(self._offset.y), ) if self._flux_ratio != 1.0: - s += " * %s" % ensure_hashable(self._flux_ratio) + s += " * %s" % (ensure_hashable(self._flux_ratio),) return s @property @@ -411,6 +411,16 @@ def tree_unflatten(cls, aux_data, children): obj._original, obj._params = children return obj + def to_galsim(self): + return _galsim.Transformation( + self.original.to_galsim(), + jac=self.jac, + offset=self.offset.to_galsim(), + flux_ratio=self.flux_ratio, + gsparams=self.gsparams.to_galsim(), + propagate_gsparams=self.propagate_gsparams, + ) + def _Transform( obj, diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index fcddf18e..ee18ef33 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -919,7 +919,7 @@ def __eq__(self, other): ) def __repr__(self): - return "galsim.PixelScale(%r)" % ensure_hashable(self.scale) + return "galsim.PixelScale(%r)" % (ensure_hashable(self.scale),) def __hash__(self): return hash(repr(self)) diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 0c32d5a7..09062795 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -1,12 +1,15 @@ from functools import partial import jax +import jax.numpy as jnp import jax.random as jrng +import numpy as np +import galsim import jax_galsim as jgs -def _generate_stamp(rng_key, psf): +def _generate_image_one(rng_key, psf): rng_key, use_key = jrng.split(rng_key) flux = jrng.uniform(use_key, minval=1.5, maxval=2.5) rng_key, use_key = jrng.split(rng_key) @@ -37,17 +40,200 @@ def _generate_stamp(rng_key, psf): @partial(jax.jit, static_argnames=("n_obj")) -def _generate_stamps(rng_key, psf, n_obj): +def _generate_image(rng_key, psf, n_obj): use_keys = jrng.split(rng_key, num=n_obj + 1) rng_key = use_keys[0] use_keys = use_keys[1:] - return jax.vmap(_generate_stamp, in_axes=(0, None))(use_keys, psf) + return jax.vmap(_generate_image_one, in_axes=(0, None))(use_keys, psf) def test_render_scene_draw_many_ffts_full_img(): psf = jgs.Gaussian(fwhm=0.9) - img = _generate_stamps(jrng.key(42), psf, 5) + img = _generate_image(jrng.key(10), psf, 50) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(img.array.sum(axis=0)) + pdb.set_trace() assert img.array.shape == (5, 200, 200) assert img.array.sum() > 5.0 + + +def _get_bd_jgs( + flux_d, + flux_b, + hlr_b, + hlr_d, + q_b, + q_d, + beta, + *, + psf_hlr=0.7, +): + components = [] + + # disk + disk = jgs.Exponential(flux=flux_d, half_light_radius=hlr_d).shear( + q=q_d, beta=beta * jgs.degrees + ) + components.append(disk) + + # bulge + bulge = jgs.Spergel(nu=-0.6, flux=flux_b, half_light_radius=hlr_b).shear( + q=q_b, beta=beta * jgs.degrees + ) + components.append(bulge) + + galaxy = jgs.Add(components) + + # psf + psf = jgs.Moffat(2, flux=1.0, half_light_radius=0.7) + + gal_conv = jgs.Convolve([galaxy, psf]) + return gal_conv + + +@partial(jax.jit, static_argnames=("fft_size", "slen")) +def _draw_stamp_jgs( + galaxy_params: dict, + image_pos: jgs.PositionD, + local_wcs: jgs.PixelScale, + fft_size: int, + slen: int, +) -> jax.Array: + gsparams = jgs.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + + convolved_object = _get_bd_jgs(**galaxy_params).withGSParams(gsparams) + + # you have to render just with on offset in order to keep the bounds + # static during rendering + dx = image_pos.x - jnp.floor(image_pos.x + 0.5) + dy = image_pos.y - jnp.floor(image_pos.y + 0.5) + stamp = convolved_object.drawImage( + nx=slen, ny=slen, offset=(dx, dy), wcs=local_wcs, dtype=jnp.float64 + ) + # then we apply a shift to get the correct final bounds + shift = jgs.PositionI( + jnp.int32(jnp.floor(image_pos.x + 0.5 - stamp.bounds.true_center.x)), + jnp.int32(jnp.floor(image_pos.y + 0.5 - stamp.bounds.true_center.y)), + ) + stamp.shift(shift) + + return stamp + + +@partial(jax.jit, static_argnames=("slen",)) +def _add_to_image(carry, x, slen): + image = carry[0] + stamp = x + + b = stamp.bounds & image.bounds + if b.isDefined(): + i1 = b.ymin - image.ymin + j1 = b.xmin - image.xmin + start_inds = (i1, j1) + subim = jax.lax.dynamic_slice( + image.array, start_inds, (slen, slen) + ) + subim = subim + stamp.array + + image._array = jax.lax.dynamic_update_slice( + image.array, subim, start_inds, + ) + + return (image,), None + + +def _render_scene_stamps_galsim( + galaxy_params: dict, + image_pos: list[galsim.PositionD], + local_wcs: list[galsim.PixelScale], + fft_size: int, + slen: int, + image: galsim.ImageD, + ng: int, +): + gsparams = jgs.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + + for i in range(ng): + gpars = { + k: v[i] for k, v in galaxy_params.items() + } + convolved_object = _get_bd_jgs(**gpars).withGSParams(gsparams).to_galsim() + + stamp = convolved_object.drawImage( + nx=slen, ny=slen, center=image_pos[i], wcs=local_wcs[i], dtype=np.float64 + ) + + b = stamp.bounds & image.bounds + if b.isDefined(): + image[b] += stamp[b] + + return image + + +def test_render_scene_stamps(): + image = jgs.Image(ncol=200, nrow=200, scale=0.2, dtype=jnp.float64) + wcs = image.wcs + + rng = np.random.default_rng(seed=10) + ng = 5 + + galaxy_params = { + "flux_d": rng.uniform(low=0, high=1.0, size=ng), + "flux_b": rng.uniform(low=0, high=1.0, size=ng), + "hlr_b": rng.uniform(low=0.3, high=0.5, size=ng), + "hlr_d": rng.uniform(low=0.5, high=0.7, size=ng), + "q_b": rng.uniform(low=0.1, high=0.9, size=ng), + "q_d": rng.uniform(low=0.1, high=0.9, size=ng), + "beta": rng.uniform(low=0, high=360, size=ng), + "x": rng.uniform(low=10, high=190, size=ng), + "y": rng.uniform(low=10, high=190, size=ng), + } + + x = galaxy_params.pop("x") + y = galaxy_params.pop("y") + image_positions = jax.vmap(lambda x, y: jgs.PositionD(x=x, y=y))(x, y) + local_wcss = jax.vmap(lambda x: wcs.local(image_pos=x))(image_positions) + + stamps = jax.jit(jax.vmap(partial(_draw_stamp_jgs, slen=52, fft_size=256)))( + galaxy_params, image_positions, local_wcss + ) + assert stamps.array.shape == (ng, 52, 52) + assert stamps.array.sum() > 0 + + pad_image = jgs.ImageD( + jnp.pad(image.array, 52), wcs=image.wcs, bounds=image.bounds.withBorder(52) + ) + + final_pad_image = jax.lax.scan(partial(_add_to_image, slen=52), (pad_image,), xs=stamps, length=ng)[0][0] + np.testing.assert_allclose(final_pad_image.array.sum(), stamps.array.sum()) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(final_pad_image.array) + pdb.set_trace() + + gs_image = jgs.Image(ncol=200, nrow=200, scale=0.2, dtype=np.float64) + wcs = image.wcs + + gs_image_positions = map(lambda tup: galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y)) + gs_local_wcss = map(lambda x: wcs.local(image_pos=x), gs_image_positions) + + _render_scene_stamps_galsim( + galaxy_params, + gs_image_positions, + gs_local_wcss, + 256, + 52, + gs_image, + ng, + ) From ea8cf0fc964c9062da554a6baa8dcdcb67739100 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 13 Mar 2026 13:02:38 -0500 Subject: [PATCH 07/16] fix: tests pass --- jax_galsim/bounds.py | 6 +- jax_galsim/convolve.py | 7 -- jax_galsim/sum.py | 7 -- jax_galsim/transform.py | 10 --- tests/jax/test_render_scene.py | 126 ++++++++++++++++++++++++++------- 5 files changed, 103 insertions(+), 53 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 39c74535..d49baa6f 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -163,9 +163,9 @@ def __and__(self, other): ymin = jnp.maximum(self.ymin, other.ymin) ymax = jnp.minimum(self.ymax, other.ymax) return jax.lax.cond( - (xmin > xmax) | (ymin > ymax), - lambda : self.__class__(), - lambda : self.__class__(xmin, xmax, ymin, ymax), + (xmin > xmax) | (ymin > ymax), + lambda: self.__class__(), + lambda: self.__class__(xmin, xmax, ymin, ymax), ) def __add__(self, other): diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 8f3de7c4..6961ad39 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -337,13 +337,6 @@ def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(children[0]["obj_list"], **aux_data) - def to_galsim(self): - return _galsim.Convolution( - [arg.to_galsim() for arg in self.obj_list], - gsparams=self._gsparams.to_galsim(), - propagate_gsparams=self._propagate_gsparams, - ) - @implements( _galsim.convolve.Deconvolve, diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index ec434f6f..958e6bfa 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -255,10 +255,3 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(children[0]["obj_list"], **aux_data) - - def to_galsim(self): - return _galsim.Sum( - [arg.to_galsim() for arg in self.obj_list], - gsparams=self._gsparams.to_galsim(), - propagate_gsparams=self._propagate_gsparams, - ) diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 21ad71d7..bf9f4f6d 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -411,16 +411,6 @@ def tree_unflatten(cls, aux_data, children): obj._original, obj._params = children return obj - def to_galsim(self): - return _galsim.Transformation( - self.original.to_galsim(), - jac=self.jac, - offset=self.offset.to_galsim(), - flux_ratio=self.flux_ratio, - gsparams=self.gsparams.to_galsim(), - propagate_gsparams=self.propagate_gsparams, - ) - def _Transform( obj, diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 09062795..7ae77219 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -1,11 +1,11 @@ from functools import partial +import galsim as _galsim import jax import jax.numpy as jnp import jax.random as jrng import numpy as np -import galsim import jax_galsim as jgs @@ -112,8 +112,10 @@ def _draw_stamp_jgs( # you have to render just with on offset in order to keep the bounds # static during rendering - dx = image_pos.x - jnp.floor(image_pos.x + 0.5) - dy = image_pos.y - jnp.floor(image_pos.y + 0.5) + dx = image_pos.x - jnp.ceil(image_pos.x) + dy = image_pos.y - jnp.ceil(image_pos.y) + dx = dx + 0.5 * ((slen + 1) % 2) + dy = dy + 0.5 * ((slen + 1) % 2) stamp = convolved_object.drawImage( nx=slen, ny=slen, offset=(dx, dy), wcs=local_wcs, dtype=jnp.float64 ) @@ -137,37 +139,73 @@ def _add_to_image(carry, x, slen): i1 = b.ymin - image.ymin j1 = b.xmin - image.xmin start_inds = (i1, j1) - subim = jax.lax.dynamic_slice( - image.array, start_inds, (slen, slen) - ) + subim = jax.lax.dynamic_slice(image.array, start_inds, (slen, slen)) subim = subim + stamp.array image._array = jax.lax.dynamic_update_slice( - image.array, subim, start_inds, + image.array, + subim, + start_inds, ) return (image,), None +def _get_bd_gs( + flux_d, + flux_b, + hlr_b, + hlr_d, + q_b, + q_d, + beta, + *, + psf_hlr=0.7, +): + components = [] + + # disk + disk = _galsim.Exponential(flux=flux_d, half_light_radius=hlr_d).shear( + q=q_d, beta=beta * _galsim.degrees + ) + components.append(disk) + + # bulge + bulge = _galsim.Spergel(nu=-0.6, flux=flux_b, half_light_radius=hlr_b).shear( + q=q_b, beta=beta * _galsim.degrees + ) + components.append(bulge) + + galaxy = _galsim.Add(components) + + # psf + psf = _galsim.Moffat(2, flux=1.0, half_light_radius=0.7) + + gal_conv = _galsim.Convolve([galaxy, psf]) + return gal_conv + + def _render_scene_stamps_galsim( galaxy_params: dict, - image_pos: list[galsim.PositionD], - local_wcs: list[galsim.PixelScale], + image_pos: list[_galsim.PositionD], + local_wcs: list[_galsim.PixelScale], fft_size: int, slen: int, - image: galsim.ImageD, + image: _galsim.ImageD, ng: int, ): - gsparams = jgs.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + gsparams = _galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) for i in range(ng): - gpars = { - k: v[i] for k, v in galaxy_params.items() - } - convolved_object = _get_bd_jgs(**gpars).withGSParams(gsparams).to_galsim() + gpars = {k: v[i] for k, v in galaxy_params.items()} + convolved_object = _get_bd_gs(**gpars).withGSParams(gsparams) stamp = convolved_object.drawImage( - nx=slen, ny=slen, center=image_pos[i], wcs=local_wcs[i], dtype=np.float64 + nx=slen, + ny=slen, + center=(image_pos[i].x, image_pos[i].y), + wcs=local_wcs[i], + dtype=np.float64, ) b = stamp.bounds & image.bounds @@ -183,6 +221,8 @@ def test_render_scene_stamps(): rng = np.random.default_rng(seed=10) ng = 5 + slen = 52 + fft_size = 2048 galaxy_params = { "flux_d": rng.uniform(low=0, high=1.0, size=ng), @@ -201,17 +241,19 @@ def test_render_scene_stamps(): image_positions = jax.vmap(lambda x, y: jgs.PositionD(x=x, y=y))(x, y) local_wcss = jax.vmap(lambda x: wcs.local(image_pos=x))(image_positions) - stamps = jax.jit(jax.vmap(partial(_draw_stamp_jgs, slen=52, fft_size=256)))( + stamps = jax.jit(jax.vmap(partial(_draw_stamp_jgs, slen=slen, fft_size=fft_size)))( galaxy_params, image_positions, local_wcss ) - assert stamps.array.shape == (ng, 52, 52) + assert stamps.array.shape == (ng, slen, slen) assert stamps.array.sum() > 0 pad_image = jgs.ImageD( - jnp.pad(image.array, 52), wcs=image.wcs, bounds=image.bounds.withBorder(52) + jnp.pad(image.array, slen), wcs=image.wcs, bounds=image.bounds.withBorder(slen) ) - final_pad_image = jax.lax.scan(partial(_add_to_image, slen=52), (pad_image,), xs=stamps, length=ng)[0][0] + final_pad_image = jax.lax.scan( + partial(_add_to_image, slen=slen), (pad_image,), xs=stamps, length=ng + )[0][0] np.testing.assert_allclose(final_pad_image.array.sum(), stamps.array.sum()) if False: @@ -222,18 +264,50 @@ def test_render_scene_stamps(): plt.imshow(final_pad_image.array) pdb.set_trace() - gs_image = jgs.Image(ncol=200, nrow=200, scale=0.2, dtype=np.float64) - wcs = image.wcs + gs_image = _galsim.Image(ncol=200, nrow=200, scale=0.2, dtype=np.float64) + wcs = gs_image.wcs - gs_image_positions = map(lambda tup: galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y)) - gs_local_wcss = map(lambda x: wcs.local(image_pos=x), gs_image_positions) + gs_image_positions = list( + map(lambda tup: _galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y)) + ) + gs_local_wcss = list(map(lambda x: wcs.local(image_pos=x), gs_image_positions)) _render_scene_stamps_galsim( galaxy_params, gs_image_positions, gs_local_wcss, - 256, - 52, + fft_size, + slen, gs_image, ng, ) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(gs_image.array) + pdb.set_trace() + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(final_pad_image.array[slen:-slen, slen:-slen] - gs_image.array) + pdb.set_trace() + + np.testing.assert_allclose( + gs_image.array.sum(), + final_pad_image.array[slen:-slen, slen:-slen].sum(), + atol=1e-4, + rtol=1e-5, + ) + + np.testing.assert_allclose( + gs_image.array, + final_pad_image.array[slen:-slen, slen:-slen], + atol=1e-6, + rtol=1e-6, + ) From 21618abac2017d699ee4f8bad76faa68e8bb56be Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 13 Mar 2026 13:03:18 -0500 Subject: [PATCH 08/16] doc: comment on magoic --- tests/jax/test_render_scene.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 7ae77219..b1ebcf1e 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -112,6 +112,8 @@ def _draw_stamp_jgs( # you have to render just with on offset in order to keep the bounds # static during rendering + # the exact pixel computation here is MAGIC right now + # we'll need a way to make this easier dx = image_pos.x - jnp.ceil(image_pos.x) dy = image_pos.y - jnp.ceil(image_pos.y) dx = dx + 0.5 * ((slen + 1) % 2) From 79675dc22308440f9f07f89ad530273265a74ef4 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 14 Mar 2026 20:57:48 -0500 Subject: [PATCH 09/16] fix: make it pass with static bounds --- jax_galsim/bounds.py | 129 +++++++++++++++++++-------------- tests/jax/test_api.py | 40 +--------- tests/jax/test_render_scene.py | 49 +++++++------ tests/jax/test_vmapping.py | 51 ------------- 4 files changed, 100 insertions(+), 169 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index d49baa6f..3bd18861 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,11 +1,8 @@ import galsim as _galsim -import jax -import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - cast_to_float, - cast_to_int, ensure_hashable, implements, ) @@ -84,13 +81,31 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - # for simple inputs, we can check if the bounds are valid - if ( - isinstance(self.xmin, (float, int)) - and isinstance(self.xmax, (float, int)) - and isinstance(self.ymin, (float, int)) - and isinstance(self.ymax, (float, int)) - and ((self.xmin > self.xmax) or (self.ymin > self.ymax)) + if not ( + isinstance( + self.xmin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.xmax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + ): + raise ValueError( + "BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!" + ) + + if not ( + float(self.xmin) <= float(self.xmax) + and float(self.ymin) <= float(self.ymax) ): self._isdefined = False @@ -110,20 +125,20 @@ def includes(self, *args): b = args[0] return ( self.isDefined() - & b.isDefined() - & (self.xmin <= b.xmin) - & (self.xmax >= b.xmax) - & (self.ymin <= b.ymin) - & (self.ymax >= b.ymax) + and b.isDefined() + and (self.xmin <= b.xmin) + and (self.xmax >= b.xmax) + and (self.ymin <= b.ymin) + and (self.ymax >= b.ymax) ) elif isinstance(args[0], Position): p = args[0] return ( self.isDefined() - & (self.xmin <= p.x) - & (p.x <= self.xmax) - & (self.ymin <= p.y) - & (p.y <= self.ymax) + and (self.xmin <= p.x) + and (p.x <= self.xmax) + and (self.ymin <= p.y) + and (p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) @@ -131,10 +146,10 @@ def includes(self, *args): x, y = args return ( self.isDefined() - & (self.xmin <= x) - & (x <= self.xmax) - & (self.ymin <= y) - & (y <= self.ymax) + and (self.xmin <= x) + and (x <= self.xmax) + and (self.ymin <= y) + and (y <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") @@ -148,8 +163,8 @@ def expand(self, factor_x, factor_y=None): dx = (self.xmax - self.xmin) * 0.5 * (factor_x - 1.0) dy = (self.ymax - self.ymin) * 0.5 * (factor_y - 1.0) if isinstance(self, BoundsI): - dx = jnp.ceil(dx) - dy = jnp.ceil(dy) + dx = np.ceil(dx) + dy = np.ceil(dy) return self.withBorder(dx, dy) def __and__(self, other): @@ -158,34 +173,33 @@ def __and__(self, other): if not self.isDefined() or not other.isDefined(): return self.__class__() else: - xmin = jnp.maximum(self.xmin, other.xmin) - xmax = jnp.minimum(self.xmax, other.xmax) - ymin = jnp.maximum(self.ymin, other.ymin) - ymax = jnp.minimum(self.ymax, other.ymax) - return jax.lax.cond( - (xmin > xmax) | (ymin > ymax), - lambda: self.__class__(), - lambda: self.__class__(xmin, xmax, ymin, ymax), - ) + xmin = np.maximum(self.xmin, other.xmin) + xmax = np.minimum(self.xmax, other.xmax) + ymin = np.maximum(self.ymin, other.ymin) + ymax = np.minimum(self.ymax, other.ymax) + if (xmin > xmax) or (ymin > ymax): + return self.__class__() + else: + return self.__class__(xmin, xmax, ymin, ymax) def __add__(self, other): if isinstance(other, self.__class__): if not other.isDefined(): return self elif self.isDefined(): - xmin = jnp.minimum(self.xmin, other.xmin) - xmax = jnp.maximum(self.xmax, other.xmax) - ymin = jnp.minimum(self.ymin, other.ymin) - ymax = jnp.maximum(self.ymax, other.ymax) + xmin = np.minimum(self.xmin, other.xmin) + xmax = np.maximum(self.xmax, other.xmax) + ymin = np.minimum(self.ymin, other.ymin) + ymax = np.maximum(self.ymax, other.ymax) return self.__class__(xmin, xmax, ymin, ymax) else: return other elif isinstance(other, self._pos_class): if self.isDefined(): - xmin = jnp.minimum(self.xmin, other.x) - xmax = jnp.maximum(self.xmax, other.x) - ymin = jnp.minimum(self.ymin, other.y) - ymax = jnp.maximum(self.ymax, other.y) + xmin = np.minimum(self.xmin, other.x) + xmax = np.maximum(self.xmax, other.x) + ymin = np.minimum(self.ymin, other.y) + ymax = np.maximum(self.ymax, other.y) return self.__class__(xmin, xmax, ymin, ymax) else: return self.__class__(other) @@ -234,15 +248,18 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - children = (self.xmin, self.xmax, self.ymin, self.ymax) + children = () # Define auxiliary static data that doesn’t need to be traced - aux_data = None + if self.isDefined(): + aux_data = (self.xmin, self.xmax, self.ymin, self.ymax) + else: + aux_data = () return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls(*children) + return cls(*aux_data) @classmethod def from_galsim(cls, galsim_bounds): @@ -293,15 +310,15 @@ class BoundsD(Bounds): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - self.xmin = cast_to_float(self.xmin) - self.xmax = cast_to_float(self.xmax) - self.ymin = cast_to_float(self.ymin) - self.ymax = cast_to_float(self.ymax) + self.xmin = float(self.xmin) + self.xmax = float(self.xmax) + self.ymin = float(self.ymin) + self.ymax = float(self.ymax) def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, np.ndarray) and x.shape == () and x.dtype.name in ["float32", "float64", "float"] ): @@ -342,15 +359,15 @@ def __init__(self, *args, **kwargs): ): raise TypeError("BoundsI must be initialized with integer values") - self.xmin = cast_to_int(self.xmin) - self.xmax = cast_to_int(self.xmax) - self.ymin = cast_to_int(self.ymin) - self.ymax = cast_to_int(self.ymax) + self.xmin = int(self.xmin) + self.xmax = int(self.xmax) + self.ymin = int(self.ymin) + self.ymax = int(self.ymax) def _check_scalar(self, x, name): try: if ( - isinstance(x, jax.Array) + isinstance(x, np.ndarray) and x.shape == () and x.dtype.name in ["int32", "int64", "int"] ): diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e79f320c..622e3225 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -494,10 +494,8 @@ def _reg_sfun(g1): [ jax_galsim.BoundsD(), jax_galsim.BoundsI(), - jax_galsim.BoundsD( - jnp.array(0.2), jnp.array(4.0), jnp.array(-0.5), jnp.array(4.7) - ), - jax_galsim.BoundsI(jnp.array(-10), jnp.array(5), jnp.array(0), jnp.array(7)), + jax_galsim.BoundsD(0.2, 4.0, -0.5, 4.7), + jax_galsim.BoundsI(-10, 5, 0, 7), ], ) def test_api_bounds(obj): @@ -508,40 +506,6 @@ def test_api_bounds(obj): # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj - if isinstance(obj, jax_galsim.BoundsD): - - def _reg_sfun(g1): - return ( - ( - obj.__class__(g1, g1 + 0.5, 2 * g1, 2 * g1 + 0.5).expand(0.5) - + obj.__class__(-g1, -g1 + 0.5, -2 * g1, -2 * g1 + 0.5) - ) - .expand(4) - .area() - ) - - _sfun = jax.jit(_reg_sfun) - - _sgradfun = jax.jit(jax.grad(_sfun)) - _sfun_vmap = jax.jit(jax.vmap(_sfun)) - _sgradfun_vmap = jax.jit(jax.vmap(_sgradfun)) - - # we can jit the object - np.testing.assert_allclose(_sfun(0.3), _reg_sfun(0.3)) - - # check derivs - eps = 1e-6 - grad = _sgradfun(0.3) - finite_diff = (_reg_sfun(0.3 + eps) - _reg_sfun(0.3 - eps)) / (2 * eps) - np.testing.assert_allclose(grad, finite_diff) - - # check vmap - x = jnp.linspace(-0.9, 0.9, 10) - np.testing.assert_allclose(_sfun_vmap(x), [_reg_sfun(_x) for _x in x]) - - # check vmap grad - np.testing.assert_allclose(_sgradfun_vmap(x), [_sgradfun(_x) for _x in x]) - @pytest.mark.parametrize( "obj", diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index b1ebcf1e..09dbf61a 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -60,8 +60,8 @@ def test_render_scene_draw_many_ffts_full_img(): plt.imshow(img.array.sum(axis=0)) pdb.set_trace() - assert img.array.shape == (5, 200, 200) - assert img.array.sum() > 5.0 + assert img.array.shape == (50, 200, 200) + assert img.array.sum() > 50.0 def _get_bd_jgs( @@ -121,12 +121,6 @@ def _draw_stamp_jgs( stamp = convolved_object.drawImage( nx=slen, ny=slen, offset=(dx, dy), wcs=local_wcs, dtype=jnp.float64 ) - # then we apply a shift to get the correct final bounds - shift = jgs.PositionI( - jnp.int32(jnp.floor(image_pos.x + 0.5 - stamp.bounds.true_center.x)), - jnp.int32(jnp.floor(image_pos.y + 0.5 - stamp.bounds.true_center.y)), - ) - stamp.shift(shift) return stamp @@ -134,21 +128,25 @@ def _draw_stamp_jgs( @partial(jax.jit, static_argnames=("slen",)) def _add_to_image(carry, x, slen): image = carry[0] - stamp = x - - b = stamp.bounds & image.bounds - if b.isDefined(): - i1 = b.ymin - image.ymin - j1 = b.xmin - image.xmin - start_inds = (i1, j1) - subim = jax.lax.dynamic_slice(image.array, start_inds, (slen, slen)) - subim = subim + stamp.array - - image._array = jax.lax.dynamic_update_slice( - image.array, - subim, - start_inds, - ) + stamp, image_pos = x + + # then we apply a shift to get the correct final bounds + shift = jgs.PositionI( + jnp.int32(jnp.floor(image_pos.x + 0.5 - stamp.bounds.true_center.x)), + jnp.int32(jnp.floor(image_pos.y + 0.5 - stamp.bounds.true_center.y)), + ) + + i1 = stamp.bounds.ymin + shift.y - image.ymin + j1 = stamp.bounds.xmin + shift.x - image.xmin + start_inds = (i1, j1) + subim = jax.lax.dynamic_slice(image.array, start_inds, (slen, slen)) + subim = subim + stamp.array + + image._array = jax.lax.dynamic_update_slice( + image.array, + subim, + start_inds, + ) return (image,), None @@ -254,7 +252,10 @@ def test_render_scene_stamps(): ) final_pad_image = jax.lax.scan( - partial(_add_to_image, slen=slen), (pad_image,), xs=stamps, length=ng + partial(_add_to_image, slen=slen), + (pad_image,), + xs=(stamps, image_positions), + length=ng, )[0][0] np.testing.assert_allclose(final_pad_image.array.sum(), stamps.array.sum()) diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index f71022b5..4a3e6b31 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -1,6 +1,5 @@ import jax import jax.numpy as jnp -import numpy as np import jax_galsim as galsim @@ -142,25 +141,6 @@ def test_eq(self, other): assert test_eq(obj_duplicated, obj) -def test_bounds_vmapping(): - obj = galsim.BoundsD(0.0, 1.0, 0.0, 1.0) - obj_d = jax.vmap(galsim.BoundsD)(0.0 * e, 1.0 * e, 0.0 * e, 1.0 * e) - - objI = galsim.BoundsI(0.0, 1.0, 0.0, 1.0) - objI_d = jax.vmap(galsim.BoundsI)(0.0 * e, 1.0 * e, 0.0 * e, 1.0 * e) - - def test_eq(self, other): - return ( - (self.xmin == jnp.array([other.xmin, other.xmin])).all() - and (self.xmax == jnp.array([other.xmax, other.xmax])).all() - and (self.ymin == jnp.array([other.ymin, other.ymin])).all() - and (self.ymax == jnp.array([other.ymax, other.ymax])).all() - ) - - assert test_eq(obj_d, obj) - assert test_eq(objI_d, objI) - - def test_drawing_vmapping_and_jitting_gaussian_psf(): gsparams = galsim.GSParams(minimum_fft_size=512, maximum_fft_size=512) @@ -227,34 +207,3 @@ def drawGalaxy(flux): assert arr.shape[0] == 2 assert arr.shape[1] == arr.shape[2] == 128 assert arr[0].sum() < arr[1].sum() - - -def test_bounds_includes_vmapping(): - # See https://github.com/GalSim-developers/JAX-GalSim/issues/190#issuecomment-4031602051 - # for the source of the test code - b0 = galsim.BoundsI(1, 128, 1, 128) - b1 = galsim.BoundsI(32, 98, 32, 98) - b2 = galsim.BoundsI(-1, 10, 5, 200) - - # bounds array - bnd_list = [b1, b2] - bnd_array = jax.tree.map(lambda *vals: jnp.array(vals), *bnd_list) - res = jax.vmap(lambda x: b0.includes(x))(bnd_array) - res_jit = jax.jit(jax.vmap(lambda x: b0.includes(x)))(bnd_array) - np.testing.assert_array_equal(res, res_jit) - np.testing.assert_array_equal(res, np.array([True, False])) - - # position objects - pos_list = [galsim.PositionD(4, 10), galsim.PositionD(-4, -20)] - pos_array = jax.tree.map(lambda *vals: jnp.array(vals), *pos_list) - res = jax.vmap(lambda x: b0.includes(x))(pos_array) - res_jit = jax.jit(jax.vmap(lambda x: b0.includes(x)))(pos_array) - np.testing.assert_array_equal(res, res_jit) - np.testing.assert_array_equal(res, np.array([True, False])) - - # position arrays - shape is (n_points, 2) - pos_array = jnp.array([[4.0, -4.0, 7.0], [10.0, -20.0, 8.0]]).T - res = jax.vmap(lambda x: b0.includes(*x))(pos_array) - res_jit = jax.jit(jax.vmap(lambda x: b0.includes(*x)))(pos_array) - np.testing.assert_array_equal(res, res_jit) - np.testing.assert_array_equal(res, np.array([True, False, True])) From 7c5d5b8f6382d5a01d6b0b60ce881ba4d8707b64 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 14 Mar 2026 21:00:41 -0500 Subject: [PATCH 10/16] fix: put back int casts --- jax_galsim/core/utils.py | 38 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c2f943e1..c4cc5413 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -115,28 +115,24 @@ def cast_to_int(x): return int(x) except Exception: try: - return jnp.asarray(x, dtype=int) + if not jnp.any(jnp.isnan(x)): + return jnp.asarray(x, dtype=int) + else: + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like + if type(x) is object: + return x + else: + return 1 * x except Exception: - return x - # try: - # if not jnp.any(jnp.isnan(x)): - # return jnp.asarray(x, dtype=int) - # else: - # # this will return the same value for anything int-like that - # # cannot be cast to int - # # however, it will raise an error if something is not int-like - # if type(x) is object: - # return x - # else: - # return 1 * x - # except Exception: - # # this will return the same value for anything int-like that - # # cannot be cast to int - # # however, it will raise an error if something is not int-like - # if type(x) is object: - # return x - # else: - # return 1 * x + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like + if type(x) is object: + return x + else: + return 1 * x def is_equal_with_arrays(x, y): From 290a7d746fd346e564c38fd7505910dac3b7d35e Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 14 Mar 2026 21:02:29 -0500 Subject: [PATCH 11/16] fix: try simpler casting --- jax_galsim/core/utils.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index c4cc5413..d818889b 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -97,7 +97,7 @@ def cast_to_float(x): return float(x) except Exception: try: - return jnp.asarray(x, dtype=float) + return jnp.astype(x, dtype=float) except Exception: # this will return the same value for anything float-like that # cannot be cast to float @@ -115,20 +115,12 @@ def cast_to_int(x): return int(x) except Exception: try: - if not jnp.any(jnp.isnan(x)): - return jnp.asarray(x, dtype=int) - else: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x + return jnp.astype(x, dtype=int) except Exception: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like + # this will return the same value for anything float-like that + # cannot be cast to float + # however, it will raise an error if something is not float-like + # we exclude object types since they are used in JAX tracing if type(x) is object: return x else: From c342d0718da3920a7102b3204947a86c7c50f53a Mon Sep 17 00:00:00 2001 From: beckermr Date: Sat, 14 Mar 2026 21:06:39 -0500 Subject: [PATCH 12/16] fix: more fixes for casts --- jax_galsim/bounds.py | 66 ++++++++++++++++++++++++++-------------- jax_galsim/core/utils.py | 6 ++-- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 3bd18861..54b18665 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -81,28 +81,6 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - if not ( - isinstance( - self.xmin, - (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), - ) - and isinstance( - self.xmax, - (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), - ) - and isinstance( - self.ymin, - (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), - ) - and isinstance( - self.ymax, - (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), - ) - ): - raise ValueError( - "BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!" - ) - if not ( float(self.xmin) <= float(self.xmax) and float(self.ymin) <= float(self.ymax) @@ -315,6 +293,28 @@ def __init__(self, *args, **kwargs): self.ymin = float(self.ymin) self.ymax = float(self.ymax) + if not ( + isinstance( + self.xmin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.xmax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + ): + raise ValueError( + "BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!" + ) + def _check_scalar(self, x, name): try: if ( @@ -364,6 +364,28 @@ def __init__(self, *args, **kwargs): self.ymin = int(self.ymin) self.ymax = int(self.ymax) + if not ( + isinstance( + self.xmin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.xmax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymin, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + and isinstance( + self.ymax, + (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64), + ) + ): + raise ValueError( + "BoundsI/D classes must use python ints/floats or numpy values in JAX-GalSim!" + ) + def _check_scalar(self, x, name): try: if ( diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index d818889b..4c9a2091 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -117,9 +117,9 @@ def cast_to_int(x): try: return jnp.astype(x, dtype=int) except Exception: - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like + # this will return the same value for anything int-like that + # cannot be cast to int + # however, it will raise an error if something is not int-like # we exclude object types since they are used in JAX tracing if type(x) is object: return x From 92faa0b19c2ed50a9d42bb46c9931e60776722fc Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 15 Mar 2026 05:39:33 -0500 Subject: [PATCH 13/16] fix: put back some changes not needed --- jax_galsim/bounds.py | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 54b18665..e4103493 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -104,19 +104,17 @@ def includes(self, *args): return ( self.isDefined() and b.isDefined() - and (self.xmin <= b.xmin) - and (self.xmax >= b.xmax) - and (self.ymin <= b.ymin) - and (self.ymax >= b.ymax) + and self.xmin <= b.xmin + and self.xmax >= b.xmax + and self.ymin <= b.ymin + and self.ymax >= b.ymax ) elif isinstance(args[0], Position): p = args[0] return ( self.isDefined() - and (self.xmin <= p.x) - and (p.x <= self.xmax) - and (self.ymin <= p.y) - and (p.y <= self.ymax) + and self.xmin <= p.x <= self.xmax + and self.ymin <= p.y <= self.ymax ) else: raise TypeError("Invalid argument %s" % args[0]) @@ -124,10 +122,8 @@ def includes(self, *args): x, y = args return ( self.isDefined() - and (self.xmin <= x) - and (x <= self.xmax) - and (self.ymin <= y) - and (y <= self.ymax) + and self.xmin <= float(x) <= self.xmax + and self.ymin <= float(y) <= self.ymax ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") @@ -344,18 +340,12 @@ class BoundsI(Bounds): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - # for simple inputs, we can check if the bounds are valid ints + if ( - isinstance(self.xmin, (float, int)) - and isinstance(self.xmax, (float, int)) - and isinstance(self.ymin, (float, int)) - and isinstance(self.ymax, (float, int)) - and ( - self.xmin != int(self.xmin) - or self.xmax != int(self.xmax) - or self.ymin != int(self.ymin) - or self.ymax != int(self.ymax) - ) + self.xmin != int(self.xmin) + or self.xmax != int(self.xmax) + or self.ymin != int(self.ymin) + or self.ymax != int(self.ymax) ): raise TypeError("BoundsI must be initialized with integer values") From ab4c05221d18ff08ce7a6ad029d056f21825ef37 Mon Sep 17 00:00:00 2001 From: beckermr Date: Sun, 15 Mar 2026 10:18:12 -0500 Subject: [PATCH 14/16] test: update test suite --- tests/jax/test_api.py | 6 ++++++ tests/jax/test_render_scene.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 622e3225..e8d4b0d1 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -494,6 +494,10 @@ def _reg_sfun(g1): [ jax_galsim.BoundsD(), jax_galsim.BoundsI(), + jax_galsim.BoundsD( + jnp.array(0.2), jnp.array(4.0), jnp.array(-0.5), jnp.array(4.7) + ), + jax_galsim.BoundsI(jnp.array(-10), jnp.array(5), jnp.array(0), jnp.array(7)), jax_galsim.BoundsD(0.2, 4.0, -0.5, 4.7), jax_galsim.BoundsI(-10, 5, 0, 7), ], @@ -503,6 +507,8 @@ def test_api_bounds(obj): _run_object_checks(obj, obj.__class__, "pickle-eval-repr") _run_object_checks(obj, obj.__class__, "to-from-galsim") + assert isinstance(obj.xmin, (float, int)) + # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 09dbf61a..793dfb74 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -50,7 +50,7 @@ def _generate_image(rng_key, psf, n_obj): def test_render_scene_draw_many_ffts_full_img(): psf = jgs.Gaussian(fwhm=0.9) - img = _generate_image(jrng.key(10), psf, 50) + img = _generate_image(jrng.key(10), psf, 5) if False: import pdb @@ -60,8 +60,8 @@ def test_render_scene_draw_many_ffts_full_img(): plt.imshow(img.array.sum(axis=0)) pdb.set_trace() - assert img.array.shape == (50, 200, 200) - assert img.array.sum() > 50.0 + assert img.array.shape == (5, 200, 200) + assert img.array.sum() > 5.0 def _get_bd_jgs( From 4977ef1db62e310ec5100f222d3bda16864a64a0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Tue, 17 Mar 2026 09:08:04 -0500 Subject: [PATCH 15/16] feat: better notes on rendering offsets --- tests/jax/test_render_scene.py | 100 ++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 793dfb74..a3c832a0 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import jax.random as jrng import numpy as np +import pytest import jax_galsim as jgs @@ -112,14 +113,21 @@ def _draw_stamp_jgs( # you have to render just with on offset in order to keep the bounds # static during rendering - # the exact pixel computation here is MAGIC right now - # we'll need a way to make this easier - dx = image_pos.x - jnp.ceil(image_pos.x) - dy = image_pos.y - jnp.ceil(image_pos.y) - dx = dx + 0.5 * ((slen + 1) % 2) - dy = dy + 0.5 * ((slen + 1) % 2) + # here dx,dy is the offset to the nearest pixel + # we then render with use_true_center = False to ensure the offset is + # applied relative to a pixel center for all image dimensions, including + # even ones. + # this means the object is offset by (dx,dy) from stamp.bounds.center + dx = image_pos.x - jnp.floor(image_pos.x + 0.5) + dy = image_pos.y - jnp.floor(image_pos.y + 0.5) + stamp = convolved_object.drawImage( - nx=slen, ny=slen, offset=(dx, dy), wcs=local_wcs, dtype=jnp.float64 + nx=slen, + ny=slen, + offset=(dx, dy), + wcs=local_wcs, + dtype=jnp.float64, + use_true_center=False, ) return stamp @@ -130,15 +138,36 @@ def _add_to_image(carry, x, slen): image = carry[0] stamp, image_pos = x - # then we apply a shift to get the correct final bounds + # then we apply a shift to the stamp get the correct final location + # above we rendered at the location xs, ys = (dx,dy) + stamp.bounds.center + # in the image.bounds coordinates, the location (xs,ys) should be + # + # (xs - stamp.bounds.xmin) + shift.x = image_pos.x - image.bounds.xmin + # + # the logic here is that the offset of the object in array indices in the final + # image should be equal to the shift in array indices of the stamo plus the offset + # in array indicies of the stamp. + # we then get for x + # shift.x = image_pos.x - image.bounds.xmin - xs + stamp.bounds.xmin + # = image_pos.x - dx - stamp.bounds.center.x + stamp.bounds.xmin - image.bounds.xmin + # = image_pos.x - (image_pos.x - jnp.floor(image_pos.x + 0.5)) - stamp.bounds.center.x + stamp.bounds.xmin - image.bounds.xmin + # = jnp.floor(image_pos.x + 0.5) - stamp.bounds.center.x + stamp.bounds.xmin - image.bounds.xmin shift = jgs.PositionI( - jnp.int32(jnp.floor(image_pos.x + 0.5 - stamp.bounds.true_center.x)), - jnp.int32(jnp.floor(image_pos.y + 0.5 - stamp.bounds.true_center.y)), + jnp.int32( + jnp.floor(image_pos.x + 0.5) + - stamp.bounds.center.x + + stamp.bounds.xmin + - image.bounds.xmin + ), + jnp.int32( + jnp.floor(image_pos.y + 0.5) + - stamp.bounds.center.y + + stamp.bounds.ymin + - image.bounds.ymin + ), ) - i1 = stamp.bounds.ymin + shift.y - image.ymin - j1 = stamp.bounds.xmin + shift.x - image.xmin - start_inds = (i1, j1) + start_inds = (shift.y, shift.x) subim = jax.lax.dynamic_slice(image.array, start_inds, (slen, slen)) subim = subim + stamp.array @@ -215,13 +244,13 @@ def _render_scene_stamps_galsim( return image -def test_render_scene_stamps(): +@pytest.mark.parametrize("slen", [51, 52]) +def test_render_scene_stamps(slen): image = jgs.Image(ncol=200, nrow=200, scale=0.2, dtype=jnp.float64) wcs = image.wcs rng = np.random.default_rng(seed=10) ng = 5 - slen = 52 fft_size = 2048 galaxy_params = { @@ -285,6 +314,35 @@ def test_render_scene_stamps(): ng, ) + gs_image_mo = _galsim.Image(ncol=200, nrow=200, scale=0.2, dtype=np.float64) + wcs = gs_image.wcs + + gs_image_positions = list( + map(lambda tup: _galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y)) + ) + gs_local_wcss = list(map(lambda x: wcs.local(image_pos=x), gs_image_positions)) + + _render_scene_stamps_galsim( + galaxy_params, + gs_image_positions, + gs_local_wcss, + fft_size, + slen + 1, + gs_image_mo, + ng, + ) + + abs_eps = 4.0 * np.max(np.abs(gs_image_mo.array - gs_image.array)) + rel_eps = 0.0 + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(gs_image_mo.array - gs_image.array) + pdb.set_trace() + if False: import pdb @@ -302,15 +360,15 @@ def test_render_scene_stamps(): pdb.set_trace() np.testing.assert_allclose( - gs_image.array.sum(), final_pad_image.array[slen:-slen, slen:-slen].sum(), - atol=1e-4, - rtol=1e-5, + gs_image.array.sum(), + atol=abs_eps, + rtol=rel_eps, ) np.testing.assert_allclose( - gs_image.array, final_pad_image.array[slen:-slen, slen:-slen], - atol=1e-6, - rtol=1e-6, + gs_image.array, + atol=abs_eps, + rtol=rel_eps, ) From 5b4648e5df8f1f996896516b83ef6228c2b42cfb Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 18 Mar 2026 17:36:18 -0600 Subject: [PATCH 16/16] fix: do not trace bounds --- jax_galsim/image.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index ab1e04b1..f6f0f518 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1061,8 +1061,8 @@ def rot_180(self): def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing - children = (self.array, self.wcs, self.bounds) - aux_data = {"dtype": self.dtype, "isconst": self.isconst} + children = (self.array, self.wcs) + aux_data = {"dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst} # other routines may add these attributes to images on the fly # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): @@ -1080,15 +1080,15 @@ def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj._array = children[0] obj.wcs = children[1] - obj._bounds = children[2] + obj._bounds = aux_data["bounds"] obj._dtype = aux_data["dtype"] obj._is_const = aux_data["isconst"] - if len(children) > 3: - obj.added_flux = children[3] + if len(children) > 2: + obj.added_flux = children[2] if "header" in aux_data: obj.header = aux_data["header"] - if len(children) > 4: - obj.photons = children[4] + if len(children) > 3: + obj.photons = children[3] return obj @classmethod