Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,23 @@ def _max_sb(self):
return self.flux / (self.width * self.height)

def _xValue(self, pos):
return self._xValue_array(pos.x, pos.y)

def _xValue_array(self, x, y):
norm = self.flux / (self.width * self.height)
return jnp.where(
2.0 * jnp.abs(pos.x) < self.width,
jnp.where(2.0 * jnp.abs(pos.y) < self.height, norm, 0.0),
2.0 * jnp.abs(x) < self.width,
jnp.where(2.0 * jnp.abs(y) < self.height, norm, 0.0),
0.0,
)

def _kValue(self, kpos):
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
_wo2pi = self.width / (2.0 * jnp.pi)
_ho2pi = self.height / (2.0 * jnp.pi)
return self.flux * jnp.sinc(kpos.x * _wo2pi) * jnp.sinc(kpos.y * _ho2pi)
return self.flux * jnp.sinc(kx * _wo2pi) * jnp.sinc(ky * _ho2pi)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
_jac = jnp.eye(2) if jac is None else jac
Expand Down
31 changes: 25 additions & 6 deletions jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,13 @@ def _xValue(self, pos):
raise NotImplementedError("Real-space convolutions are not implemented")

def _kValue(self, kpos):
kv_list = [
obj.kValue(kpos) for obj in self.obj_list
] # In GalSim one uses obj.kValue
return jnp.prod(jnp.array(kv_list))
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
result = self.obj_list[0]._kValue_array(kx, ky)
for obj in self.obj_list[1:]:
result = result * obj._kValue_array(kx, ky)
return result

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Real-space convolutions are not implemented")
Expand All @@ -313,10 +316,23 @@ def _shoot(self, photons, rng):
photons.convolve(p1, rng)

def _drawKImage(self, image, jac=None):
from jax_galsim import Image

image = self.obj_list[0]._drawKImage(image, jac)
if len(self.obj_list) > 1:
for obj in self.obj_list[1:]:
image *= obj._drawKImage(image, jac)
# Use a fresh blank image with matching metadata to sever
# the false AD dependency on the galaxy-filled array.
# The PSF's _drawKImage only uses image metadata (bounds,
# scale, wcs), never the array data.
blank = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale)
obj_kimage = obj._drawKImage(blank, jac)
image = Image(
array=image.array * obj_kimage.array,
bounds=image.bounds,
wcs=image.wcs,
_check_bounds=False,
)
return image

def tree_flatten(self):
Expand Down Expand Up @@ -474,10 +490,13 @@ def _max_sb(self):
return -self.orig_obj.max_sb / self.orig_obj.flux**2

def _kValue(self, pos):
return self._kValue_array(pos.x, pos.y)

def _kValue_array(self, kx, ky):
# Really, for very low original kvalues, this gets very high, which can be unstable
# in the presence of noise. So if the original value is less than min_acc_kvalue,
# we instead just return 1/min_acc_kvalue rather than the real inverse.
kval = self.orig_obj._kValue(pos)
kval = self.orig_obj._kValue_array(kx, ky)
return jnp.where(
jnp.abs(kval) < self._min_acc_kvalue,
self._inv_min_acc_kvalue,
Expand Down
29 changes: 10 additions & 19 deletions jax_galsim/core/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def draw_by_xValue(
):
"""Utility function to draw a real-space GSObject into an Image."""
# putting the import here to avoid circular imports
from jax_galsim import Image, PositionD
from jax_galsim import Image

# Applies flux scaling to compensate for pixel scale
# See SBProfile.draw()
Expand All @@ -29,9 +29,7 @@ def draw_by_xValue(
flux_scaling *= jnp.exp(logdet)

# Draw the object
im = jax.vmap(lambda *args: gsobject._xValue(PositionD(*args)))(
coords[..., 0], coords[..., 1]
)
im = gsobject._xValue_array(coords[..., 0], coords[..., 1])

# Apply the flux scaling
im = (im * flux_scaling).astype(image.dtype)
Expand All @@ -42,42 +40,35 @@ def draw_by_xValue(

def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)):
# putting the import here to avoid circular imports
from jax_galsim import Image, PositionD
from jax_galsim import Image

# Create an array of coordinates
coords = jnp.stack(image.get_pixel_centers(), axis=-1)
coords = coords * image.scale # Scale by the image pixel scale
coords = jnp.dot(coords, jacobian)

# Draw the object
im = jax.vmap(lambda *args: gsobject._kValue(PositionD(*args)))(
coords[..., 0], coords[..., 1]
)
im = (im).astype(image.dtype)
im = gsobject._kValue_array(coords[..., 0], coords[..., 1])
im = im.astype(image.dtype)

# Return an image
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)


def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)):
# putting the import here to avoid circular imports
from jax_galsim import Image, PositionD
from jax_galsim import Image

# Create an array of coordinates
kcoords = jnp.stack(image.get_pixel_centers(), axis=-1)
kcoords = kcoords * image.scale # Scale by the image pixel scale
kcoords = jnp.dot(kcoords, jacobian)
cenx, ceny = offset.x, offset.y

# flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) )
# NB: seems that tere is no jax.lax.polar equivalent to c++ std::polar function
def phase(kpos):
arg = -(kpos.x * cenx + kpos.y * ceny)
return jnp.cos(arg) + 1j * jnp.sin(arg)
# flux Exp(-i (kx cx + ky cy) )
kx, ky = kcoords[..., 0], kcoords[..., 1]
arg = -(kx * offset.x + ky * offset.y)
im_phase = jnp.cos(arg) + 1j * jnp.sin(arg)

im_phase = jax.vmap(lambda *args: phase(PositionD(*args)))(
kcoords[..., 0], kcoords[..., 1]
)
return Image(
array=image.array * im_phase,
bounds=image.bounds,
Expand Down
17 changes: 8 additions & 9 deletions jax_galsim/deltafunction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

Expand Down Expand Up @@ -54,16 +53,16 @@ def _max_sb(self):
return DeltaFunction._mock_inf

def _xValue(self, pos):
return jax.lax.cond(
jnp.array(pos.x == 0.0, dtype=bool) & jnp.array(pos.y == 0.0, dtype=bool),
lambda *a: DeltaFunction._mock_inf,
lambda *a: 0.0,
)
return self._xValue_array(pos.x, pos.y)

def _xValue_array(self, x, y):
return jnp.where((x == 0.0) & (y == 0.0), DeltaFunction._mock_inf, 0.0)

def _kValue(self, kpos):
# this is a wasteful and fancy way to get the shape to broadcast to
# to match the input kpos
return self.flux + kpos.x * (0.0 + 0.0j)
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
return self.flux + kx * (0.0 + 0.0j)

@implements(_galsim.DeltaFunction._shoot)
def _shoot(self, photons, rng):
Expand Down
10 changes: 8 additions & 2 deletions jax_galsim/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,17 @@ def _max_sb(self):
return self._norm

def _xValue(self, pos):
r = jnp.sqrt(pos.x**2 + pos.y**2)
return self._xValue_array(pos.x, pos.y)

def _xValue_array(self, x, y):
r = jnp.sqrt(x**2 + y**2)
return self._norm * jnp.exp(-r * self._inv_r0)

def _kValue(self, kpos):
ksqp1 = (kpos.x**2 + kpos.y**2) * self._r0**2 + 1.0
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
ksqp1 = (kx**2 + ky**2) * self._r0**2 + 1.0
return self.flux / (ksqp1 * jnp.sqrt(ksqp1))

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
Expand Down
10 changes: 8 additions & 2 deletions jax_galsim/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,17 @@ def _max_sb(self):
return self._norm

def _xValue(self, pos):
rsq = pos.x**2 + pos.y**2
return self._xValue_array(pos.x, pos.y)

def _xValue_array(self, x, y):
rsq = x**2 + y**2
return self._norm * jnp.exp(-0.5 * rsq * self._inv_sigsq)

def _kValue(self, kpos):
ksq = (kpos.x**2 + kpos.y**2) * self._sigsq
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
ksq = (kx**2 + ky**2) * self._sigsq
return self.flux * jnp.exp(-0.5 * ksq)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
Expand Down
26 changes: 26 additions & 0 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,32 @@ def _kValue(self, kpos):
"%s does not implement kValue" % self.__class__.__name__
)

def _kValue_array(self, kx, ky):
"""Evaluate the k-space profile at arrays of kx, ky coordinates.

By default falls back to vmap over _kValue with PositionD. Concrete
profiles should override this with direct array operations for better
performance (avoids per-pixel PositionD construction).
"""
out_shape = jnp.shape(kx)
kx = jnp.atleast_1d(kx).ravel()
ky = jnp.atleast_1d(ky).ravel()
result = jax.vmap(lambda x, y: self._kValue(PositionD(x, y)))(kx, ky)
return result.reshape(out_shape)

def _xValue_array(self, x, y):
"""Evaluate the real-space profile at arrays of x, y coordinates.

By default falls back to vmap over _xValue with PositionD. Concrete
profiles should override this with direct array operations for better
performance (avoids per-pixel PositionD construction).
"""
out_shape = jnp.shape(x)
x = jnp.atleast_1d(x).ravel()
y = jnp.atleast_1d(y).ravel()
result = jax.vmap(lambda xx, yy: self._xValue(PositionD(xx, yy)))(x, y)
return result.reshape(out_shape)

@implements(_galsim.GSObject.withGSParams)
def withGSParams(self, gsparams=None, **kwargs):
if gsparams == self.gsparams:
Expand Down
26 changes: 26 additions & 0 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,18 @@ def _xValue(self, pos):
self._x_interpolant,
)[0]

def _xValue_array(self, x, y):
return _xValue_arr(
x,
y,
self._offset.x,
self._offset.y,
self._pad_image.bounds.xmin,
self._pad_image.bounds.ymin,
self._pad_image.array,
self._x_interpolant,
)

def _kValue(self, kpos):
kx = jnp.array([kpos.x], dtype=float)
ky = jnp.array([kpos.y], dtype=float)
Expand All @@ -779,6 +791,20 @@ def _kValue(self, kpos):
self._k_interpolant,
)[0]

def _kValue_array(self, kx, ky):
return _kValue_arr(
kx,
ky,
self._offset.x,
self._offset.y,
self._kim.bounds.xmin,
self._kim.bounds.ymin,
self._kim.array,
self._kim.scale,
self._x_interpolant,
self._k_interpolant,
)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
jacobian = jnp.eye(2) if jac is None else jac

Expand Down
19 changes: 10 additions & 9 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,11 @@ def _has_hard_edges(self):
def _max_sb(self):
return self._norm

@jax.jit
def _xValue(self, pos):
rsq = (pos.x**2 + pos.y**2) * self._inv_r0_sq
# trunc if r>maxR with r0 scaled version
return self._xValue_array(pos.x, pos.y)

def _xValue_array(self, x, y):
rsq = (x**2 + y**2) * self._inv_r0_sq
return jnp.where(
rsq > self._maxRrD_sq, 0.0, self._norm * jnp.power(1.0 + rsq, -self.beta)
)
Expand All @@ -352,14 +353,14 @@ def _kValue_trunc(self, k):
0.0,
)

@jax.jit
def _kValue(self, kpos):
"""computation of the Moffat response in k-space with switch of truncated/untracated case
kpos can be a scalar or a vector (typically, scalar for debug and 2D considering an image)
"""
k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq)
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
k = jnp.sqrt((kx**2 + ky**2) * self._r0_sq)
out_shape = jnp.shape(k)
k = jnp.atleast_1d(k)
# Flatten to 1D for _hankel's vmap which only maps over axis 0
k = jnp.atleast_1d(k).ravel()
res = jax.lax.cond(
self.trunc > 0,
lambda x: self._kValue_trunc(x),
Expand Down
12 changes: 8 additions & 4 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,19 @@ def _max_sb(self):
# from SBSpergelImpl.h
return jnp.abs(self._xnorm) * self._xnorm0

@jax.jit
def _xValue(self, pos):
r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0
return self._xValue_array(pos.x, pos.y)

def _xValue_array(self, x, y):
r = jnp.sqrt(x**2 + y**2) * self._inv_r0
res = jnp.where(r == 0, self._xnorm0, fz_nu(r, self.nu))
return self._xnorm * res

@jax.jit
def _kValue(self, kpos):
ksq = (kpos.x**2 + kpos.y**2) * self._r0_sq
return self._kValue_array(kpos.x, kpos.y)

def _kValue_array(self, kx, ky):
ksq = (kx**2 + ky**2) * self._r0_sq
return self.flux * jnp.power(1.0 + ksq, -1.0 - self.nu)

def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
Expand Down
Loading