From 4c7798fc83c9f3a95daccc57edf28a17a5882e51 Mon Sep 17 00:00:00 2001 From: beckermr Date: Fri, 15 May 2026 05:24:19 -0500 Subject: [PATCH] test: add test for bounds isDefined --- jax_galsim/bounds.py | 5 ++--- tests/jax/test_bounds_jax.py | 21 +++++++++++++++++++++ tests/jax/test_render_scene.py | 1 + 3 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 tests/jax/test_bounds_jax.py diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 21b6ba8a..75badb84 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -6,7 +6,6 @@ from jax_galsim.core.utils import ( CONST_TYPES, cast_to_float, - cast_to_int, cast_to_python_float, check_is_int_then_cast, ensure_hashable, @@ -519,8 +518,8 @@ def __init__(self, *args, **kwargs): self.deltay = cast_to_python_float(self.deltay) if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") - self.deltax = int(cast_to_int(self.deltax)) - self.deltay = int(cast_to_int(self.deltay)) + self.deltax = int(self.deltax) + self.deltay = int(self.deltay) if has_tracers(self._xmin) or has_tracers(self._ymin): self._isstatic = False diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py new file mode 100644 index 00000000..1bcb8a89 --- /dev/null +++ b/tests/jax/test_bounds_jax.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp +import numpy as np + +import jax_galsim + + +@jax.vmap +@jax.jit +def _make_bounds(xmin, ymin): + bds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, deltax=10, deltay=10) + return bds, bds.isDefined() + + +def test_bounds_jax_vmap_isdefined(): + xmin = jnp.array([9, 10, 11]) + ymin = jnp.array([9, 10, 11]) + + bds, isdef = _make_bounds(xmin, ymin) + print(isdef, bds.isDefined()) + np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True) diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py index 064c9aef..077117ee 100644 --- a/tests/jax/test_render_scene.py +++ b/tests/jax/test_render_scene.py @@ -365,6 +365,7 @@ def test_render_scene_stamps(slen): # present in GalSim when drawing stamps that odd or even sized. abs_eps = np.max(np.abs(gs_image_mo.array - gs_image.array)) rel_eps = 0.0 + assert abs_eps < 5e-5 if False: import pdb