diff --git a/docs/sharp-bits.rst b/docs/sharp-bits.rst index 7901176e..ff28546e 100644 --- a/docs/sharp-bits.rst +++ b/docs/sharp-bits.rst @@ -61,6 +61,34 @@ does not affect the original. # JAX-GalSim — real_part is a copy real_part = complex_image.real # independent array +Scalar Types, Array Types, and Casting +-------------------------------------- + +With the use of JAX, there are now many possible types for numeric data. These include + +- **Python scalars**: Objects with types that are ``float``, ``int``, or ``complex``. +- **NumPy scalars**: Objects with types that are subclasses of ``np.floating``, ``np.integer``, etc. +- **NumPy array scalars**: Objects with a type that is ``np.ndarray`` and have ``np.ndim(...) == 0``. +- **NumPy arrays**: Objects with a type that is ``np.ndarray`` and have ``np.ndim(...) > 0``. +- **JAX array scalars**: Objects with a type that is ``jax.numpy.ndarray`` and have ``jax.numpy.ndim(...) == 0``. +- **JAX arrays**: Objects with a type that is ``jax.numpy.ndarray`` and have ``jax.numpy.ndim(...) > 0``. + +**JAX does not have pure scalar types like NumPy. JAX uses array scalars for those instead.** + +JAX-GalSim uses the following rules when handling data types and casting. + +- If the item is a Python numeric type (i.e., ``int`` or ``float``) or a + NumPy scalar type (i.e., ``isinstance(x, np.number)``, ``isinstance(x, np.integer)``, etc.), + convert it to a Python type of the appropriate kind. +- For all other array-like types, cast to the correct type via ``jax.numpy.astype(x, ...)``. +- For putting data into FITS headers only, JAX-GalSim converts of NumPy/JAX arrays to Python + numeric types as long as there is one element in the array (i.e., it is a NumPy scalar type, + an array scalar, or a 1D array with one element). + +These rules allow JAX-GalSim to transparently handle JAX's tracing operations, but can result in +the code raising generic ``Exception`` instances instead of more specific ``GalSim`` exceptions in +some cases. + Random Number Generation ------------------------ @@ -163,9 +191,6 @@ profile parameters passed into a ``jit``-compiled function): def good(sigma): return jax.lax.cond(sigma > 1.0, lambda s: s * 2, lambda s: s, sigma) -JAX-GalSim uses an internal ``has_tracers()`` utility to detect tracing and -avoid problematic control flow in its own implementations. - Fixed output shapes ^^^^^^^^^^^^^^^^^^^ @@ -197,20 +222,9 @@ The ``__init__`` gotcha During ``jit`` tracing, JAX calls constructors with **tracer objects** rather than concrete Python numbers. Type checks like ``isinstance(sigma, float)`` will -fail on tracers. JAX-GalSim handles this internally, but if you subclass any -JAX-GalSim object, be aware that ``__init__`` may receive tracers: - -.. code-block:: python - - from jax_galsim.core.utils import has_tracers - - class MyProfile(jax_galsim.GSObject): - def __init__(self, sigma, gsparams=None): - if not has_tracers(sigma): - # Only validate with concrete values - if sigma <= 0: - raise ValueError("sigma must be positive") - ... +return ``False`` on tracers, and you cannot check correctness of values (e.g., +``if sigma > 0: ...```). JAX-GalSim handles this internally, but if you subclass any +JAX-GalSim object, be aware that ``__init__`` may receive tracers. Profile Restrictions -------------------- @@ -221,12 +235,9 @@ Some GalSim features are not yet implemented in JAX-GalSim: - **ChromaticObject**: All chromatic functionality (wavelength-dependent profiles) is not available. - **InterpolatedKImage**: Not implemented. -- **Airy, Kolmogorov, OpticalPSF, RealGalaxy**: See :doc:`api-coverage` for +- **Airy, Kolmogorov, OpticalPSF, RealGalaxy, etc.**: See :doc:`api-coverage` for the full list. -The project currently implements **22.5 %** of the GalSim public API, focused -on the most commonly used profiles and operations. - Numerical Precision ------------------- @@ -249,11 +260,11 @@ These differences are typically at the level of floating-point round-off should not affect scientific conclusions. ⚠️ Additional Sharp Bits --------------------------- +------------------------ In the :doc:`api/index` you will find **🔪 JAX-GalSim - The Sharp Bits 🔪** blocks highlighting additional important caveats for specific classes and or methods. These could include things like: -- Many classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.GSObject.drawImage`). +- Some classes do not perform some of Galsim's test for correctness during initialization (e.g., :meth:`~jax_galsim.InterpolatedImage`). - Certain profiles might not be auto-differentiable with respect to some of their parameters (e.g., :class:`~jax_galsim.Spergel`, :class:`~jax_galsim.Moffat`) - Limitations regarding what types of inputes are handled (e.g., :meth:`~jax_galsim.Image.calculate_fft` does not accept complex dtypes.) diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index aec36d90..fd3af1c3 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -27,7 +27,6 @@ from jax_galsim.core.utils import ( cast_to_float, ensure_hashable, - has_tracers, implements, ) @@ -55,7 +54,7 @@ class AngleUnit(object): def __init__(self, value): if isinstance(value, AngleUnit): raise TypeError("Cannot construct AngleUnit from another AngleUnit") - self._value = cast_to_float(value) + self._value = cast_to_float(value, accept_strings=True) @property @implements(_galsim.AngleUnit.value) @@ -199,7 +198,7 @@ def __sub__(self, other): return _Angle(self._rad - other._rad) def __mul__(self, other): - if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)): + if isinstance(other, (Angle, AngleUnit)): raise TypeError( "Cannot multiply Angle by %s of type %s" % (other, type(other)) ) @@ -210,7 +209,7 @@ def __mul__(self, other): def __div__(self, other): if isinstance(other, AngleUnit): return self._rad / other.value - elif has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES): + elif not isinstance(other, Angle): return _Angle(self._rad / other) else: raise TypeError( diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 21b6ba8a..69ab8746 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,16 +1,14 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - CONST_TYPES, cast_to_float, cast_to_int, - cast_to_python_float, check_is_int_then_cast, ensure_hashable, - has_tracers, implements, ) from jax_galsim.position import Position, PositionD, PositionI @@ -23,10 +21,7 @@ Further, the JAX implementation adds a new method, ``isStatic`` to the ``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance has been instantiated with static, known values, ``isStatic()`` will -return ``True``. You can indicate to JAX-GalSim that a ``BoundsI`` -instance should be static via initializing it with the ``static`` -keyword set to the ``True``. If the object detects that it is being -initialized with non-static data, an error will be raised. +return ``True``. ``BoundsI`` objects in JAX-Galsim support an additional initialization call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, @@ -139,8 +134,8 @@ def _parse_args(self, *args, **kwargs): else: max_delta = 1 if ( - isinstance(self.deltax, CONST_TYPES) - and isinstance(self.deltay, CONST_TYPES) + isinstance(self.deltax, (int, float, np.integer, np.floating)) + and isinstance(self.deltay, (int, float, np.integer, np.floating)) and (self.deltax < max_delta or self.deltay < max_delta) ): self._isdefined = False @@ -364,10 +359,8 @@ def from_galsim(cls, galsim_bounds): """Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object.""" if isinstance(galsim_bounds, _galsim.BoundsD): _cls = BoundsD - kwargs = {} elif isinstance(galsim_bounds, _galsim.BoundsI): _cls = BoundsI - kwargs = {"static": True} else: raise TypeError( "galsim_bounds must be either a %s or a %s" @@ -379,7 +372,6 @@ def from_galsim(cls, galsim_bounds): galsim_bounds.xmax, galsim_bounds.ymin, galsim_bounds.ymax, - **kwargs, ) else: return _cls() @@ -505,24 +497,25 @@ def __init__(self, *args, **kwargs): # initial setting to let stuff pass through freely self._isstatic = True - force_static = kwargs.pop("static", False) - self._parse_args(*args, **kwargs) - if has_tracers(self.deltax) or has_tracers(self.deltay): - raise RuntimeError( - "Jax-GalSim BoundsI instances must have a fixed width! " - f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." - ) - - self.deltax = cast_to_python_float(self.deltax) - self.deltay = cast_to_python_float(self.deltay) + self.deltax = cast_to_float(self.deltax) + self.deltay = cast_to_float(self.deltay) if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") - self.deltax = int(cast_to_int(self.deltax)) - self.deltay = int(cast_to_int(self.deltay)) + self.deltax = cast_to_int(self.deltax) + self.deltay = cast_to_int(self.deltay) - if has_tracers(self._xmin) or has_tracers(self._ymin): + if not ( + isinstance( + self._xmin, + (int, float, np.floating, np.integer), + ) + and isinstance( + self._ymin, + (int, float, np.floating, np.integer), + ) + ): self._isstatic = False # validate inputs are ints @@ -536,13 +529,6 @@ def __init__(self, *args, **kwargs): if self.deltax < 1 and self.deltay < 1: self._isdefined = False - if force_static and not self._isstatic: - raise RuntimeError( - "BoundsI initialized with non-static " - f"data (xmin,ymin = {self._xmin},{self._yminb}) " - "when static data was explicitly requested." - ) - def _check_scalar(self, x, name): try: if ( diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index f6785ceb..839d0658 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -7,41 +7,14 @@ import jax import jax.numpy as jnp import numpy as np -from jax.tree_util import tree_flatten - -CONST_TYPES = ( - float, - int, - np.ndarray, - np.int8, - np.int16, - np.int32, - np.int64, - np.float16, - np.float32, - np.float64, - np.complex64, - np.complex128, -) -CONST_TYPES_WITH_JAX = CONST_TYPES + ( - jax.Array, - jnp.ndarray, - jnp.int8, - jnp.int16, - jnp.int32, - jnp.int64, - jnp.float32, - jnp.float64, - jnp.complex64, - jnp.complex128, -) def check_is_int_then_cast(val, msg): """Check if `val` is an integer, raise if not, otherwise cast to int.""" - # for simple inputs, we can check direct in python - if isinstance(val, CONST_TYPES) and not has_tracers(val): - val = cast_to_python_float(val) + val = cast_to_float(val) + + if isinstance(val, (int, float, np.integer, np.floating)): + # for simple inputs, we can check direct in python if val != int(val): raise TypeError(msg) val = int(val) @@ -69,130 +42,39 @@ def cast_numpy_array_to_native_byte_order(arr): return arr.astype(arr.dtype.newbyteorder("=")) -def has_tracers(x): - """Return True if the input item is a JAX tracer or object, False otherwise.""" - for item in tree_flatten(x)[0]: - if isinstance(item, jax.core.Tracer) or type(item) is object: - return True - return False - - -@jax.jit -def compute_major_minor_from_jacobian(jac): - h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) - h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) - major = 0.5 * jnp.abs(h1 + h2) - minor = 0.5 * jnp.abs(h1 - h2) - return major, minor - - -def _cast_to_array_scalar(x, dtype=None): - """Cast the input to an array scalar. Works on python scalars, iterables and jax arrays. - For iterables it always takes the first element after a call to .ravel()""" - if dtype is None: - if hasattr(x, "dtype"): - dtype = x.dtype - else: - dtype = float - - if isinstance(x, jax.Array): - return jnp.atleast_1d(x).astype(dtype).ravel()[0] - elif hasattr(x, "astype"): - return x.astype(dtype).ravel()[0] +def _cast_to_type(x, typ, accept_strings=False): + if isinstance(x, (int, float, np.integer, np.floating)) or ( + accept_strings and isinstance(x, str) + ): + return typ(x) else: - return jnp.atleast_1d(jnp.array(x, dtype=dtype)).ravel()[0] + return jnp.astype(x, typ) -def cast_to_python_float(x): - """Cast the input to a python float. Works on python floats and jax arrays. - For jax arrays it always takes the first element after a call to .ravel()""" - if isinstance(x, jax.Array): - return _cast_to_array_scalar(x, dtype=float).item() - else: - try: - return float(x) - except TypeError: - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like - return 1.0 * x - except ValueError as e: - # we let NaNs through - if " NaN " in str(e): - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like - return 1.0 * x - else: - raise e +def cast_to_float(x, accept_strings=False): + """Cast the input to a float. Works on python floats/ints, numpy scalars, and jax/numpy arrays. + Parameters: + accept_strings: If True, allow string to ``float`` conversion. [default: False] -def cast_to_python_int(x): - """Cast the input to a python int. Works on python ints and jax arrays. - For jax arrays it always takes the first element after a call to .ravel()""" - if isinstance(x, jax.Array): - return _cast_to_array_scalar(x, dtype=int).item() - else: - try: - return int(x) - except TypeError: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - return 1 * x - except ValueError as e: - # we let NaNs through - if " NaN " in str(e): - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - return 1 * x - else: - raise e + Returns: + Input value ``x`` casted to a ``float``. + """ + # use the python `float` const/func here to promote to the highest + # precision available without emitting a warning in JAX + return _cast_to_type(x, float, accept_strings=accept_strings) -def cast_to_float(x): - """Cast the input to a float. Works on python floats and jax arrays.""" - try: - return float(x) - except Exception: - try: - return jnp.asarray(x, dtype=float) - except Exception: - # this will return the same value for anything float-like that - # cannot be cast to float - # however, it will raise an error if something is not float-like - # we exclude object types since they are used in JAX tracing - if type(x) is object: - return x - else: - return 1.0 * x +def cast_to_int(x, accept_strings=False): + """Cast the input to an int. Works on python floats/ints, numpy scalars, and jax/numpy arrays. + Parameters: + accept_strings: If True, allow string to ``int`` conversion. [default: False] -def cast_to_int(x): - """Cast the input to an int. Works on python floats/ints and jax arrays.""" - try: - return int(x) - except Exception: - try: - if not jnp.any(jnp.isnan(x)): - return jnp.asarray(x, dtype=int) - else: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x - except Exception: - # this will return the same value for anything int-like that - # cannot be cast to int - # however, it will raise an error if something is not int-like - if type(x) is object: - return x - else: - return 1 * x + Returns: + Input value ``x`` casted to an ``int``. + """ + return _cast_to_type(x, int, accept_strings=accept_strings) def is_equal_with_arrays(x, y): diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index d64564ab..01654860 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -13,7 +13,6 @@ from jax_galsim.celestial import CelestialCoord from jax_galsim.core.utils import ( cast_to_float, - cast_to_python_float, ensure_hashable, implements, ) @@ -32,6 +31,7 @@ JacobianWCS, OffsetWCS, PixelScale, + _cast_to_python_float, ) ######################################################################################### @@ -754,16 +754,16 @@ def _writeHeader(self, header, bounds): header["GS_WCS"] = ("GSFitsWCS", "GalSim WCS name") header["CTYPE1"] = "RA---" + self.wcs_type header["CTYPE2"] = "DEC--" + self.wcs_type - header["CRPIX1"] = cast_to_python_float(self.crpix[0]) - header["CRPIX2"] = cast_to_python_float(self.crpix[1]) - header["CD1_1"] = cast_to_python_float(self.cd[0][0]) - header["CD1_2"] = cast_to_python_float(self.cd[0][1]) - header["CD2_1"] = cast_to_python_float(self.cd[1][0]) - header["CD2_2"] = cast_to_python_float(self.cd[1][1]) + header["CRPIX1"] = _cast_to_python_float(self.crpix[0]) + header["CRPIX2"] = _cast_to_python_float(self.crpix[1]) + header["CD1_1"] = _cast_to_python_float(self.cd[0][0]) + header["CD1_2"] = _cast_to_python_float(self.cd[0][1]) + header["CD2_1"] = _cast_to_python_float(self.cd[1][0]) + header["CD2_2"] = _cast_to_python_float(self.cd[1][1]) header["CUNIT1"] = "deg" header["CUNIT2"] = "deg" - header["CRVAL1"] = cast_to_python_float(self.center.ra / degrees) - header["CRVAL2"] = cast_to_python_float(self.center.dec / degrees) + header["CRVAL1"] = _cast_to_python_float(self.center.ra / degrees) + header["CRVAL2"] = _cast_to_python_float(self.center.dec / degrees) if self.pv is not None: order = len(self.pv[0]) - 1 k = 0 @@ -771,8 +771,8 @@ def _writeHeader(self, header, bounds): for n in range(order + 1): for j in range(n + 1): i = n - j - header["PV1_" + str(k)] = cast_to_python_float(self.pv[0, i, j]) - header["PV2_" + str(k)] = cast_to_python_float(self.pv[1, j, i]) + header["PV1_" + str(k)] = _cast_to_python_float(self.pv[0, i, j]) + header["PV2_" + str(k)] = _cast_to_python_float(self.pv[1, j, i]) k = k + 1 if k in odd_indices: k = k + 1 @@ -785,7 +785,9 @@ def _writeHeader(self, header, bounds): if i == 1 and j == 0: aij -= 1 # Turn back into standard form. if aij != 0.0: - header["A_" + str(i) + "_" + str(j)] = cast_to_python_float(aij) + header["A_" + str(i) + "_" + str(j)] = _cast_to_python_float( + aij + ) header["B_ORDER"] = order for i in range(order + 1): for j in range(order + 1): @@ -793,7 +795,9 @@ def _writeHeader(self, header, bounds): if i == 0 and j == 1: bij -= 1 if bij != 0.0: - header["B_" + str(i) + "_" + str(j)] = cast_to_python_float(bij) + header["B_" + str(i) + "_" + str(j)] = _cast_to_python_float( + bij + ) if self.abp is not None: order = len(self.abp[0]) - 1 header["AP_ORDER"] = order @@ -803,7 +807,7 @@ def _writeHeader(self, header, bounds): if i == 1 and j == 0: apij -= 1 if apij != 0.0: - header["AP_" + str(i) + "_" + str(j)] = cast_to_python_float( + header["AP_" + str(i) + "_" + str(j)] = _cast_to_python_float( apij ) header["BP_ORDER"] = order @@ -813,7 +817,7 @@ def _writeHeader(self, header, bounds): if i == 0 and j == 1: bpij -= 1 if bpij != 0.0: - header["BP_" + str(i) + "_" + str(j)] = cast_to_python_float( + header["BP_" + str(i) + "_" + str(j)] = _cast_to_python_float( bpij ) return header diff --git a/jax_galsim/image.py b/jax_galsim/image.py index d7eac192..63e4da7a 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1,3 +1,4 @@ +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -8,7 +9,6 @@ from jax_galsim.core.utils import ( cast_numpy_array_to_native_byte_order, ensure_hashable, - has_tracers, implements, ) from jax_galsim.errors import GalSimImmutableError @@ -103,12 +103,12 @@ def __init__(self, *args, **kwargs): else: if "array" in kwargs: array = kwargs.pop("array") - if has_tracers(array) or isinstance(array, jnp.ndarray): - pass - elif isinstance(array, np.ndarray): + if isinstance(array, np.ndarray): array = jnp.array(cast_numpy_array_to_native_byte_order(array)) + elif isinstance(array, jnp.ndarray): + pass else: - raise TypeError("Unable to parse %s as an array." % array) + raise TypeError(f"Unable to parse {array!r} as an array.") array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds @@ -198,12 +198,7 @@ def __init__(self, *args, **kwargs): ncol = int(ncol) nrow = int(nrow) self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype) - if not has_tracers(xmin) and not has_tracers(ymin): - self._bounds = BoundsI( - xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True - ) - else: - self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) + self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value: self._array = self._array.at[...].add(init_value) elif bounds is not None: @@ -216,12 +211,7 @@ def __init__(self, *args, **kwargs): elif array is not None: self._array = array.view() nrow, ncol = array.shape - if not has_tracers(xmin) and not has_tracers(ymin): - self._bounds = BoundsI( - xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True - ) - else: - self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) + self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot specify init_value with array", @@ -290,22 +280,14 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if ( - check_bounds - and b.isDefined() - and not has_tracers(b.xmin) - and not has_tracers(b.ymin) - and not has_tracers(b.xmax) - and not has_tracers(b.ymax) - ): - # We need to disable this when jitting - if b.xmax - b.xmin + 1 != array.shape[1]: + if check_bounds and b.isDefined(): + if b.deltax != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( "Shape of array is inconsistent with provided bounds", array=array, bounds=b, ) - if b.ymax - b.ymin + 1 != array.shape[0]: + if b.deltay != array.shape[0]: raise _galsim.GalSimIncompatibleValuesError( "Shape of array is inconsistent with provided bounds", array=array, @@ -333,8 +315,15 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds - if self.bounds.isDefined() and not has_tracers(self.array): - s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) + if self.bounds.isDefined() and isinstance( + self.array, (np.ndarray, jnp.ndarray, jax.Array) + ): + try: + np.array(self.array) + except Exception: + pass + else: + s += ", array=\n%r" % (np.array(self.array),) s += ", wcs=%r" % self.wcs if self.isconst: s += ", make_const=True" @@ -596,14 +585,8 @@ def subImage(self, bounds): "Attempt to access subImage of undefined image" ) if ( - not has_tracers(self.bounds.xmin) - and not has_tracers(self.bounds.xmax) - and not has_tracers(self.bounds.ymin) - and not has_tracers(self.bounds.ymax) - and not has_tracers(bounds.xmin) - and not has_tracers(bounds.xmax) - and not has_tracers(bounds.ymin) - and not has_tracers(bounds.ymax) + self.bounds.isStatic() + and bounds.isStatic() and not self.bounds.includes(bounds) ): raise _galsim.GalSimBoundsError( @@ -641,14 +624,8 @@ def setSubImage(self, bounds, rhs): "Attempt to access values of an undefined image" ) if ( - not has_tracers(self.bounds.xmin) - and not has_tracers(self.bounds.xmax) - and not has_tracers(self.bounds.ymin) - and not has_tracers(self.bounds.ymax) - and not has_tracers(bounds.xmin) - and not has_tracers(bounds.xmax) - and not has_tracers(bounds.ymin) - and not has_tracers(bounds.ymax) + self.bounds.isStatic() + and bounds.isStatic() and not self.bounds.includes(bounds) ): raise _galsim.GalSimBoundsError( @@ -744,37 +721,60 @@ def __setitem__(self, *args): def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + def _raise_if_nonzero(bnds, x_or_y, msg): + if x_or_y == "x": + if bnds.isStatic(): + if bnds.xmin != 0: + raise _galsim.GalSimIncompatibleValuesError( + msg, + hermitian=hermitian, + bounds=bnds, + ) + else: + bnds.xmin = equinox.error_if( + bnds.xmin, + jnp.any(bnds.xmin != 0), + msg, + ) + else: + if bnds.isStatic(): + if bnds.ymin != 0: + raise _galsim.GalSimIncompatibleValuesError( + msg, + hermitian=hermitian, + bounds=bnds, + ) + else: + bnds.ymin = equinox.error_if( + bnds.ymin, + jnp.any(bnds.ymin != 0), + msg, + ) + + return bnds + # Get this at the start to check for invalid bounds and raise the exception before # possibly writing data past the edge of the image. if not hermitian: return self._wrap(bounds, False, False, None) elif hermitian == "x": - if not has_tracers(self.bounds.xmin) and self.bounds.xmin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'x' requires self.bounds.xmin == 0", - hermitian=hermitian, - bounds=self.bounds, - ) - if not has_tracers(bounds.xmin) and bounds.xmin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'x' requires bounds.xmin == 0", - hermitian=hermitian, - bounds=bounds, - ) + self._bounds = _raise_if_nonzero( + self.bounds, "x", "hermitian == 'x' requires self.bounds.xmin == 0" + ) + bounds = _raise_if_nonzero( + bounds, "x", "hermitian == 'x' requires bounds.xmin == 0" + ) + return self._wrap(bounds, True, False, 2 * bounds.xmax) elif hermitian == "y": - if not has_tracers(self.bounds.ymin) and self.bounds.ymin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'y' requires self.bounds.ymin == 0", - hermitian=hermitian, - bounds=self.bounds, - ) - if not has_tracers(bounds.ymin) and bounds.ymin != 0: - raise _galsim.GalSimIncompatibleValuesError( - "hermitian == 'y' requires bounds.ymin == 0", - hermitian=hermitian, - bounds=bounds, - ) + self._bounds = _raise_if_nonzero( + self.bounds, "y", "hermitian == 'y' requires self.bounds.ymin == 0" + ) + bounds = _raise_if_nonzero( + bounds, "y", "hermitian == 'y' requires bounds.ymin == 0" + ) + return self._wrap(bounds, False, True, 2 * bounds.ymax) else: raise _galsim.GalSimValueError( diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 932c48b8..13ea2a37 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -13,7 +13,6 @@ from jax_galsim.core.math import safe_sqrt from jax_galsim.core.utils import ( ensure_hashable, - has_tracers, implements, ) from jax_galsim.gsobject import GSObject @@ -59,13 +58,10 @@ def __init__( # let define beta_thr a threshold to trigger the truncature self._beta_thr = 1.1 - if has_tracers(trunc) or ( - isinstance(trunc, (np.ndarray, float, jnp.ndarray, int)) - and np.any(trunc != 0) - ): + if (not isinstance(trunc, (float, int))) or trunc != 0: raise ValueError( "JAX-GalSim does not support truncated Moffat profiles " - f"(got trunc={repr(trunc)}, always pass the constant 0.0)!" + f"(got trunc={trunc!r}, always pass the constant 0.0)!" ) if isinstance(beta, (float, int)): diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 967f55c3..e97ee2b7 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +import equinox import galsim as _galsim import jax import jax.numpy as jnp @@ -8,7 +9,7 @@ from jax_galsim.core.utils import ( cast_numpy_array_to_native_byte_order, - cast_to_python_int, + cast_to_int, implements, ) from jax_galsim.errors import ( @@ -91,19 +92,16 @@ def __init__( _nokeep=None, ): self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N + self._Ntot = cast_to_int(self._Ntot) + if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None: - try: - # this will raise a boolean conversion error in JAX - # which we swallow - err_cond = (N > _JAX_GALSIM_PHOTON_ARRAY_SIZE) or False - except Exception: - err_cond = False - - if err_cond: - raise GalSimValueError( - f"The given photon array size {N} is larger than " - f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." - ) + equinox.error_if( + jnp.array(N), + jnp.array(N > _JAX_GALSIM_PHOTON_ARRAY_SIZE), + f"The given photon array size {N} is larger than " + f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}.", + ) + if _nokeep is not None: self._nokeep = _nokeep else: @@ -820,7 +818,7 @@ def __repr__(self): import numpy as np s = "galsim.PhotonArray(%r, x=array(%r), y=array(%r), flux=array(%r)" % ( - cast_to_python_int(self.size()), + self.size(), np.array(self.x).tolist(), np.array(self.y).tolist(), np.array(self.flux).tolist(), @@ -844,7 +842,7 @@ def __repr__(self): return s def __str__(self): - return "galsim.PhotonArray(%r)" % cast_to_python_int(self.size()) + return "galsim.PhotonArray(%r)" % self.size() __hash__ = None diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index bf9f4f6d..e78a67e3 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -1,9 +1,9 @@ import galsim as _galsim +import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( - compute_major_minor_from_jacobian, ensure_hashable, implements, ) @@ -12,6 +12,15 @@ from jax_galsim.position import PositionD +@jax.jit +def _compute_major_minor_from_jacobian(jac): + h1 = jnp.hypot(jac[0, 0] + jac[1, 1], jac[0, 1] - jac[1, 0]) + h2 = jnp.hypot(jac[0, 0] - jac[1, 1], jac[0, 1] + jac[1, 0]) + major = 0.5 * jnp.abs(h1 + h2) + minor = 0.5 * jnp.abs(h1 - h2) + return major, minor + + @implements( _galsim.Transform, lax_description="Does not support Chromatic Objects or Convolutions.", @@ -277,12 +286,12 @@ def _kfactor(self, kx, ky): @property def _maxk(self): - _, minor = compute_major_minor_from_jacobian(self._jac) + _, minor = _compute_major_minor_from_jacobian(self._jac) return self._original.maxk / minor @property def _stepk(self): - major, _ = compute_major_minor_from_jacobian(self._jac) + major, _ = _compute_major_minor_from_jacobian(self._jac) stepk = self._original.stepk / major # If we have a shift, we need to further modify stepk # stepk = Pi/R diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 20e31eb3..ea3a9dbc 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,11 +1,16 @@ import galsim as _galsim +import jax import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.angle import AngleUnit, arcsec, radians from jax_galsim.celestial import CelestialCoord -from jax_galsim.core.utils import cast_to_python_float, ensure_hashable, implements +from jax_galsim.core.utils import ( + cast_to_float, + ensure_hashable, + implements, +) from jax_galsim.errors import GalSimValueError from jax_galsim.gsobject import GSObject from jax_galsim.position import Position, PositionD, PositionI @@ -13,6 +18,28 @@ from jax_galsim.transform import _Transform +# this function casts input values to python numeric values +# this kind of casting is only done for writing FITS headers +# and should never be done anywhere else in the code base +def _cast_to_static_numeric_scalar(x, msg=None): + if isinstance(x, (int, float, np.integer, np.floating)): + return x + + if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): + if x.ndim == 0: + return x.item() + + if all(sv == 1 for sv in x.shape): + return x.ravel()[0].item() + + msg = msg or f"Cannot convert input {x!r} to a static, numeric scalar." + raise RuntimeError(msg) + + +def _cast_to_python_float(x): + return cast_to_float(_cast_to_static_numeric_scalar(x)) + + # We inherit from the reference BaseWCS and only redefine the methods that # make references to jax_galsim objects. @implements(_galsim.BaseWCS) @@ -910,7 +937,7 @@ def _toJacobian(self): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("PixelScale", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") return self.affine()._writeLinearWCS(header, bounds) @staticmethod @@ -1032,9 +1059,15 @@ def _newOrigin(self, origin, world_origin): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("ShearWCS", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") - header["GS_G1"] = (cast_to_python_float(self.shear.g1), "GalSim image shear g1") - header["GS_G2"] = (cast_to_python_float(self.shear.g2), "GalSim image shear g2") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") + header["GS_G1"] = ( + _cast_to_python_float(self.shear.g1), + "GalSim image shear g1", + ) + header["GS_G2"] = ( + _cast_to_python_float(self.shear.g2), + "GalSim image shear g2", + ) return self.affine()._writeLinearWCS(header, bounds) @implements(_galsim.wcs.ShearWCS.copy) @@ -1326,15 +1359,21 @@ def world_origin(self): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("OffsetWCS", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") - header["GS_X0"] = (cast_to_python_float(self.origin.x), "GalSim image origin x") - header["GS_Y0"] = (cast_to_python_float(self.origin.y), "GalSim image origin y") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") + header["GS_X0"] = ( + _cast_to_python_float(self.origin.x), + "GalSim image origin x", + ) + header["GS_Y0"] = ( + _cast_to_python_float(self.origin.y), + "GalSim image origin y", + ) header["GS_U0"] = ( - cast_to_python_float(self.world_origin.x), + _cast_to_python_float(self.world_origin.x), "GalSim world origin u", ) header["GS_V0"] = ( - cast_to_python_float(self.world_origin.y), + _cast_to_python_float(self.world_origin.y), "GalSim world origin v", ) return self.affine()._writeLinearWCS(header, bounds) @@ -1404,23 +1443,29 @@ def _newOrigin(self, origin, world_origin): def _writeHeader(self, header, bounds): header["GS_WCS"] = ("OffsetShearWCS", "GalSim WCS name") - header["GS_SCALE"] = (cast_to_python_float(self.scale), "GalSim image scale") - header["GS_G1"] = (cast_to_python_float(self.shear.g1), "GalSim image shear g1") - header["GS_G2"] = (cast_to_python_float(self.shear.g2), "GalSim image shear g2") + header["GS_SCALE"] = (_cast_to_python_float(self.scale), "GalSim image scale") + header["GS_G1"] = ( + _cast_to_python_float(self.shear.g1), + "GalSim image shear g1", + ) + header["GS_G2"] = ( + _cast_to_python_float(self.shear.g2), + "GalSim image shear g2", + ) header["GS_X0"] = ( - cast_to_python_float(self.origin.x), + _cast_to_python_float(self.origin.x), "GalSim image origin x coordinate", ) header["GS_Y0"] = ( - cast_to_python_float(self.origin.y), + _cast_to_python_float(self.origin.y), "GalSim image origin y coordinate", ) header["GS_U0"] = ( - cast_to_python_float(self.world_origin.x), + _cast_to_python_float(self.world_origin.x), "GalSim world origin u coordinate", ) header["GS_V0"] = ( - cast_to_python_float(self.world_origin.y), + _cast_to_python_float(self.world_origin.y), "GalSim world origin v coordinate", ) return self.affine()._writeLinearWCS(header, bounds) @@ -1504,25 +1549,25 @@ def _writeLinearWCS(self, header, bounds): header["CTYPE1"] = ("LINEAR", "name of the world coordinate axis") header["CTYPE2"] = ("LINEAR", "name of the world coordinate axis") header["CRVAL1"] = ( - cast_to_python_float(self.u0), + _cast_to_python_float(self.u0), "world coordinate at reference pixel = u0", ) header["CRVAL2"] = ( - cast_to_python_float(self.v0), + _cast_to_python_float(self.v0), "world coordinate at reference pixel = v0", ) header["CRPIX1"] = ( - cast_to_python_float(self.x0), + _cast_to_python_float(self.x0), "image coordinate of reference pixel = x0", ) header["CRPIX2"] = ( - cast_to_python_float(self.y0), + _cast_to_python_float(self.y0), "image coordinate of reference pixel = y0", ) - header["CD1_1"] = (cast_to_python_float(self.dudx), "CD1_1 = dudx") - header["CD1_2"] = (cast_to_python_float(self.dudy), "CD1_2 = dudy") - header["CD2_1"] = (cast_to_python_float(self.dvdx), "CD2_1 = dvdx") - header["CD2_2"] = (cast_to_python_float(self.dvdy), "CD2_2 = dvdy") + header["CD1_1"] = (_cast_to_python_float(self.dudx), "CD1_1 = dudx") + header["CD1_2"] = (_cast_to_python_float(self.dudy), "CD1_2 = dudy") + header["CD2_1"] = (_cast_to_python_float(self.dvdx), "CD2_1 = dvdx") + header["CD2_2"] = (_cast_to_python_float(self.dvdy), "CD2_2 = dvdy") return header @staticmethod