diff --git a/tests/test_config_image.py b/tests/test_config_image.py index 169f1be527..80436a56d7 100644 --- a/tests/test_config_image.py +++ b/tests/test_config_image.py @@ -1367,7 +1367,11 @@ def test_tiled(): gal = galsim.Gaussian(sigma=sigma, flux=flux) gal.drawImage(stamp) stamp.addNoise(galsim.GaussianNoise(sigma=0.5, rng=ud)) - im1a[stamp.bounds] = stamp + if is_jax_galsim(): + # jax-galsim uses the JAX .at API for inplace ops + im1a = im1a.at[stamp.bounds].set(stamp) + else: + im1a[stamp.bounds] = stamp # Compare to what config builds im1b = galsim.config.BuildImage(config) diff --git a/tests/test_image.py b/tests/test_image.py index c80b47ff25..a820d48110 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -235,8 +235,16 @@ def test_Image_basic(): value2 = 53 + 12*x - 19*y if tchar[i] in ['US', 'UI']: value2 = abs(value2) - im1[x,y] = value2 - im2_view[galsim.PositionI(x,y)] = value2 + if is_jax_galsim(): + # jax-galsim uses the JAX .at API for inplace ops + im1 = im1.at[x,y].set(value2) + else: + im1[x,y] = value2 + if is_jax_galsim(): + # jax-galsim uses the JAX .at API for inplace ops + im2_view = im2_view.at[galsim.PositionI(x,y)].set(value2) + else: + im2_view[galsim.PositionI(x,y)] = value2 assert im1.getValue(x,y) == value2 assert im1.view().getValue(x=x, y=y) == value2 assert im1.view(make_const=True).getValue(x,y) == value2 @@ -278,7 +286,11 @@ def test_Image_basic(): else: value3 = 10*x + y im1.addValue(x,y, np.int64(value3-value2)) - im2_view[x,y] += np.int64(value3-value2) + if is_jax_galsim(): + # jax-galsim uses the JAX .at API for inplace ops + im2_view = im2_view.at[x,y].add(np.int64(value3-value2)) + else: + im2_view[x,y] += np.int64(value3-value2) assert im1[galsim.PositionI(x,y)] == value3 assert im1.view()[x,y] == value3 assert im1.view(make_const=True)[galsim.PositionI(x,y)] == value3 @@ -299,11 +311,19 @@ def test_Image_basic(): assert_raises(galsim.GalSimBoundsError,im1.addValue,0,0,1) assert_raises(galsim.GalSimBoundsError,im1.__call__,0,0) assert_raises(galsim.GalSimBoundsError,im1.__getitem__,0,0) - assert_raises(galsim.GalSimBoundsError,im1.__setitem__,0,0,1) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + assert_raises(galsim.GalSimBoundsError,lambda x, y, v: im1.at[x, y].set(v),0,0,1) + else: + assert_raises(galsim.GalSimBoundsError,im1.__setitem__,0,0,1) assert_raises(galsim.GalSimBoundsError,im1.view().setValue,0,0,1) assert_raises(galsim.GalSimBoundsError,im1.view().__call__,0,0) assert_raises(galsim.GalSimBoundsError,im1.view().__getitem__,0,0) - assert_raises(galsim.GalSimBoundsError,im1.view().__setitem__,0,0,1) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + assert_raises(galsim.GalSimBoundsError,lambda x, y, v: im1.view().at[x, y].set(v),0,0,1) + else: + assert_raises(galsim.GalSimBoundsError,im1.view().__setitem__,0,0,1) assert_raises(galsim.GalSimBoundsError,im1.setValue,ncol+1,0,1) assert_raises(galsim.GalSimBoundsError,im1.addValue,ncol+1,0,1) @@ -344,16 +364,29 @@ def test_Image_basic(): galsim.Image(ncol+1,nrow, init_value=10)) assert_raises(galsim.GalSimBoundsError,im1.setSubImage,galsim.BoundsI(0,ncol+1,0,nrow+1), galsim.Image(ncol+2,nrow+2, init_value=10)) - assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol,1,nrow), - galsim.Image(ncol+1,nrow, init_value=10)) - assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,0,nrow), - galsim.Image(ncol+1,nrow, init_value=10)) - assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol+1,1,nrow), - galsim.Image(ncol+1,nrow, init_value=10)) - assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,1,nrow+1), - galsim.Image(ncol+1,nrow, init_value=10)) - assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol+1,0,nrow+1), - galsim.Image(ncol+2,nrow+2, init_value=10)) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(0,ncol,1,nrow), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(1,ncol,0,nrow), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(1,ncol+1,1,nrow), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(1,ncol,1,nrow+1), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(0,ncol+1,0,nrow+1), + galsim.Image(ncol+2,nrow+2, init_value=10)) + else: + assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol,1,nrow), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,0,nrow), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol+1,1,nrow), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,1,nrow+1), + galsim.Image(ncol+1,nrow, init_value=10)) + assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol+1,0,nrow+1), + galsim.Image(ncol+2,nrow+2, init_value=10)) # Also, setting values in something that should be const assert_raises(galsim.GalSimImmutableError,im1.view(make_const=True).setValue,1,1,1) @@ -364,9 +397,17 @@ def test_Image_basic(): # Finally check for the wrong number of arguments in get/setitem assert_raises(TypeError,im1.__getitem__,1) - assert_raises(TypeError,im1.__setitem__,1,1) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + assert_raises(TypeError,lambda b, v: im1.at[b].set(v),1,1) + else: + assert_raises(TypeError,im1.__setitem__,1,1) assert_raises(TypeError,im1.__getitem__,1,2,3) - assert_raises(TypeError,im1.__setitem__,1,2,3,4) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + assert_raises(TypeError,lambda x, y, z, v: im1.at[x, y, z].set(v),1,2,3,4) + else: + assert_raises(TypeError,im1.__setitem__,1,2,3,4) # Check view of given data im3_view = galsim.Image(ref_array.astype(np_array_type)) @@ -519,8 +560,13 @@ def test_undefined_image(): assert_raises(galsim.GalSimUndefinedBoundsError,im1.setSubImage,galsim.BoundsI(1,2,1,2), galsim.Image(2,2, init_value=10)) - assert_raises(galsim.GalSimUndefinedBoundsError,im1.__setitem__,galsim.BoundsI(1,2,1,2), - galsim.Image(2,2, init_value=10)) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + assert_raises(galsim.GalSimUndefinedBoundsError,lambda b,v: im1.at[b].set(v),galsim.BoundsI(1,2,1,2), + galsim.Image(2,2, init_value=10)) + else: + assert_raises(galsim.GalSimUndefinedBoundsError,im1.__setitem__,galsim.BoundsI(1,2,1,2), + galsim.Image(2,2, init_value=10)) im1.scale = 1. assert_raises(galsim.GalSimUndefinedBoundsError,im1.calculate_fft) @@ -2097,7 +2143,11 @@ def test_Image_subImage(): err_msg="image.subImage(bounds) does not match reference for dtype = "+str(types[i])) np.testing.assert_array_equal(image[bounds].array, sub_array, err_msg="image[bounds] does not match reference for dtype = "+str(types[i])) - image[bounds] = galsim.Image(sub_array+100) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image(sub_array+100)) + else: + image[bounds] = galsim.Image(sub_array+100) np.testing.assert_array_equal(image[bounds].array, (sub_array+100), err_msg="image[bounds] = im2 does not set correctly for dtype = "+str(types[i])) for xpos in range(1,test_shape[0]+1): @@ -2111,67 +2161,131 @@ def test_Image_subImage(): "image[bounds] = im2 set wrong locations for dtype = "+str(types[i]) image = galsim.Image(ref_array.astype(types[i])) - image[bounds] += 100 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].add(100) + else: + image[bounds] += 100 np.testing.assert_array_equal(image[bounds].array, (sub_array+100), err_msg="image[bounds] += 100 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image(sub_array) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image(sub_array)) + else: + image[bounds] = galsim.Image(sub_array) np.testing.assert_array_equal(image.array, ref_array, err_msg="image[bounds] += 100 set wrong locations for dtype = "+str(types[i])) image = galsim.Image(ref_array.astype(types[i])) - image[bounds] -= 100 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].subtract(100) + else: + image[bounds] -= 100 np.testing.assert_array_equal(image[bounds].array, (sub_array-100), err_msg="image[bounds] -= 100 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image(sub_array) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image(sub_array)) + else: + image[bounds] = galsim.Image(sub_array) np.testing.assert_array_equal(image.array, ref_array, err_msg="image[bounds] -= 100 set wrong locations for dtype = "+str(types[i])) image = galsim.Image(ref_array.astype(types[i])) - image[bounds] *= 100 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].multiply(100) + else: + image[bounds] *= 100 np.testing.assert_array_equal(image[bounds].array, (sub_array*100), err_msg="image[bounds] *= 100 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image(sub_array) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image(sub_array)) + else: + image[bounds] = galsim.Image(sub_array) np.testing.assert_array_equal(image.array, ref_array, err_msg="image[bounds] *= 100 set wrong locations for dtype = "+str(types[i])) image = galsim.Image((100*ref_array).astype(types[i])) - image[bounds] /= 100 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].divide(100) + else: + image[bounds] /= 100 np.testing.assert_array_equal(image[bounds].array, (sub_array), err_msg="image[bounds] /= 100 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image((100*sub_array).astype(types[i])) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image((100*sub_array).astype(types[i]))) + else: + image[bounds] = galsim.Image((100*sub_array).astype(types[i])) np.testing.assert_array_equal(image.array, (100*ref_array), err_msg="image[bounds] /= 100 set wrong locations for dtype = "+str(types[i])) im2 = galsim.Image(sub_array) image = galsim.Image(ref_array.astype(types[i])) - image[bounds] += im2 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].add(im2) + else: + image[bounds] += im2 np.testing.assert_array_equal(image[bounds].array, (2*sub_array), err_msg="image[bounds] += im2 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image(sub_array) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image(sub_array)) + else: + image[bounds] = galsim.Image(sub_array) np.testing.assert_array_equal(image.array, ref_array, err_msg="image[bounds] += im2 set wrong locations for dtype = "+str(types[i])) image = galsim.Image(2*ref_array.astype(types[i])) - image[bounds] -= im2 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].subtract(im2) + else: + image[bounds] -= im2 np.testing.assert_array_equal(image[bounds].array, sub_array, err_msg="image[bounds] -= im2 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image((2*sub_array).astype(types[i])) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image((2*sub_array).astype(types[i]))) + else: + image[bounds] = galsim.Image((2*sub_array).astype(types[i])) np.testing.assert_array_equal(image.array, (2*ref_array), err_msg="image[bounds] -= im2 set wrong locations for dtype = "+str(types[i])) image = galsim.Image(ref_array.astype(types[i])) - image[bounds] *= im2 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].multiply(im2) + else: + image[bounds] *= im2 np.testing.assert_array_equal(image[bounds].array, (sub_array**2), err_msg="image[bounds] *= im2 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image(sub_array) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image(sub_array)) + else: + image[bounds] = galsim.Image(sub_array) np.testing.assert_array_equal(image.array, ref_array, err_msg="image[bounds] *= im2 set wrong locations for dtype = "+str(types[i])) image = galsim.Image((2 * ref_array**2).astype(types[i])) - image[bounds] /= im2 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].divide(im2) + else: + image[bounds] /= im2 np.testing.assert_array_equal(image[bounds].array, (2*sub_array), err_msg="image[bounds] /= im2 does not set correctly for dtype = "+str(types[i])) - image[bounds] = galsim.Image((2*sub_array**2).astype(types[i])) + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + image = image.at[bounds].set(galsim.Image((2*sub_array**2).astype(types[i]))) + else: + image[bounds] = galsim.Image((2*sub_array**2).astype(types[i])) np.testing.assert_array_equal(image.array, (2*ref_array**2), err_msg="image[bounds] /= im2 set wrong locations for dtype = "+str(types[i])) @@ -2728,7 +2842,11 @@ def test_copy(): assert im10b.wcs == im.wcs assert im10b.bounds == im.bounds np.testing.assert_array_equal(im10b.array, im.array) - im10b[2,3] = 27 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + im10b = im10b.at[2,3].set(27) + else: + im10b[2,3] = 27 assert im10b(2,3) == 27. assert im(2,3) != 27. @@ -2738,7 +2856,11 @@ def test_copy(): assert im5.bounds == im8.bounds np.testing.assert_array_equal(im5.array, im8.array) assert im5(3,8) == 11. - im8[3,8] = 15 + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + im8 = im8.at[3,8].set(15) + else: + im8[3,8] = 15 assert im5(3,8) == 11. assert_raises(TypeError, im5.copyFrom, im8.array) @@ -3429,7 +3551,11 @@ def test_wrap(): for i in range(17): for j in range(23): val = np.exp(i/7.3) + (j/12.9)**3 # Something randomly complicated... - im[i,j] = val + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + im = im.at[i,j].set(val) + else: + im[i,j] = val # Find the location in the sub-image for this point. ii = (i-b.xmin) % (b.xmax-b.xmin+1) + b.xmin jj = (j-b.ymin) % (b.ymax-b.ymin+1) + b.ymin @@ -3463,12 +3589,19 @@ def test_wrap(): # An arbitrary, complicated Hermitian function. val = np.exp((i/(2.3*M))**2 + 1j*(2.8*i-1.3*j)) + ((2 + 3j*j)/(1.9*N))**3 #val = 2*(i-j)**2 + 3j*(i+j) - - im[i,j] = val - if j >= 0: - im2[i,j] = val - if i >= 0: - im3[i,j] = val + if is_jax_galsim(): + # jax-galsim uses .at syntax for setting items + im = im.at[i,j].set(val) + if j >= 0: + im2 = im2.at[i,j].set(val) + if i >= 0: + im3 = im3.at[i,j].set(val) + else: + im[i,j] = val + if j >= 0: + im2[i,j] = val + if i >= 0: + im3[i,j] = val ii = (i-b.xmin) % (b.xmax-b.xmin+1) + b.xmin jj = (j-b.ymin) % (b.ymax-b.ymin+1) + b.ymin