From 4077a18c9477614eb60558812c911bf56dabbfbb Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 05:45:55 -0500 Subject: [PATCH 01/37] fix: clean up type handling --- docs/sharp-bits.rst | 29 ++------- jax_galsim/core/utils.py | 124 ++++++++++----------------------------- 2 files changed, 37 insertions(+), 116 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 7901176e..031fa843 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -163,9 +163,6 @@ profile parameters passed into a ``jit``-compiled function): def good(sigma): return jax.lax.cond(sigma > 1.0, lambda s: s * 2, lambda s: s, sigma) -JAX-GalSim uses an internal ``has_tracers()`` utility to detect tracing and -avoid problematic control flow in its own implementations. - Fixed output shapes ^^^^^^^^^^^^^^^^^^^ @@ -197,20 +194,9 @@ The ``__init__`` gotcha During ``jit`` tracing, JAX calls constructors with **tracer objects** rather than concrete Python numbers. Type checks like ``isinstance(sigma, float)`` will -fail on tracers. JAX-GalSim handles this internally, but if you subclass any -JAX-GalSim object, be aware that ``__init__`` may receive tracers: - -.. code-block:: python - - from jax_galsim.core.utils import has_tracers - - class MyProfile(jax_galsim.GSObject): - def __init__(self, sigma, gsparams=None): - if not has_tracers(sigma): - # Only validate with concrete values - if sigma <= 0: - raise ValueError("sigma must be positive") - ... +return ``False`` on tracers, and you cannot check correctness of values (e.g., +``if sigma > 0: ...```). JAX-GalSim handles this internally, but if you subclass any +JAX-GalSim object, be aware that ``__init__`` may receive tracers. Profile Restrictions -------------------- @@ -221,12 +207,9 @@ Some GalSim features are not yet implemented in JAX-GalSim: - **ChromaticObject**: All chromatic functionality (wavelength-dependent profiles) is not available. - **InterpolatedKImage**: Not implemented. -- **Airy, Kolmogorov, OpticalPSF, RealGalaxy**: See :doc:`api-coverage` for +- **Airy, Kolmogorov, OpticalPSF, RealGalaxy, etc.**: See :doc:`api-coverage` for the full list. -The project currently implements **22.5 %** of the GalSim public API, focused -on the most commonly used profiles and operations. - Numerical Precision ------------------- @@ -249,11 +232,11 @@ These differences are typically at the level of floating-point round-off should not affect scientific conclusions. ⚠️ Additional Sharp Bits --------------------------- +------------------------ In the :doc:`api/index` you will find **🔪 JAX-GalSim - The Sharp Bits 🔪** blocks highlighting additional important caveats for specific classes and or methods. These could include things like: -- Many classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.GSObject.drawImage`). +- Some classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.GSObject.drawImage`). - Certain profiles might not be auto-differentiable with respect to some of their parameters (e.g., :class:`~jax_galsim.Spergel`, :class:`~jax_galsim.Moffat`) - Limitations regarding what types of inputes are handled (e.g., :meth:`~jax_galsim.Image.calculate_fft` does not accept complex dtypes.) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f6785ceb..79f9119a 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -86,113 +86,51 @@ def compute_major_minor_from_jacobian(jac): return major, minor -def _cast_to_array_scalar(x, dtype=None): - """Cast the input to an array scalar. Works on python scalars, iterables and jax arrays. - For iterables it always takes the first element after a call to .ravel()""" - if dtype is None: - if hasattr(x, "dtype"): - dtype = x.dtype - else: - dtype = float - - if isinstance(x, jax.Array): - return jnp.atleast_1d(x).astype(dtype).ravel()[0] - elif hasattr(x, "astype"): - return x.astype(dtype).ravel()[0] - else: - return jnp.atleast_1d(jnp.array(x, dtype=dtype)).ravel()[0] - - def cast_to_python_float(x): - """Cast the input to a python float. Works on python floats and jax arrays. - For jax arrays it always takes the first element after a call to .ravel()""" - if isinstance(x, jax.Array): - return _cast_to_array_scalar(x, dtype=float).item() + """Cast the input to a python float. Works on python int/floats + and jax/numpy arrays. Will raise an error for arrays with more than one value. + """ + if isinstance(x, (int, float, np.integer, np.floating)): + return float(x) + elif isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)) and ( + x.ndim == 0 or (x.ndim == 1 and x.shape[0] == 1) + ): + return float(x.item()) else: - try: - return float(x) - except TypeError: - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like - return 1.0 * x - except ValueError as e: - # we let NaNs through - if " NaN " in str(e): - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like - return 1.0 * x - else: - raise e + raise ValueError(f"Cannot convert object {x!r} to a python float!") def cast_to_python_int(x): - """Cast the input to a python int. Works on python ints and jax arrays. - For jax arrays it always takes the first element after a call to .ravel()""" - if isinstance(x, jax.Array): - return _cast_to_array_scalar(x, dtype=int).item() + """Cast the input to a python int. Works on python int/floats + and jax/numpy arrays. Will raise an error for arrays with more than one value, + or if it encounters NaNs. + """ + if isinstance(x, (int, float, np.integer, np.floating)): + return int(x) + elif isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)) and ( + x.ndim == 0 or (x.ndim == 1 and x.shape[0] == 1) + ): + return int(x.item()) else: - try: - return int(x) - except TypeError: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - return 1 * x - except ValueError as e: - # we let NaNs through - if " NaN " in str(e): - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - return 1 * x - else: - raise e + raise ValueError(f"Cannot convert object {x!r} to a python int!") def cast_to_float(x): - """Cast the input to a float. Works on python floats and jax arrays.""" - try: + """Cast the input to a float. Works on python floats, numpy scalars, and jax/numpy arrays.""" + if isinstance(x, (int, float, np.integer, np.floating)): return float(x) - except Exception: - try: - return jnp.asarray(x, dtype=float) - except Exception: - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like - # we exclude object types since they are used in JAX tracing - if type(x) is object: - return x - else: - return 1.0 * x + else: + # use the python `float` const/func here to promote to the highest + # precision available without emitting a warning in JAX + return jnp.astype(x, float) def cast_to_int(x): - """Cast the input to an int. Works on python floats/ints and jax arrays.""" - try: + """Cast the input to an int. Works on python floats, numpy scalars, and jax/numpy arrays.""" + if isinstance(x, (int, float, np.integer, np.floating)): return int(x) - except Exception: - try: - if not jnp.any(jnp.isnan(x)): - return jnp.asarray(x, dtype=int) - else: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x - except Exception: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x + else: + return jnp.astype(x, int) def is_equal_with_arrays(x, y): From 5f17c95a761334b0b8e573dd082945999c3b3c68 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 06:22:09 -0500 Subject: [PATCH 02/37] fix: allow strings optionally for floating casts --- jax_galsim/angle.py | 2 +- jax_galsim/core/utils.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index aec36d90..9569ebd9 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -55,7 +55,7 @@ class AngleUnit(object): def __init__(self, value): if isinstance(value, AngleUnit): raise TypeError("Cannot construct AngleUnit from another AngleUnit") - self._value = cast_to_float(value) + self._value = cast_to_float(value, accept_strings=True) @property @implements(_galsim.AngleUnit.value) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 79f9119a..6d10c8f5 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -115,9 +115,18 @@ def cast_to_python_int(x): raise ValueError(f"Cannot convert object {x!r} to a python int!") -def cast_to_float(x): - """Cast the input to a float. Works on python floats, numpy scalars, and jax/numpy arrays.""" - if isinstance(x, (int, float, np.integer, np.floating)): +def cast_to_float(x, accept_strings=False): + """Cast the input to a float. Works on python floats, numpy scalars, and jax/numpy arrays. + + Parameters: + accept_strings: If True, allow string to ``float`` conversion. [default: False] + + Returns: + Input value ``x`` casted to a ``float``. + """ + if isinstance(x, (int, float, np.integer, np.floating)) or ( + accept_strings and isinstance(x, str) + ): return float(x) else: # use the python `float` const/func here to promote to the highest From bb482c34de8e83dc660019ac7729054ca2cc0fd8 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 07:42:35 -0500 Subject: [PATCH 03/37] fix: more type cleanups --- jax_galsim/core/utils.py | 45 ++++++++++++++++---------------------- jax_galsim/photon_array.py | 29 +++++++++++------------- 2 files changed, 32 insertions(+), 42 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 6d10c8f5..6e261266 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -100,23 +100,17 @@ def cast_to_python_float(x): raise ValueError(f"Cannot convert object {x!r} to a python float!") -def cast_to_python_int(x): - """Cast the input to a python int. Works on python int/floats - and jax/numpy arrays. Will raise an error for arrays with more than one value, - or if it encounters NaNs. - """ - if isinstance(x, (int, float, np.integer, np.floating)): - return int(x) - elif isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)) and ( - x.ndim == 0 or (x.ndim == 1 and x.shape[0] == 1) +def _cast_to_type(x, typ, accept_strings=False): + if isinstance(x, (int, float, np.integer, np.floating)) or ( + accept_strings and isinstance(x, str) ): - return int(x.item()) + return typ(x) else: - raise ValueError(f"Cannot convert object {x!r} to a python int!") + return jnp.astype(x, typ) def cast_to_float(x, accept_strings=False): - """Cast the input to a float. Works on python floats, numpy scalars, and jax/numpy arrays. + """Cast the input to a float. Works on python floats/ints, numpy scalars, and jax/numpy arrays. Parameters: accept_strings: If True, allow string to ``float`` conversion. [default: False] @@ -124,22 +118,21 @@ def cast_to_float(x, accept_strings=False): Returns: Input value ``x`` casted to a ``float``. """ - if isinstance(x, (int, float, np.integer, np.floating)) or ( - accept_strings and isinstance(x, str) - ): - return float(x) - else: - # use the python `float` const/func here to promote to the highest - # precision available without emitting a warning in JAX - return jnp.astype(x, float) + # use the python `float` const/func here to promote to the highest + # precision available without emitting a warning in JAX + return _cast_to_type(x, float, accept_strings=accept_strings) -def cast_to_int(x): - """Cast the input to an int. Works on python floats, numpy scalars, and jax/numpy arrays.""" - if isinstance(x, (int, float, np.integer, np.floating)): - return int(x) - else: - return jnp.astype(x, int) +def cast_to_int(x, accept_strings=False): + """Cast the input to an int. Works on python floats/ints, numpy scalars, and jax/numpy arrays. + + Parameters: + accept_strings: If True, allow string to ``int`` conversion. [default: False] + + Returns: + Input value ``x`` casted to an ``int``. + """ + return _cast_to_type(x, int, accept_strings=accept_strings) def is_equal_with_arrays(x, y): diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 967f55c3..196af69a 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -8,7 +8,7 @@ from jax_galsim.core.utils import ( cast_numpy_array_to_native_byte_order, - cast_to_python_int, + cast_to_int, implements, ) from jax_galsim.errors import ( @@ -91,19 +91,16 @@ def __init__( _nokeep=None, ): self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N - if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None: - try: - # this will raise a boolean conversion error in JAX - # which we swallow - err_cond = (N > _JAX_GALSIM_PHOTON_ARRAY_SIZE) or False - except Exception: - err_cond = False - - if err_cond: - raise GalSimValueError( - f"The given photon array size {N} is larger than " - f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." - ) + self._Ntot = cast_to_int(self._Ntot) + + if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None and ( + self._Ntot > _JAX_GALSIM_PHOTON_ARRAY_SIZE + ): + raise GalSimValueError( + f"The given photon array size {self._Ntot} is larger than " + f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." + ) + if _nokeep is not None: self._nokeep = _nokeep else: @@ -820,7 +817,7 @@ def __repr__(self): import numpy as np s = "galsim.PhotonArray(%r, x=array(%r), y=array(%r), flux=array(%r)" % ( - cast_to_python_int(self.size()), + self.size(), np.array(self.x).tolist(), np.array(self.y).tolist(), np.array(self.flux).tolist(), @@ -844,7 +841,7 @@ def __repr__(self): return s def __str__(self): - return "galsim.PhotonArray(%r)" % cast_to_python_int(self.size()) + return "galsim.PhotonArray(%r)" % self.size() __hash__ = None From 434f80020c9c5c58832160784eadba6b45ff5c70 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 07:43:26 -0500 Subject: [PATCH 04/37] Apply suggestion from @beckermr --- jax_galsim/photon_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 196af69a..8ecd8ae9 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -96,7 +96,7 @@ def __init__( if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None and ( self._Ntot > _JAX_GALSIM_PHOTON_ARRAY_SIZE ): - raise GalSimValueError( + raise ValueError( f"The given photon array size {self._Ntot} is larger than " f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." ) From 76c5097cd1982221f75b5100c0b5d6329d550d19 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 08:22:41 -0500 Subject: [PATCH 05/37] fix: remove more weird casts and be more struct on types --- jax_galsim/bounds.py | 22 ++++++----- jax_galsim/core/utils.py | 57 ++------------------------- jax_galsim/fitswcs.py | 34 ++++++++-------- jax_galsim/transform.py | 15 ++++++-- jax_galsim/wcs.py | 83 ++++++++++++++++++++++++++++------------ 5 files changed, 105 insertions(+), 106 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 21b6ba8a..a02358b3 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,13 +1,12 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - CONST_TYPES, cast_to_float, cast_to_int, - cast_to_python_float, check_is_int_then_cast, ensure_hashable, has_tracers, @@ -139,8 +138,8 @@ def _parse_args(self, *args, **kwargs): else: max_delta = 1 if ( - isinstance(self.deltax, CONST_TYPES) - and isinstance(self.deltay, CONST_TYPES) + isinstance(self.deltax, (int, float, np.integer, np.floating)) + and isinstance(self.deltay, (int, float, np.integer, np.floating)) and (self.deltax < max_delta or self.deltay < max_delta) ): self._isdefined = False @@ -509,18 +508,21 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - if has_tracers(self.deltax) or has_tracers(self.deltay): + if not ( + isinstance(self.deltax, (int, float, np.integer, np.floating)) + and isinstance(self.deltay, (int, float, np.integer, np.floating)) + ): raise RuntimeError( - "Jax-GalSim BoundsI instances must have a fixed width! " + "Jax-GalSim BoundsI instances must have a fixed, static width! " f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) - self.deltax = cast_to_python_float(self.deltax) - self.deltay = cast_to_python_float(self.deltay) + self.deltax = cast_to_float(self.deltax) + self.deltay = cast_to_float(self.deltay) if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") - self.deltax = int(cast_to_int(self.deltax)) - self.deltay = int(cast_to_int(self.deltay)) + self.deltax = cast_to_int(self.deltax) + self.deltay = cast_to_int(self.deltay) if has_tracers(self._xmin) or has_tracers(self._ymin): self._isstatic = False diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 6e261266..7fd99ccc 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -9,39 +9,13 @@ import numpy as np from jax.tree_util import tree_flatten -CONST_TYPES = ( - float, - int, - np.ndarray, - np.int8, - np.int16, - np.int32, - np.int64, - np.float16, - np.float32, - np.float64, - np.complex64, - np.complex128, -) -CONST_TYPES_WITH_JAX = CONST_TYPES + ( - jax.Array, - jnp.ndarray, - jnp.int8, - jnp.int16, - jnp.int32, - jnp.int64, - jnp.float32, - jnp.float64, - jnp.complex64, - jnp.complex128, -) - def check_is_int_then_cast(val, msg): """Check if `val` is an integer, raise if not, otherwise cast to int.""" - # for simple inputs, we can check direct in python - if isinstance(val, CONST_TYPES) and not has_tracers(val): - val = cast_to_python_float(val) + val = cast_to_float(val) + + if isinstance(val, (int, float, np.integer, np.floating)): + # for simple inputs, we can check direct in python if val != int(val): raise TypeError(msg) val = int(val) @@ -77,29 +51,6 @@ def has_tracers(x): return False -@jax.jit -def compute_major_minor_from_jacobian(jac): - h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) - h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) - major = 0.5 * jnp.abs(h1 + h2) - minor = 0.5 * jnp.abs(h1 - h2) - return major, minor - - -def cast_to_python_float(x): - """Cast the input to a python float. Works on python int/floats - and jax/numpy arrays. Will raise an error for arrays with more than one value. - """ - if isinstance(x, (int, float, np.integer, np.floating)): - return float(x) - elif isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)) and ( - x.ndim == 0 or (x.ndim == 1 and x.shape[0] == 1) - ): - return float(x.item()) - else: - raise ValueError(f"Cannot convert object {x!r} to a python float!") - - def _cast_to_type(x, typ, accept_strings=False): if isinstance(x, (int, float, np.integer, np.floating)) or ( accept_strings and isinstance(x, str) diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index d64564ab..01654860 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -13,7 +13,6 @@ from jax_galsim.celestial import CelestialCoord from jax_galsim.core.utils import ( cast_to_float, - cast_to_python_float, ensure_hashable, implements, ) @@ -32,6 +31,7 @@ JacobianWCS, OffsetWCS, PixelScale, + _cast_to_python_float, ) ######################################################################################### @@ -754,16 +754,16 @@ def _writeHeader(self, header, bounds): header["GS_WCS"] = ("GSFitsWCS", "GalSim WCS name") header["CTYPE1"] = "RA---" + self.wcs_type header["CTYPE2"] = "DEC--" + self.wcs_type - header["CRPIX1"] = cast_to_python_float(self.crpix[0]) - header["CRPIX2"] = cast_to_python_float(self.crpix[1]) - header["CD1_1"] = cast_to_python_float(self.cd[0][0]) - header["CD1_2"] = cast_to_python_float(self.cd[0][1]) - header["CD2_1"] = cast_to_python_float(self.cd[1][0]) - header["CD2_2"] = cast_to_python_float(self.cd[1][1]) + header["CRPIX1"] = _cast_to_python_float(self.crpix[0]) + header["CRPIX2"] = _cast_to_python_float(self.crpix[1]) + header["CD1_1"] = _cast_to_python_float(self.cd[0][0]) + header["CD1_2"] = _cast_to_python_float(self.cd[0][1]) + header["CD2_1"] = _cast_to_python_float(self.cd[1][0]) + header["CD2_2"] = _cast_to_python_float(self.cd[1][1]) header["CUNIT1"] = "deg" header["CUNIT2"] = "deg" - header["CRVAL1"] = cast_to_python_float(self.center.ra / degrees) - header["CRVAL2"] = cast_to_python_float(self.center.dec / degrees) + header["CRVAL1"] = _cast_to_python_float(self.center.ra / degrees) + header["CRVAL2"] = _cast_to_python_float(self.center.dec / degrees) if self.pv is not None: order = len(self.pv[0]) - 1 k = 0 @@ -771,8 +771,8 @@ def _writeHeader(self, header, bounds): for n in range(order + 1): for j in range(n + 1): i = n - j - header["PV1_" + str(k)] = cast_to_python_float(self.pv[0, i, j]) - header["PV2_" + str(k)] = cast_to_python_float(self.pv[1, j, i]) + header["PV1_" + str(k)] = _cast_to_python_float(self.pv[0, i, j]) + header["PV2_" + str(k)] = _cast_to_python_float(self.pv[1, j, i]) k = k + 1 if k in odd_indices: k = k + 1 @@ -785,7 +785,9 @@ def _writeHeader(self, header, bounds): if i == 1 and j == 0: aij -= 1 # Turn back into standard form. if aij != 0.0: - header["A_" + str(i) + "_" + str(j)] = cast_to_python_float(aij) + header["A_" + str(i) + "_" + str(j)] = _cast_to_python_float( + aij + ) header["B_ORDER"] = order for i in range(order + 1): for j in range(order + 1): @@ -793,7 +795,9 @@ def _writeHeader(self, header, bounds): if i == 0 and j == 1: bij -= 1 if bij != 0.0: - header["B_" + str(i) + "_" + str(j)] = cast_to_python_float(bij) + header["B_" + str(i) + "_" + str(j)] = _cast_to_python_float( + bij + ) if self.abp is not None: order = len(self.abp[0]) - 1 header["AP_ORDER"] = order @@ -803,7 +807,7 @@ def _writeHeader(self, header, bounds): if i == 1 and j == 0: apij -= 1 if apij != 0.0: - header["AP_" + str(i) + "_" + str(j)] = cast_to_python_float( + header["AP_" + str(i) + "_" + str(j)] = _cast_to_python_float( apij ) header["BP_ORDER"] = order @@ -813,7 +817,7 @@ def _writeHeader(self, header, bounds): if i == 0 and j == 1: bpij -= 1 if bpij != 0.0: - header["BP_" + str(i) + "_" + str(j)] = cast_to_python_float( + header["BP_" + str(i) + "_" + str(j)] = _cast_to_python_float( bpij ) return header diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index bf9f4f6d..e78a67e3 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -1,9 +1,9 @@ import galsim as _galsim +import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - compute_major_minor_from_jacobian, ensure_hashable, implements, ) @@ -12,6 +12,15 @@ from jax_galsim.position import PositionD +@jax.jit +def _compute_major_minor_from_jacobian(jac): + h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) + h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) + major = 0.5 * jnp.abs(h1 + h2) + minor = 0.5 * jnp.abs(h1 - h2) + return major, minor + + @implements( _galsim.Transform, lax_description="Does not support Chromatic Objects or Convolutions.", @@ -277,12 +286,12 @@ def _kfactor(self, kx, ky): @property def _maxk(self): - _, minor = compute_major_minor_from_jacobian(self._jac) + _, minor = _compute_major_minor_from_jacobian(self._jac) return self._original.maxk / minor @property def _stepk(self): - major, _ = compute_major_minor_from_jacobian(self._jac) + major, _ = _compute_major_minor_from_jacobian(self._jac) stepk = self._original.stepk / major # If we have a shift, we need to further modify stepk # stepk = Pi/R diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 20e31eb3..8208da99 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,11 +1,12 @@ import galsim as _galsim +import jax import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import AngleUnit, arcsec, radians from jax_galsim.celestial import CelestialCoord -from jax_galsim.core.utils import cast_to_python_float, ensure_hashable, implements +from jax_galsim.core.utils import ensure_hashable, implements from jax_galsim.errors import GalSimValueError from jax_galsim.gsobject import GSObject from jax_galsim.position import Position, PositionD, PositionI @@ -13,6 +14,20 @@ from jax_galsim.transform import _Transform +def _cast_to_python_float(x): + """Cast the input to a python float. Works on python int/floats + and jax/numpy arrays. Will raise an error for arrays with more than one value. + """ + if isinstance(x, (int, float, np.integer, np.floating)): + return float(x) + elif isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)) and ( + x.ndim == 0 or (x.ndim == 1 and x.shape[0] == 1) + ): + return float(x.item()) + else: + raise ValueError(f"Cannot convert object {x!r} to a python float!") + + # We inherit from the reference BaseWCS and only redefine the methods that # make references to jax_galsim objects. @implements(_galsim.BaseWCS) @@ -910,7 +925,7 @@ def _toJacobian(self): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("PixelScale", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") return self.affine()._writeLinearWCS(header, bounds) @staticmethod @@ -1032,9 +1047,15 @@ def _newOrigin(self, origin, world_origin): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("ShearWCS", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") - header["GS_G1"] = (cast_to_python_float(self.shear.g1), "GalSim image shear g1") - header["GS_G2"] = (cast_to_python_float(self.shear.g2), "GalSim image shear g2") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") + header["GS_G1"] = ( + _cast_to_python_float(self.shear.g1), + "GalSim image shear g1", + ) + header["GS_G2"] = ( + _cast_to_python_float(self.shear.g2), + "GalSim image shear g2", + ) return self.affine()._writeLinearWCS(header, bounds) @implements(_galsim.wcs.ShearWCS.copy) @@ -1326,15 +1347,21 @@ def world_origin(self): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("OffsetWCS", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") - header["GS_X0"] = (cast_to_python_float(self.origin.x), "GalSim image origin x") - header["GS_Y0"] = (cast_to_python_float(self.origin.y), "GalSim image origin y") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") + header["GS_X0"] = ( + _cast_to_python_float(self.origin.x), + "GalSim image origin x", + ) + header["GS_Y0"] = ( + _cast_to_python_float(self.origin.y), + "GalSim image origin y", + ) header["GS_U0"] = ( - cast_to_python_float(self.world_origin.x), + _cast_to_python_float(self.world_origin.x), "GalSim world origin u", ) header["GS_V0"] = ( - cast_to_python_float(self.world_origin.y), + _cast_to_python_float(self.world_origin.y), "GalSim world origin v", ) return self.affine()._writeLinearWCS(header, bounds) @@ -1404,23 +1431,29 @@ def _newOrigin(self, origin, world_origin): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("OffsetShearWCS", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") - header["GS_G1"] = (cast_to_python_float(self.shear.g1), "GalSim image shear g1") - header["GS_G2"] = (cast_to_python_float(self.shear.g2), "GalSim image shear g2") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") + header["GS_G1"] = ( + _cast_to_python_float(self.shear.g1), + "GalSim image shear g1", + ) + header["GS_G2"] = ( + _cast_to_python_float(self.shear.g2), + "GalSim image shear g2", + ) header["GS_X0"] = ( - cast_to_python_float(self.origin.x), + _cast_to_python_float(self.origin.x), "GalSim image origin x coordinate", ) header["GS_Y0"] = ( - cast_to_python_float(self.origin.y), + _cast_to_python_float(self.origin.y), "GalSim image origin y coordinate", ) header["GS_U0"] = ( - cast_to_python_float(self.world_origin.x), + _cast_to_python_float(self.world_origin.x), "GalSim world origin u coordinate", ) header["GS_V0"] = ( - cast_to_python_float(self.world_origin.y), + _cast_to_python_float(self.world_origin.y), "GalSim world origin v coordinate", ) return self.affine()._writeLinearWCS(header, bounds) @@ -1504,25 +1537,25 @@ def _writeLinearWCS(self, header, bounds): header["CTYPE1"] = ("LINEAR", "name of the world coordinate axis") header["CTYPE2"] = ("LINEAR", "name of the world coordinate axis") header["CRVAL1"] = ( - cast_to_python_float(self.u0), + _cast_to_python_float(self.u0), "world coordinate at reference pixel = u0", ) header["CRVAL2"] = ( - cast_to_python_float(self.v0), + _cast_to_python_float(self.v0), "world coordinate at reference pixel = v0", ) header["CRPIX1"] = ( - cast_to_python_float(self.x0), + _cast_to_python_float(self.x0), "image coordinate of reference pixel = x0", ) header["CRPIX2"] = ( - cast_to_python_float(self.y0), + _cast_to_python_float(self.y0), "image coordinate of reference pixel = y0", ) - header["CD1_1"] = (cast_to_python_float(self.dudx), "CD1_1 = dudx") - header["CD1_2"] = (cast_to_python_float(self.dudy), "CD1_2 = dudy") - header["CD2_1"] = (cast_to_python_float(self.dvdx), "CD2_1 = dvdx") - header["CD2_2"] = (cast_to_python_float(self.dvdy), "CD2_2 = dvdy") + header["CD1_1"] = (_cast_to_python_float(self.dudx), "CD1_1 = dudx") + header["CD1_2"] = (_cast_to_python_float(self.dudy), "CD1_2 = dudy") + header["CD2_1"] = (_cast_to_python_float(self.dvdx), "CD2_1 = dvdx") + header["CD2_2"] = (_cast_to_python_float(self.dvdy), "CD2_2 = dvdy") return header @staticmethod From 3a6e4e098b64e539122addab03be01cfdcfbcde3 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 08:39:43 -0500 Subject: [PATCH 06/37] fix: ensure bounds deltax is numeric scalar --- jax_galsim/bounds.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index a02358b3..635d60d8 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -495,6 +495,20 @@ def __hash__(self): ) +def _cast_to_static_numeric_scalar(x, msg): + if isinstance(x, (int, float, np.integer, np.floating)): + return x + + if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): + if x.ndim == 0: + return x.item() + + if x.ndim == 1 and x.shape[0] == 1: + return x[0].item() + + raise RuntimeError(msg) + + @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsI(Bounds): @@ -508,14 +522,12 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - if not ( - isinstance(self.deltax, (int, float, np.integer, np.floating)) - and isinstance(self.deltay, (int, float, np.integer, np.floating)) - ): - raise RuntimeError( - "Jax-GalSim BoundsI instances must have a fixed, static width! " - f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." - ) + msg = ( + "Jax-GalSim BoundsI instances must have a fixed, static width! " + f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." + ) + self.deltax = _cast_to_static_numeric_scalar(self.deltax, msg) + self.deltay = _cast_to_static_numeric_scalar(self.deltay, msg) self.deltax = cast_to_float(self.deltax) self.deltay = cast_to_float(self.deltay) From cb1387f0bbd4a2d8967f8904ad01348792dac998 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 08:48:16 -0500 Subject: [PATCH 07/37] fix: handle photon array sizes as well --- jax_galsim/bounds.py | 22 +++------------------- jax_galsim/core/utils.py | 15 +++++++++++++++ jax_galsim/photon_array.py | 10 +++++++++- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 635d60d8..e1e9f866 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -7,6 +7,7 @@ from jax_galsim.core.utils import ( cast_to_float, cast_to_int, + cast_to_static_numeric_scalar, check_is_int_then_cast, ensure_hashable, has_tracers, @@ -495,20 +496,6 @@ def __hash__(self): ) -def _cast_to_static_numeric_scalar(x, msg): - if isinstance(x, (int, float, np.integer, np.floating)): - return x - - if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): - if x.ndim == 0: - return x.item() - - if x.ndim == 1 and x.shape[0] == 1: - return x[0].item() - - raise RuntimeError(msg) - - @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsI(Bounds): @@ -526,11 +513,8 @@ def __init__(self, *args, **kwargs): "Jax-GalSim BoundsI instances must have a fixed, static width! " f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) - self.deltax = _cast_to_static_numeric_scalar(self.deltax, msg) - self.deltay = _cast_to_static_numeric_scalar(self.deltay, msg) - - self.deltax = cast_to_float(self.deltax) - self.deltay = cast_to_float(self.deltay) + self.deltax = cast_to_float(cast_to_static_numeric_scalar(self.deltax, msg=msg)) + self.deltay = cast_to_float(cast_to_static_numeric_scalar(self.deltay, msg=msg)) if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") self.deltax = cast_to_int(self.deltax) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 7fd99ccc..71c8873b 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -51,6 +51,21 @@ def has_tracers(x): return False +def cast_to_static_numeric_scalar(x, msg=None): + if isinstance(x, (int, float, np.integer, np.floating)): + return x + + if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): + if x.ndim == 0: + return x.item() + + if x.ndim == 1 and x.shape[0] == 1: + return x[0].item() + + msg = msg or f"Cannot convert input {x!r} to a static, numeric scalar." + raise RuntimeError(msg) + + def _cast_to_type(x, typ, accept_strings=False): if isinstance(x, (int, float, np.integer, np.floating)) or ( accept_strings and isinstance(x, str) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 8ecd8ae9..ed11b566 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -9,6 +9,7 @@ from jax_galsim.core.utils import ( cast_numpy_array_to_native_byte_order, cast_to_int, + cast_to_static_numeric_scalar, implements, ) from jax_galsim.errors import ( @@ -91,7 +92,14 @@ def __init__( _nokeep=None, ): self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N - self._Ntot = cast_to_int(self._Ntot) + self._Ntot = cast_to_int( + cast_to_static_numeric_scalar( + self._Ntot, + msg=( + f"JAX_GalSim photon arrays must have static sizes., Got {self._Ntot!r}." + ), + ) + ) if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None and ( self._Ntot > _JAX_GALSIM_PHOTON_ARRAY_SIZE From 2ea2e88cf7ae6308120775698d69f851c8f902de Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 08:50:54 -0500 Subject: [PATCH 08/37] refactor: simpler code --- jax_galsim/wcs.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 8208da99..24a14b1d 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,12 +1,16 @@ import galsim as _galsim -import jax import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import AngleUnit, arcsec, radians from jax_galsim.celestial import CelestialCoord -from jax_galsim.core.utils import ensure_hashable, implements +from jax_galsim.core.utils import ( + cast_to_float, + cast_to_static_numeric_scalar, + ensure_hashable, + implements, +) from jax_galsim.errors import GalSimValueError from jax_galsim.gsobject import GSObject from jax_galsim.position import Position, PositionD, PositionI @@ -15,17 +19,7 @@ def _cast_to_python_float(x): - """Cast the input to a python float. Works on python int/floats - and jax/numpy arrays. Will raise an error for arrays with more than one value. - """ - if isinstance(x, (int, float, np.integer, np.floating)): - return float(x) - elif isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)) and ( - x.ndim == 0 or (x.ndim == 1 and x.shape[0] == 1) - ): - return float(x.item()) - else: - raise ValueError(f"Cannot convert object {x!r} to a python float!") + return cast_to_float(cast_to_static_numeric_scalar(x)) # We inherit from the reference BaseWCS and only redefine the methods that From 19fa5273151af2354c5065dc731726910f999135 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 08:57:56 -0500 Subject: [PATCH 09/37] fix: accept any array with one element --- jax_galsim/core/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 71c8873b..5cf52204 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -59,8 +59,8 @@ def cast_to_static_numeric_scalar(x, msg=None): if x.ndim == 0: return x.item() - if x.ndim == 1 and x.shape[0] == 1: - return x[0].item() + if all(sv for sv in x.shape == 1): + return x.ravel()[0].item() msg = msg or f"Cannot convert input {x!r} to a static, numeric scalar." raise RuntimeError(msg) From edad337a5535a84958931126ca08144beb5c3ed9 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 09:05:23 -0500 Subject: [PATCH 10/37] fix: simpler --- jax_galsim/bounds.py | 9 ++------- jax_galsim/core/utils.py | 15 --------------- jax_galsim/photon_array.py | 10 +--------- jax_galsim/wcs.py | 19 +++++++++++++++++-- 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index e1e9f866..7de6b8d2 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -7,7 +7,6 @@ from jax_galsim.core.utils import ( cast_to_float, cast_to_int, - cast_to_static_numeric_scalar, check_is_int_then_cast, ensure_hashable, has_tracers, @@ -509,12 +508,8 @@ def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) - msg = ( - "Jax-GalSim BoundsI instances must have a fixed, static width! " - f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." - ) - self.deltax = cast_to_float(cast_to_static_numeric_scalar(self.deltax, msg=msg)) - self.deltay = cast_to_float(cast_to_static_numeric_scalar(self.deltay, msg=msg)) + self.deltax = cast_to_float(self.deltax) + self.deltay = cast_to_float(self.deltay) if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") self.deltax = cast_to_int(self.deltax) diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 5cf52204..7fd99ccc 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -51,21 +51,6 @@ def has_tracers(x): return False -def cast_to_static_numeric_scalar(x, msg=None): - if isinstance(x, (int, float, np.integer, np.floating)): - return x - - if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): - if x.ndim == 0: - return x.item() - - if all(sv for sv in x.shape == 1): - return x.ravel()[0].item() - - msg = msg or f"Cannot convert input {x!r} to a static, numeric scalar." - raise RuntimeError(msg) - - def _cast_to_type(x, typ, accept_strings=False): if isinstance(x, (int, float, np.integer, np.floating)) or ( accept_strings and isinstance(x, str) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index ed11b566..8ecd8ae9 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -9,7 +9,6 @@ from jax_galsim.core.utils import ( cast_numpy_array_to_native_byte_order, cast_to_int, - cast_to_static_numeric_scalar, implements, ) from jax_galsim.errors import ( @@ -92,14 +91,7 @@ def __init__( _nokeep=None, ): self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N - self._Ntot = cast_to_int( - cast_to_static_numeric_scalar( - self._Ntot, - msg=( - f"JAX_GalSim photon arrays must have static sizes., Got {self._Ntot!r}." - ), - ) - ) + self._Ntot = cast_to_int(self._Ntot) if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None and ( self._Ntot > _JAX_GALSIM_PHOTON_ARRAY_SIZE diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 24a14b1d..d6ea6ea2 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,4 +1,5 @@ import galsim as _galsim +import jax import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class @@ -7,7 +8,6 @@ from jax_galsim.celestial import CelestialCoord from jax_galsim.core.utils import ( cast_to_float, - cast_to_static_numeric_scalar, ensure_hashable, implements, ) @@ -18,8 +18,23 @@ from jax_galsim.transform import _Transform +def _cast_to_static_numeric_scalar(x, msg=None): + if isinstance(x, (int, float, np.integer, np.floating)): + return x + + if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): + if x.ndim == 0: + return x.item() + + if all(sv for sv in x.shape == 1): + return x.ravel()[0].item() + + msg = msg or f"Cannot convert input {x!r} to a static, numeric scalar." + raise RuntimeError(msg) + + def _cast_to_python_float(x): - return cast_to_float(cast_to_static_numeric_scalar(x)) + return cast_to_float(_cast_to_static_numeric_scalar(x)) # We inherit from the reference BaseWCS and only redefine the methods that From f7aa2a68d5fd6724136c10227ad68e8c2626f396 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 09:08:05 -0500 Subject: [PATCH 11/37] doc: add comment --- jax_galsim/wcs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index d6ea6ea2..6dcec13b 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -18,6 +18,9 @@ from jax_galsim.transform import _Transform +# this function casts input values to python numeric values +# this kind of casting is only done for writing FITS headers +# and should never be done anywhere else in the code base def _cast_to_static_numeric_scalar(x, msg=None): if isinstance(x, (int, float, np.integer, np.floating)): return x From caecf66643786949ae470eedea214f6c86ad8c8b Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 09:18:30 -0500 Subject: [PATCH 12/37] fix: start to remove has_tracers func --- jax_galsim/bounds.py | 6 ++++-- jax_galsim/moffat.py | 8 ++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 7de6b8d2..aa1f9285 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -9,7 +9,6 @@ cast_to_int, check_is_int_then_cast, ensure_hashable, - has_tracers, implements, ) from jax_galsim.position import Position, PositionD, PositionI @@ -515,7 +514,10 @@ def __init__(self, *args, **kwargs): self.deltax = cast_to_int(self.deltax) self.deltay = cast_to_int(self.deltay) - if has_tracers(self._xmin) or has_tracers(self._ymin): + if not ( + isinstance(self._xmin, (int, float, np.floating, np.integer)) + and isinstance(self._ymin, (int, float, np.floating, np.integer)) + ): self._isstatic = False # validate inputs are ints diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 932c48b8..13ea2a37 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -13,7 +13,6 @@ from jax_galsim.core.math import safe_sqrt from jax_galsim.core.utils import ( ensure_hashable, - has_tracers, implements, ) from jax_galsim.gsobject import GSObject @@ -59,13 +58,10 @@ def __init__( # let define beta_thr a threshold to trigger the truncature self._beta_thr = 1.1 - if has_tracers(trunc) or ( - isinstance(trunc, (np.ndarray, float, jnp.ndarray, int)) - and np.any(trunc != 0) - ): + if (not isinstance(trunc, (float, int))) or trunc != 0: raise ValueError( "JAX-GalSim does not support truncated Moffat profiles " - f"(got trunc={repr(trunc)}, always pass the constant 0.0)!" + f"(got trunc={trunc!r}, always pass the constant 0.0)!" ) if isinstance(beta, (float, int)): From d5ee29205a69137e03b44a3c25b2c13391a8ded0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 09:26:03 -0500 Subject: [PATCH 13/37] fix: accept array inputs for boundsI static --- jax_galsim/bounds.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index aa1f9285..bc13e9d8 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -515,8 +515,14 @@ def __init__(self, *args, **kwargs): self.deltay = cast_to_int(self.deltay) if not ( - isinstance(self._xmin, (int, float, np.floating, np.integer)) - and isinstance(self._ymin, (int, float, np.floating, np.integer)) + isinstance( + self._xmin, + (int, float, np.floating, np.integer, np.ndarray, jnp.ndarray), + ) + and isinstance( + self._ymin, + (int, float, np.floating, np.integer, np.ndarray, jnp.ndarray), + ) ): self._isstatic = False @@ -534,7 +540,7 @@ def __init__(self, *args, **kwargs): if force_static and not self._isstatic: raise RuntimeError( "BoundsI initialized with non-static " - f"data (xmin,ymin = {self._xmin},{self._yminb}) " + f"data (xmin,ymin = {self._xmin!r},{self._ymin!r}) " "when static data was explicitly requested." ) From bb313a4b4138280f21f4e472c125f3be1dbab4bf Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 09:28:33 -0500 Subject: [PATCH 14/37] Apply suggestion from @beckermr --- jax_galsim/photon_array.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 8ecd8ae9..2c8ff23f 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -93,14 +93,6 @@ def __init__( self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N self._Ntot = cast_to_int(self._Ntot) - if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None and ( - self._Ntot > _JAX_GALSIM_PHOTON_ARRAY_SIZE - ): - raise ValueError( - f"The given photon array size {self._Ntot} is larger than " - f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." - ) - if _nokeep is not None: self._nokeep = _nokeep else: From da4db037afdebac2ecd9ede2d985b2e9ac73e787 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 10:53:41 -0500 Subject: [PATCH 15/37] fix: remove more has_tracers --- jax_galsim/bounds.py | 23 +++++++++++++++-------- jax_galsim/image.py | 18 +++++------------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index bc13e9d8..bb4d9071 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -499,6 +499,20 @@ def __hash__(self): class BoundsI(Bounds): _pos_class = PositionI + @staticmethod + def value_is_static(val): + """Return ``True`` if ``value`` is a static constant, ``False`` otherwise. + + This static method is used to test ``xmin`` and ``ymin`` to detect if a ``BoundsI`` + instance has a constant offset. The method is attached to the ``BoundsI`` instance + so that other classes can use it to detect static inputs to ``BoundsI`` classes + consistently. + """ + return isinstance( + val, + (int, float, np.floating, np.integer), + ) + def __init__(self, *args, **kwargs): # initial setting to let stuff pass through freely self._isstatic = True @@ -515,14 +529,7 @@ def __init__(self, *args, **kwargs): self.deltay = cast_to_int(self.deltay) if not ( - isinstance( - self._xmin, - (int, float, np.floating, np.integer, np.ndarray, jnp.ndarray), - ) - and isinstance( - self._ymin, - (int, float, np.floating, np.integer, np.ndarray, jnp.ndarray), - ) + BoundsI.value_is_static(self._xmin) and BoundsI.value_is_static(self._ymin) ): self._isstatic = False diff --git a/jax_galsim/image.py b/jax_galsim/image.py index d7eac192..43702acc 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -198,12 +198,7 @@ def __init__(self, *args, **kwargs): ncol = int(ncol) nrow = int(nrow) self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype) - if not has_tracers(xmin) and not has_tracers(ymin): - self._bounds = BoundsI( - xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True - ) - else: - self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) + self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value: self._array = self._array.at[...].add(init_value) elif bounds is not None: @@ -216,12 +211,7 @@ def __init__(self, *args, **kwargs): elif array is not None: self._array = array.view() nrow, ncol = array.shape - if not has_tracers(xmin) and not has_tracers(ymin): - self._bounds = BoundsI( - xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True - ) - else: - self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) + self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot specify init_value with array", @@ -333,7 +323,9 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined() and not has_tracers(self.array): + if self.bounds.isDefined() and isinstance( + self.array, (np.ndarray, jnp.ndarray, jax.Array) + ): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: From 7bda500b00c47053a859442c32f63136a6c8291e Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 11:05:18 -0500 Subject: [PATCH 16/37] fix: more cleanup of has_tracers --- jax_galsim/image.py | 34 +++++++--------------------------- tests/jax/test_api.py | 1 + 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 43702acc..2c78b5e6 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -280,22 +280,14 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if ( - check_bounds - and b.isDefined() - and not has_tracers(b.xmin) - and not has_tracers(b.ymin) - and not has_tracers(b.xmax) - and not has_tracers(b.ymax) - ): - # We need to disable this when jitting - if b.xmax - b.xmin + 1 != array.shape[1]: + if check_bounds and b.isDefined(): + if b.deltax != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( "Shape of array is inconsistent with provided bounds", array=array, bounds=b, ) - if b.ymax - b.ymin + 1 != array.shape[0]: + if b.deltay != array.shape[0]: raise _galsim.GalSimIncompatibleValuesError( "Shape of array is inconsistent with provided bounds", array=array, @@ -588,14 +580,8 @@ def subImage(self, bounds): "Attempt to access subImage of undefined image" ) if ( - not has_tracers(self.bounds.xmin) - and not has_tracers(self.bounds.xmax) - and not has_tracers(self.bounds.ymin) - and not has_tracers(self.bounds.ymax) - and not has_tracers(bounds.xmin) - and not has_tracers(bounds.xmax) - and not has_tracers(bounds.ymin) - and not has_tracers(bounds.ymax) + self.bounds.isStatic() + and bounds.isStatic() and not self.bounds.includes(bounds) ): raise _galsim.GalSimBoundsError( @@ -633,14 +619,8 @@ def setSubImage(self, bounds, rhs): "Attempt to access values of an undefined image" ) if ( - not has_tracers(self.bounds.xmin) - and not has_tracers(self.bounds.xmax) - and not has_tracers(self.bounds.ymin) - and not has_tracers(self.bounds.ymax) - and not has_tracers(bounds.xmin) - and not has_tracers(bounds.xmax) - and not has_tracers(bounds.ymin) - and not has_tracers(bounds.ymax) + self.bounds.isStatic() + and bounds.isStatic() and not self.bounds.includes(bounds) ): raise _galsim.GalSimBoundsError( diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e76b081c..26f20585 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -360,6 +360,7 @@ def _reg_fun(p): "xmax", "ymax", "isStatic", + "value_is_static", ]: continue From d54e844558fb6e28fe19e0586f2e56f92db24da2 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 11:22:38 -0500 Subject: [PATCH 17/37] fix: remove use of has_tracers --- jax_galsim/angle.py | 7 ++-- jax_galsim/core/utils.py | 9 ----- jax_galsim/image.py | 77 +++++++++++++++++++++++++--------------- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 9569ebd9..88b749a2 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -27,7 +27,6 @@ from jax_galsim.core.utils import ( cast_to_float, ensure_hashable, - has_tracers, implements, ) @@ -199,7 +198,7 @@ def __sub__(self, other): return _Angle(self._rad - other._rad) def __mul__(self, other): - if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)): + if isinstance(other, (Angle, AngleUnit)): raise TypeError( "Cannot multiply Angle by %s of type %s" % (other, type(other)) ) @@ -210,11 +209,11 @@ def __mul__(self, other): def __div__(self, other): if isinstance(other, AngleUnit): return self._rad / other.value - elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES): + elif not isinstance(other, Angle): return _Angle(self._rad / other) else: raise TypeError( - "Cannot divide Angle by %s of type %s" % (other, type(other)) + "Cannot multiply Angle by %s of type %s" % (other, type(other)) ) __truediv__ = __div__ diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 7fd99ccc..839d0658 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp import numpy as np -from jax.tree_util import tree_flatten def check_is_int_then_cast(val, msg): @@ -43,14 +42,6 @@ def cast_numpy_array_to_native_byte_order(arr): return arr.astype(arr.dtype.newbyteorder("=")) -def has_tracers(x): - """Return True if the input item is a JAX tracer or object, False otherwise.""" - for item in tree_flatten(x)[0]: - if isinstance(item, jax.core.Tracer) or type(item) is object: - return True - return False - - def _cast_to_type(x, typ, accept_strings=False): if isinstance(x, (int, float, np.integer, np.floating)) or ( accept_strings and isinstance(x, str) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 2c78b5e6..d94fd37a 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -8,7 +9,6 @@ from jax_galsim.core.utils import ( cast_numpy_array_to_native_byte_order, ensure_hashable, - has_tracers, implements, ) from jax_galsim.errors import GalSimImmutableError @@ -103,12 +103,10 @@ def __init__(self, *args, **kwargs): else: if "array" in kwargs: array = kwargs.pop("array") - if has_tracers(array) or isinstance(array, jnp.ndarray): - pass - elif isinstance(array, np.ndarray): + if isinstance(array, np.ndarray): array = jnp.array(cast_numpy_array_to_native_byte_order(array)) else: - raise TypeError("Unable to parse %s as an array." % array) + array = jnp.asarray(array) array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds @@ -716,37 +714,58 @@ def __setitem__(self, *args): def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + def _raise_if_nonzero(bnds, x_or_y, msg): + if x_or_y == "x": + if bnds.isStatic() and bnds.xmin != 0: + raise _galsim.GalSimIncompatibleValuesError( + msg, + hermitian=hermitian, + bounds=bnds, + ) + else: + bnds.xmin = equinox.error_if( + bnds.xmin, + jnp.any(bnds.xmin != 0), + msg, + ) + else: + if bnds.isStatic() and bnds.ymin != 0: + raise _galsim.GalSimIncompatibleValuesError( + msg, + hermitian=hermitian, + bounds=bnds, + ) + else: + bnds.ymin = equinox.error_if( + bnds.ymin, + jnp.any(bnds.ymin != 0), + msg, + ) + + return bnds + # Get this at the start to check for invalid bounds and raise the exception before # possibly writing data past the edge of the image. if not hermitian: return self._wrap(bounds, False, False, None) elif hermitian == "x": - if not has_tracers(self.bounds.xmin) and self.bounds.xmin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'x' requires self.bounds.xmin == 0", - hermitian=hermitian, - bounds=self.bounds, - ) - if not has_tracers(bounds.xmin) and bounds.xmin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'x' requires bounds.xmin == 0", - hermitian=hermitian, - bounds=bounds, - ) + self._bounds = _raise_if_nonzero( + self.bounds, "x", "hermitian == 'x' requires self.bounds.xmin == 0" + ) + bounds = _raise_if_nonzero( + bounds, "x", "hermitian == 'x' requires self.bounds.xmin == 0" + ) + return self._wrap(bounds, True, False, 2 * bounds.xmax) elif hermitian == "y": - if not has_tracers(self.bounds.ymin) and self.bounds.ymin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'y' requires self.bounds.ymin == 0", - hermitian=hermitian, - bounds=self.bounds, - ) - if not has_tracers(bounds.ymin) and bounds.ymin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'y' requires bounds.ymin == 0", - hermitian=hermitian, - bounds=bounds, - ) + self._bounds = _raise_if_nonzero( + self.bounds, "y", "hermitian == 'y' requires self.bounds.ymin == 0" + ) + bounds = _raise_if_nonzero( + bounds, "y", "hermitian == 'y' requires self.bounds.ymin == 0" + ) + return self._wrap(bounds, False, True, 2 * bounds.ymax) else: raise _galsim.GalSimValueError( From c40a6aa47896613031e889be5da65e7a71860ca0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 11:25:46 -0500 Subject: [PATCH 18/37] fix: remove unneeded api --- jax_galsim/bounds.py | 23 ++++++++--------------- tests/jax/test_api.py | 1 - 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index bb4d9071..fcbaf779 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -499,20 +499,6 @@ def __hash__(self): class BoundsI(Bounds): _pos_class = PositionI - @staticmethod - def value_is_static(val): - """Return ``True`` if ``value`` is a static constant, ``False`` otherwise. - - This static method is used to test ``xmin`` and ``ymin`` to detect if a ``BoundsI`` - instance has a constant offset. The method is attached to the ``BoundsI`` instance - so that other classes can use it to detect static inputs to ``BoundsI`` classes - consistently. - """ - return isinstance( - val, - (int, float, np.floating, np.integer), - ) - def __init__(self, *args, **kwargs): # initial setting to let stuff pass through freely self._isstatic = True @@ -529,7 +515,14 @@ def __init__(self, *args, **kwargs): self.deltay = cast_to_int(self.deltay) if not ( - BoundsI.value_is_static(self._xmin) and BoundsI.value_is_static(self._ymin) + isinstance( + self._xmin, + (int, float, np.floating, np.integer), + ) + and isinstance( + self._ymin, + (int, float, np.floating, np.integer), + ) ): self._isstatic = False diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index 26f20585..e76b081c 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -360,7 +360,6 @@ def _reg_fun(p): "xmax", "ymax", "isStatic", - "value_is_static", ]: continue From 9e45b846c6f2fbe1b88871e9ed51796aabc9f73c Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 11:28:22 -0500 Subject: [PATCH 19/37] fix: remove extra keyword we no longer need --- jax_galsim/bounds.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index fcbaf779..69ab8746 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -21,10 +21,7 @@ Further, the JAX implementation adds a new method, ``isStatic`` to the ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance has been instantiated with static, known values, ``isStatic()`` will -return ``True``. You can indicate to JAX-GalSim that a ``BoundsI`` -instance should be static via initializing it with the ``static`` -keyword set to the ``True``. If the object detects that it is being -initialized with non-static data, an error will be raised. +return ``True``. ``BoundsI`` objects in JAX-Galsim support an additional initialization call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, @@ -362,10 +359,8 @@ def from_galsim(cls, galsim_bounds): """Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object.""" if isinstance(galsim_bounds, _galsim.BoundsD): _cls = BoundsD - kwargs = {} elif isinstance(galsim_bounds, _galsim.BoundsI): _cls = BoundsI - kwargs = {"static": True} else: raise TypeError( "galsim_bounds must be either a %s or a %s" @@ -377,7 +372,6 @@ def from_galsim(cls, galsim_bounds): galsim_bounds.xmax, galsim_bounds.ymin, galsim_bounds.ymax, - **kwargs, ) else: return _cls() @@ -503,8 +497,6 @@ def __init__(self, *args, **kwargs): # initial setting to let stuff pass through freely self._isstatic = True - force_static = kwargs.pop("static", False) - self._parse_args(*args, **kwargs) self.deltax = cast_to_float(self.deltax) @@ -537,13 +529,6 @@ def __init__(self, *args, **kwargs): if self.deltax < 1 and self.deltay < 1: self._isdefined = False - if force_static and not self._isstatic: - raise RuntimeError( - "BoundsI initialized with non-static " - f"data (xmin,ymin = {self._xmin!r},{self._ymin!r}) " - "when static data was explicitly requested." - ) - def _check_scalar(self, x, name): try: if ( From 15b4e224ef02f30d61c064bc0a9657ac07c11184 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 11:32:02 -0500 Subject: [PATCH 20/37] Apply suggestion from @beckermr --- jax_galsim/angle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index 88b749a2..fd3af1c3 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -213,7 +213,7 @@ def __div__(self, other): return _Angle(self._rad / other) else: raise TypeError( - "Cannot multiply Angle by %s of type %s" % (other, type(other)) + "Cannot divide Angle by %s of type %s" % (other, type(other)) ) __truediv__ = __div__ From 402eb30886649fa050d8db844fae796c6d2fef07 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 11:44:22 -0500 Subject: [PATCH 21/37] fix: bug in hermitian detecting checks --- jax_galsim/image.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index d94fd37a..8fb74dcf 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -105,8 +105,10 @@ def __init__(self, *args, **kwargs): array = kwargs.pop("array") if isinstance(array, np.ndarray): array = jnp.array(cast_numpy_array_to_native_byte_order(array)) + elif isinstance(array, jnp.ndarray): + pass else: - array = jnp.asarray(array) + raise TypeError(f"Unable to parse {array!r} as an array.") array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds @@ -717,12 +719,13 @@ def wrap(self, bounds, hermitian=False): def _raise_if_nonzero(bnds, x_or_y, msg): if x_or_y == "x": - if bnds.isStatic() and bnds.xmin != 0: - raise _galsim.GalSimIncompatibleValuesError( - msg, - hermitian=hermitian, - bounds=bnds, - ) + if bnds.isStatic(): + if bnds.xmin != 0: + raise _galsim.GalSimIncompatibleValuesError( + msg, + hermitian=hermitian, + bounds=bnds, + ) else: bnds.xmin = equinox.error_if( bnds.xmin, @@ -730,12 +733,13 @@ def _raise_if_nonzero(bnds, x_or_y, msg): msg, ) else: - if bnds.isStatic() and bnds.ymin != 0: - raise _galsim.GalSimIncompatibleValuesError( - msg, - hermitian=hermitian, - bounds=bnds, - ) + if bnds.isStatic(): + if bnds.ymin != 0: + raise _galsim.GalSimIncompatibleValuesError( + msg, + hermitian=hermitian, + bounds=bnds, + ) else: bnds.ymin = equinox.error_if( bnds.ymin, From eceafb923752e4e3e3c17cf653966df0ab0dcdc0 Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 12:25:34 -0500 Subject: [PATCH 22/37] doc: add some docs --- docs/sharp-bits.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 031fa843..5286ee09 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -61,6 +61,30 @@ does not affect the original. # JAX-GalSim — real_part is a copy real_part = complex_image.real # independent array +Scalar Types, Array Types, and Casting +-------------------------------------- + +With the use of JAX, there are now many possible types for numeric data. These include + +- **Python scalars**: Things with types that are ``float``, ``int``, or ``complex``. +- **NumPy scalars**: Things with types that are subclasses are ``np.floating`` and ``np.integer``. +- **Numpy array scalars**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) == 0``. +- **Numpy arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. +- **JAX array scalars**: Things with a type that is ``jnp.ndarray`` and has ``np.ndim(...) == 0``. +- **JAX arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. + +**JAX does not have pure scalar types like NumPY. JAX uses array scalars for those instead.** + +JAX-GalSim uses the following rules when handling data types and casting. + +- If the item is a Python numeric type (i.e., ``int`` or ``float``) or a + NumPy scalar type (i.e., ``isinstance(x, np.number)``, ``isinstance(x, np.integer)``, etc.), + convert it to a Python type of the appropriate kind. +- For all other array-like types, cast to the correct type via ``jax.numpy.astype(x, ...)``. +- For putting data into FITS headers only, JAX-GalSim converts of NumPy/JAX arrays to Python + numeric types as long as there is one element in the array (i.e., it is a NumPy scalar type, + an array scalar, or a 1D array with one element). + Random Number Generation ------------------------ From 9e668fa22c8b788af630e8d5681bd256b06447ee Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 12:31:26 -0500 Subject: [PATCH 23/37] doc: finish docs --- docs/sharp-bits.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 5286ee09..23c51ae0 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -85,6 +85,10 @@ JAX-GalSim uses the following rules when handling data types and casting. numeric types as long as there is one element in the array (i.e., it is a NumPy scalar type, an array scalar, or a 1D array with one element). +These rules allow JAX-GalSim to transparently handle JAX's tracing operations, but can result in +the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in +some cases. + Random Number Generation ------------------------ From c2c609702a2733ef9daa967a2c42d90694e6f9ed Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 16:37:46 -0500 Subject: [PATCH 24/37] Apply suggestion from @beckermr --- docs/sharp-bits.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 23c51ae0..9dd024aa 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -68,8 +68,8 @@ With the use of JAX, there are now many possible types for numeric data. These i - **Python scalars**: Things with types that are ``float``, ``int``, or ``complex``. - **NumPy scalars**: Things with types that are subclasses are ``np.floating`` and ``np.integer``. -- **Numpy array scalars**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) == 0``. -- **Numpy arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. +- **NumPy array scalars**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) == 0``. +- **NumPy arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. - **JAX array scalars**: Things with a type that is ``jnp.ndarray`` and has ``np.ndim(...) == 0``. - **JAX arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. From d28f0b9eb1298b0f4f159614f80f1138b7c86842 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 16:59:09 -0500 Subject: [PATCH 25/37] Apply suggestion from @beckermr --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8fb74dcf..583c1ef1 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -316,7 +316,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds if self.bounds.isDefined() and isinstance( - self.array, (np.ndarray, jnp.ndarray, jax.Array) + self.array, (np.ndarray, jax.Array) ): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs From 250560fae662f1594576e2c0e8b90b0fe25fe651 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 17:01:17 -0500 Subject: [PATCH 26/37] Apply suggestion from @beckermr --- docs/sharp-bits.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 9dd024aa..06c5bb62 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -70,8 +70,8 @@ With the use of JAX, there are now many possible types for numeric data. These i - **NumPy scalars**: Things with types that are subclasses are ``np.floating`` and ``np.integer``. - **NumPy array scalars**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) == 0``. - **NumPy arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. -- **JAX array scalars**: Things with a type that is ``jnp.ndarray`` and has ``np.ndim(...) == 0``. -- **JAX arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. +- **JAX array scalars**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) == 0``. +- **JAX arrays**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) > 0``. **JAX does not have pure scalar types like NumPY. JAX uses array scalars for those instead.** From e686a8ab761498c74a2a4ac3c849129e459c9138 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Mon, 18 May 2026 17:05:27 -0500 Subject: [PATCH 27/37] Apply suggestion from @beckermr --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 583c1ef1..8fb74dcf 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -316,7 +316,7 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds if self.bounds.isDefined() and isinstance( - self.array, (np.ndarray, jax.Array) + self.array, (np.ndarray, jnp.ndarray, jax.Array) ): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs From db8310994022a49a8a89b2ae9283caf9323feacf Mon Sep 17 00:00:00 2001 From: beckermr Date: Mon, 18 May 2026 17:18:18 -0500 Subject: [PATCH 28/37] fix: no need to ensure it is hashable here --- jax_galsim/image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 8fb74dcf..7c78a178 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -318,7 +318,12 @@ def __repr__(self): if self.bounds.isDefined() and isinstance( self.array, (np.ndarray, jnp.ndarray, jax.Array) ): - s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) + try: + np.array(self.array) + except Exception: + pass + else: + s += ", array=\n%r" % (np.array(self.array),) s += ", wcs=%r" % self.wcs if self.isconst: s += ", make_const=True" From 0677f1af49109dd84608787d08706bbfc4c4dc09 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 06:56:17 -0500 Subject: [PATCH 29/37] Apply suggestion from @ismael-mendoza Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- docs/sharp-bits.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 06c5bb62..7bb28644 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -73,7 +73,7 @@ With the use of JAX, there are now many possible types for numeric data. These i - **JAX array scalars**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) == 0``. - **JAX arrays**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) > 0``. -**JAX does not have pure scalar types like NumPY. JAX uses array scalars for those instead.** +**JAX does not have pure scalar types like NumPy. JAX uses array scalars for those instead.** JAX-GalSim uses the following rules when handling data types and casting. From b848ae0b58b01f06b98a2d50390b461ff552965c Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 06:56:29 -0500 Subject: [PATCH 30/37] Apply suggestion from @ismael-mendoza Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- docs/sharp-bits.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 7bb28644..b9087e69 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -67,7 +67,7 @@ Scalar Types, Array Types, and Casting With the use of JAX, there are now many possible types for numeric data. These include - **Python scalars**: Things with types that are ``float``, ``int``, or ``complex``. -- **NumPy scalars**: Things with types that are subclasses are ``np.floating`` and ``np.integer``. +- **NumPy scalars**: Things with types that are subclasses of ``np.floating`` and ``np.integer``. - **NumPy array scalars**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) == 0``. - **NumPy arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. - **JAX array scalars**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) == 0``. From be78240417bec93e0b80daf12ca2eb53bb24c627 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 06:56:51 -0500 Subject: [PATCH 31/37] Apply suggestion from @ismael-mendoza Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- docs/sharp-bits.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index b9087e69..9697fd3a 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -264,7 +264,7 @@ should not affect scientific conclusions. In the :doc:`api/index` you will find **🔪 JAX-GalSim - The Sharp Bits 🔪** blocks highlighting additional important caveats for specific classes and or methods. These could include things like: -- Some classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.GSObject.drawImage`). +- Some classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.InterpolatedImage`). - Certain profiles might not be auto-differentiable with respect to some of their parameters (e.g., :class:`~jax_galsim.Spergel`, :class:`~jax_galsim.Moffat`) - Limitations regarding what types of inputes are handled (e.g., :meth:`~jax_galsim.Image.calculate_fft` does not accept complex dtypes.) From f12c9f5dbc1666501d0fe5ac6ec77191bb754e9e Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 06:57:09 -0500 Subject: [PATCH 32/37] Apply suggestion from @ismael-mendoza Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 7c78a178..3d139dc6 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -763,7 +763,7 @@ def _raise_if_nonzero(bnds, x_or_y, msg): self.bounds, "x", "hermitian == 'x' requires self.bounds.xmin == 0" ) bounds = _raise_if_nonzero( - bounds, "x", "hermitian == 'x' requires self.bounds.xmin == 0" + bounds, "x", "hermitian == 'x' requires bounds.xmin == 0" ) return self._wrap(bounds, True, False, 2 * bounds.xmax) From cacca74d66d1e8fb0552d43c2fda611ba9226039 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 06:57:22 -0500 Subject: [PATCH 33/37] Apply suggestion from @ismael-mendoza Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com> --- jax_galsim/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 3d139dc6..63e4da7a 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -772,7 +772,7 @@ def _raise_if_nonzero(bnds, x_or_y, msg): self.bounds, "y", "hermitian == 'y' requires self.bounds.ymin == 0" ) bounds = _raise_if_nonzero( - bounds, "y", "hermitian == 'y' requires self.bounds.ymin == 0" + bounds, "y", "hermitian == 'y' requires bounds.ymin == 0" ) return self._wrap(bounds, False, True, 2 * bounds.ymax) From f9f0613ae08fb87a913d8cbf3b0424d3523d1f62 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 07:06:46 -0500 Subject: [PATCH 34/37] Update descriptions of numeric data types in JAX Clarified descriptions of numeric data types in JAX. --- docs/sharp-bits.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 9697fd3a..ff28546e 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -66,12 +66,12 @@ Scalar Types, Array Types, and Casting With the use of JAX, there are now many possible types for numeric data. These include -- **Python scalars**: Things with types that are ``float``, ``int``, or ``complex``. -- **NumPy scalars**: Things with types that are subclasses of ``np.floating`` and ``np.integer``. -- **NumPy array scalars**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) == 0``. -- **NumPy arrays**: Things with a type that is ``np.ndarray`` and has ``np.ndim(...) > 0``. -- **JAX array scalars**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) == 0``. -- **JAX arrays**: Things with a type that is ``jax.numpy.ndarray`` and has ``jax.numpy.ndim(...) > 0``. +- **Python scalars**: Objects with types that are ``float``, ``int``, or ``complex``. +- **NumPy scalars**: Objects with types that are subclasses of ``np.floating``, ``np.integer``, etc. +- **NumPy array scalars**: Objects with a type that is ``np.ndarray`` and have ``np.ndim(...) == 0``. +- **NumPy arrays**: Objects with a type that is ``np.ndarray`` and have ``np.ndim(...) > 0``. +- **JAX array scalars**: Objects with a type that is ``jax.numpy.ndarray`` and have ``jax.numpy.ndim(...) == 0``. +- **JAX arrays**: Objects with a type that is ``jax.numpy.ndarray`` and have ``jax.numpy.ndim(...) > 0``. **JAX does not have pure scalar types like NumPy. JAX uses array scalars for those instead.** From b1329d7518a1e83c60c860f74d500ec61b5b0ea4 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 07:07:34 -0500 Subject: [PATCH 35/37] Fix condition to check if array shape is 1 --- jax_galsim/wcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 6dcec13b..ea3a9dbc 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -29,7 +29,7 @@ def _cast_to_static_numeric_scalar(x, msg=None): if x.ndim == 0: return x.item() - if all(sv for sv in x.shape == 1): + if all(sv == 1 for sv in x.shape): return x.ravel()[0].item() msg = msg or f"Cannot convert input {x!r} to a static, numeric scalar." From 3624566e47e166d01d129d7672c3ee58dbc6e2a8 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 21 May 2026 07:11:29 -0500 Subject: [PATCH 36/37] Implement size check for photon array Add error handling for photon array size exceeding limit --- jax_galsim/photon_array.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 2c8ff23f..1bbce816 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -93,6 +93,14 @@ def __init__( self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N self._Ntot = cast_to_int(self._Ntot) + if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None: + equinox.error_if( + jnp.array(N), + jnp.array(N > _JAX_GALSIM_PHOTON_ARRAY_SIZE), + f"The given photon array size {N} is larger than " + f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}.", + ) + if _nokeep is not None: self._nokeep = _nokeep else: From 1aeb448e4378a8dc641c7fbb227f809d9ab96e19 Mon Sep 17 00:00:00 2001 From: beckermr Date: Thu, 21 May 2026 07:13:36 -0500 Subject: [PATCH 37/37] fix: missed an import --- jax_galsim/photon_array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 1bbce816..e97ee2b7 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +import equinox import galsim as _galsim import jax import jax.numpy as jnp