Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions examples/demo1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2012-2026 by the GalSim developers team on GitHub
# https://github.com/GalSim-developers
#
# This file is part of GalSim: The modular galaxy image simulation toolkit.
# https://github.com/GalSim-developers/GalSim
#
# GalSim is free software: redistribution and use in source and binary forms,
# with or without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions, and the disclaimer given in the accompanying LICENSE
# file.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the disclaimer given in the documentation
# and/or other materials provided with the distribution.
#
"""
Demo #1

This is the first script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py.
(This file is designed to be viewed in a window 100 characters wide.)

Each of these demo*.py files are designed to be equivalent to the corresponding demo*.yaml file
(or demo*.json -- found in the json directory). If you are new to python, you should probably
look at those files first as they will probably have a quicker learning curve for you. Then you
can look through these python scripts, which show how to do the same thing. Of course, experienced
pythonistas may prefer to start with these scripts and then look at the corresponding YAML files.

To run this script, simply write:

python demo1.py


This first script is about as simple as it gets. We draw an image of a single galaxy convolved
with a PSF and write it to disk. We use a circular Gaussian profile for both the PSF and the
galaxy, and add a constant level of Gaussian noise to the image.

In each demo, we list the new features introduced in that demo file. These will differ somewhat
between the .py and .yaml (or .json) versions, since the two methods implement things in different
ways. (demo*.py are python scripts, while demo*.yaml and demo*.json are configuration files.)

New features introduced in this demo:

- obj = jax_galsim.Gaussian(flux, sigma)
- obj = jax_galsim.Convolve([list of objects])
- image = obj.drawImage(scale)
- image.added_flux (Only present after a drawImage command.)
- noise = jax_galsim.GaussianNoise(sigma)
- image.addNoise(noise)
- image.write(file_name)
- image.FindAdaptiveMom()
"""

import logging
import math
import os
import sys

import jax_galsim


def main(argv):
"""
About as simple as it gets:
- Use a circular Gaussian profile for the galaxy.
- Convolve it by a circular Gaussian PSF.
- Add Gaussian noise to the image.
"""
# In non-script code, use getLogger(__name__) at module scope instead.
logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger("demo1")

gal_flux = 1.0e5 # total counts on the image
gal_sigma = 2.0 # arcsec
psf_sigma = 1.0 # arcsec
pixel_scale = 0.2 # arcsec / pixel
noise = 30.0 # standard deviation of the counts in each pixel

logger.info("Starting demo script 1 using:")
logger.info(
" - circular Gaussian galaxy (flux = %.1e, sigma = %.1f),",
gal_flux,
gal_sigma,
)
logger.info(" - circular Gaussian PSF (sigma = %.1f),", psf_sigma)
logger.info(" - pixel scale = %.2f,", pixel_scale)
logger.info(" - Gaussian noise (sigma = %.2f).", noise)

# Define the galaxy profile
gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma)
logger.debug("Made galaxy profile")

# Define the PSF profile
psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) # PSF flux should always = 1
logger.debug("Made PSF profile")

# Final profile is the convolution of these
# Can include any number of things in the list, all of which are convolved
# together to make the final flux profile.
final = jax_galsim.Convolve([gal, psf])
logger.debug("Convolved components into final profile")

# Draw the image with a particular pixel scale, given in arcsec/pixel.
# The returned image has a member, added_flux, which is gives the total flux actually added to
# the image. One could use this value to check if the image is large enough for some desired
# accuracy level. Here, we just ignore it.
image = final.drawImage(scale=pixel_scale)
logger.debug(
"Made image of the profile: flux = %f, added_flux = %f",
gal_flux,
image.added_flux,
)

# Add Gaussian noise to the image with specified sigma
image.addNoise(jax_galsim.GaussianNoise(sigma=noise))
logger.debug("Added Gaussian noise")

# Write the image to a file
if not os.path.isdir("output"):
os.mkdir("output")
file_name = os.path.join("output", "demo1.fits")
# Note: if the file already exists, this will overwrite it.
image.write(file_name)
logger.info(
"Wrote image to %r" % file_name
) # using %r adds quotes around filename for us

results = image.FindAdaptiveMom()

logger.info("HSM reports that the image has observed shape and size:")
logger.info(
" e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)",
results.observed_shape.e1,
results.observed_shape.e2,
results.moments_sigma,
)
logger.info(
"Expected values in the limit that pixel response and noise are negligible:"
)
logger.info(
" e1 = %.3f, e2 = %.3f, sigma = %.3f",
0.0,
0.0,
math.sqrt(gal_sigma**2 + psf_sigma**2) / pixel_scale,
)


if __name__ == "__main__":
main(sys.argv)
175 changes: 175 additions & 0 deletions examples/demo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) 2012-2026 by the GalSim developers team on GitHub
# https://github.com/GalSim-developers
#
# This file is part of GalSim: The modular galaxy image simulation toolkit.
# https://github.com/GalSim-developers/GalSim
#
# GalSim is free software: redistribution and use in source and binary forms,
# with or without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions, and the disclaimer given in the accompanying LICENSE
# file.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the disclaimer given in the documentation
# and/or other materials provided with the distribution.
#
"""
Demo #2

The second script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py.
(This file is designed to be viewed in a window 100 characters wide.)

This script is a bit more sophisticated, but still pretty basic. We're still only making
a single image, but now the galaxy has an exponential radial profile and is sheared.
The PSF is a circular Moffat profile. The noise is drawn from a Poisson distribution
using the flux from both the object and a background sky level to determine the
variance in each pixel.

New features introduced in this demo:

- obj = jax_galsim.Exponential(flux, scale_radius)
- obj = jax_galsim.Moffat(beta, flux, half_light_radius)
- obj = obj.shear(g1, g2) -- with explanation of other ways to specify shear
- rng = jax_galsim.BaseDeviate(seed)
- noise = jax_galsim.PoissonNoise(rng, sky_level)
- galsim.hsm.EstimateShear(image, image_epsf)
"""

import logging
import os
import sys

import galsim

import jax_galsim


def main(argv):
"""
A little bit more sophisticated, but still pretty basic:
- Use a sheared, exponential profile for the galaxy.
- Convolve it by a circular Moffat PSF.
- Add Poisson noise to the image.
"""
# In non-script code, use getLogger(__name__) at module scope instead.
logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger("demo2")

gal_flux = 1.0e5 # counts
gal_r0 = 2.7 # arcsec
g1 = 0.1 #
g2 = 0.2 #
psf_beta = 5 #
psf_re = 1.0 # arcsec
pixel_scale = 0.2 # arcsec / pixel
sky_level = 2.5e3 # counts / arcsec^2

# This time use a particular seed, so the image is deterministic.
# This is the same seed that is used in demo2.yaml, which means the images
# produced by the two methods will be precisely identical.
random_seed = 1534225

# The first thing the config layer does with the random seed is to scramble
# it a bit. Specifically, it makes a random number generator (BaseDeviate)
# using that seed and asks for a raw value. This becomes the seed that
# actually gets used.
# The reason for this extra step is that eventually (cf. demo4) the config
# layer will want to increment these seed values when building multiple
# objects or images. If the user is likewise incrementing seed values for
# multiple runs of a given config file, these can interfere leading to
# surprising (and typically bad) results.
random_seed = jax_galsim.BaseDeviate(random_seed).raw()

logger.info("Starting demo script 2 using:")
logger.info(
" - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),",
g1,
g2,
gal_flux,
gal_r0,
)
logger.info(" - circular Moffat PSF (beta = %.1f, re = %.2f),", psf_beta, psf_re)
logger.info(" - pixel scale = %.2f,", pixel_scale)
logger.info(" - Poisson noise (sky level = %.1e).", sky_level)

# Initialize the (pseudo-)random number generator that we will be using below.
# For a technical reason that will be explained later (demo9.py), we add 1 to the
# given random seed here.
rng = jax_galsim.BaseDeviate(random_seed + 1)

# Define the galaxy profile.
gal = jax_galsim.Exponential(flux=gal_flux, scale_radius=gal_r0)

# Shear the galaxy by some value.
# There are quite a few ways you can use to specify a shape.
# q, beta Axis ratio and position angle: q = b/a, 0 < q < 1
# e, beta Ellipticity and position angle: |e| = (1-q^2)/(1+q^2)
# g, beta ("Reduced") Shear and position angle: |g| = (1-q)/(1+q)
# eta, beta Conformal shear and position angle: eta = ln(1/q)
# e1,e2 Ellipticity components: e1 = e cos(2 beta), e2 = e sin(2 beta)
# g1,g2 ("Reduced") shear components: g1 = g cos(2 beta), g2 = g sin(2 beta)
# eta1,eta2 Conformal shear components: eta1 = eta cos(2 beta), eta2 = eta sin(2 beta)
gal = gal.shear(g1=g1, g2=g2)
logger.debug("Made galaxy profile")

# Define the PSF profile.
psf = jax_galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re)
logger.debug("Made PSF profile")

# Final profile is the convolution of these.
final = jax_galsim.Convolve([gal, psf])
logger.debug("Convolved components into final profile")

# Draw the image with a particular pixel scale.
image = final.drawImage(scale=pixel_scale)
# The "effective PSF" is the PSF as drawn on an image, which includes the convolution
# by the pixel response. We label it epsf here.
image_epsf = psf.drawImage(scale=pixel_scale)
logger.debug("Made image of the profile")

# To get Poisson noise on the image, we will use a class called PoissonNoise.
# However, we want the noise to correspond to what you would get with a significant
# flux from tke sky. This is done by telling PoissonNoise to add noise from a
# sky level in addition to the counts currently in the image.
#
# One wrinkle here is that the PoissonNoise class needs the sky level in each pixel,
# while we have a sky_level in counts per arcsec^2. So we need to convert:
sky_level_pixel = sky_level * pixel_scale**2
noise = jax_galsim.PoissonNoise(rng, sky_level=sky_level_pixel)
image.addNoise(noise)
logger.debug("Added Poisson noise")

# Write the image to a file.
if not os.path.isdir("output"):
os.mkdir("output")
file_name = os.path.join("output", "demo2.fits")
file_name_epsf = os.path.join("output", "demo2_epsf.fits")
image.write(file_name)
image_epsf.write(file_name_epsf)
logger.info("Wrote image to %r", file_name)
logger.info("Wrote effective PSF image to %r", file_name_epsf)

results = galsim.hsm.EstimateShear(image.to_galsim(), image_epsf.to_galsim())

logger.info("HSM reports that the image has observed shape and size:")
logger.info(
" e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)",
results.observed_shape.e1,
results.observed_shape.e2,
results.moments_sigma,
)
logger.info(
"When carrying out Regaussianization PSF correction, HSM reports distortions"
)
logger.info(" e1, e2 = %.3f, %.3f", results.corrected_e1, results.corrected_e2)
logger.info(
"Expected values in the limit that noise and non-Gaussianity are negligible:"
)
exp_shear = galsim.Shear(g1=g1, g2=g2)
logger.info(" g1, g2 = %.3f, %.3f", exp_shear.e1, exp_shear.e2)


if __name__ == "__main__":
main(sys.argv)
32 changes: 31 additions & 1 deletion jax_galsim/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,15 +1094,45 @@ def tree_unflatten(cls, aux_data, children):
@classmethod
def from_galsim(cls, galsim_image):
"""Create a `Image` from a `galsim.Image` instance."""
wcs = (
BaseWCS.from_galsim(galsim_image.wcs)
if galsim_image.wcs is not None
else None
)
im = cls(
array=galsim_image.array,
wcs=BaseWCS.from_galsim(galsim_image.wcs),
wcs=wcs,
bounds=Bounds.from_galsim(galsim_image.bounds),
)
if hasattr(galsim_image, "header"):
im.header = galsim_image.header
return im

def to_galsim(self):
"""Create a galsim `Image` from a `jax_galsim.Image` object."""
wcs = self.wcs.to_galsim() if self.wcs is not None else None
return _galsim.Image(
np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs
)

@implements(
_galsim.Image.FindAdaptiveMom,
lax_description=(
"This method converts the current `jax_galsim.Image` to a native "
"`galsim.Image` and delegates the computation to "
"`galsim.hsm.FindAdaptiveMom`. The returned object is GalSim's "
"`ShapeData`."
),
)
def FindAdaptiveMom(self, *args, **kwargs):
Comment thread
beckermr marked this conversation as resolved.
args_ = [arg.to_galsim() if hasattr(arg, "to_galsim") else arg for arg in args]
kwargs_ = {
key: val.to_galsim() if hasattr(val, "to_galsim") else val
for key, val in kwargs.items()
}
gs_image = self.to_galsim()
return gs_image.FindAdaptiveMom(*args_, **kwargs_)


@implements(
_galsim._Image,
Expand Down
2 changes: 1 addition & 1 deletion tests/GalSim
Submodule GalSim updated 1 files
+15 −15 tests/test_wcs.py
1 change: 1 addition & 0 deletions tests/SBProfile_comparison_images
1 change: 1 addition & 0 deletions tests/fits_file
1 change: 0 additions & 1 deletion tests/galsim_tests_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ allowed_failures:
- "'Image' object has no attribute 'bin'"
- "module 'jax_galsim' has no attribute 'InterpolatedKImage'"
- "module 'jax_galsim' has no attribute 'CorrelatedNoise'"
- "'Image' object has no attribute 'FindAdaptiveMom'"
- "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes
- "ValueError not raised by from_xyz"
- "ValueError not raised by greatCirclePoint"
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ def _reg_sfun(g1):
def test_api_image(obj):
_run_object_checks(obj, obj.__class__, "docs-methods")
_run_object_checks(obj, obj.__class__, "pickle-eval-repr-img")
_run_object_checks(obj, obj.__class__, "to-from-galsim")

# JAX tracing should be an identity
assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj
Expand Down