diff --git a/dev/notebooks/moffat_maxk_interp.ipynb b/dev/notebooks/moffat_maxk_interp.ipynb new file mode 100644 index 00000000..457a32c1 --- /dev/null +++ b/dev/notebooks/moffat_maxk_interp.ipynb @@ -0,0 +1,1008 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9ac27263", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "from functools import partial # noqa: E402\n", + "\n", + "import galsim # noqa: E402\n", + "import jax.numpy as jnp # noqa: E402\n", + "import matplotlib.pyplot as plt # noqa: E402\n", + "import numpy as np # noqa: E402" + ] + }, + { + "cell_type": "markdown", + "id": "92bf2a2e", + "metadata": {}, + "source": [ + "# Fit a Psuedo-Pade Approximation" + ] + }, + { + "cell_type": "markdown", + "id": "940187b9", + "metadata": {}, + "source": [ + "## Define the Approximation and Fitting Range" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3e0d47ce", + "metadata": {}, + "outputs": [], + "source": [ + "# order of rational function in log(maxk_threshold), log(beta)\n", + "PADE_ORDERS = [9, 11]\n", + "\n", + "N_PARAMS_MKTS = PADE_ORDERS[0] * 2 - 1\n", + "N_PARAMS_BETA = PADE_ORDERS[1] * 2 - 1\n", + "N_PARAMS = N_PARAMS_MKTS * N_PARAMS_BETA\n", + "\n", + "LOG_BETA_MIN = np.log(1.1 + 1e-6)\n", + "LOG_BETA_MAX = np.log(100)\n", + "LOG_MKTS_MIN = np.log(1e-12)\n", + "LOG_MKTS_MAX = np.log(0.1)\n", + "\n", + "\n", + "def _pade_func(coeffs, x):\n", + " order = (coeffs.shape[0] - 1) // 2\n", + " p = jnp.polyval(coeffs[:order], x)\n", + " q = jnp.polyval(\n", + " jnp.concatenate([coeffs[order:], jnp.ones(1)], axis=0),\n", + " x,\n", + " )\n", + " return p / q\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.vmap, in_axes=(0, 0, None))\n", + "def _logmaxk_psuedo_pade_approx(log_beta, log_mkts, coeffs):\n", + " log_beta = (log_beta - LOG_BETA_MIN) / (LOG_BETA_MAX - LOG_BETA_MIN)\n", + " log_mkts = (log_mkts - LOG_MKTS_MIN) / (LOG_MKTS_MAX - LOG_MKTS_MIN)\n", + " coeffs = coeffs.reshape(N_PARAMS_MKTS, N_PARAMS_BETA)\n", + " pqvals = jax.vmap(_pade_func, in_axes=(0, None))(coeffs, log_beta)\n", + " return _pade_func(pqvals, log_mkts)" + ] + }, + { + "cell_type": "markdown", + "id": "46e05ca8", + "metadata": {}, + "source": [ + "## Do the Fit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "664de5ff", + "metadata": {}, + "outputs": [], + "source": [ + "n_beta = 50\n", + "n_mkts = 50\n", + "\n", + "\n", + "# this is the function we are interpolating\n", + "def _fun(beta, mkt):\n", + " return galsim.Moffat(\n", + " beta,\n", + " scale_radius=1.0\n", + " ).withGSParams(maxk_threshold=mkt).maxk\n", + "\n", + "\n", + "_betas = np.logspace(np.log10(np.exp(LOG_BETA_MIN)), np.log10(np.exp(LOG_BETA_MAX)), n_beta)\n", + "_mkts = np.logspace(np.log10(np.exp(LOG_MKTS_MIN)), np.log10(np.exp(LOG_MKTS_MAX)), n_mkts)\n", + "\n", + "betas = []\n", + "mkts = []\n", + "maxks = []\n", + "for beta in _betas:\n", + " for mkt in _mkts:\n", + " betas.append(beta)\n", + " mkts.append(mkt)\n", + " maxks.append(\n", + " _fun(beta, mkt)\n", + " )\n", + "betas = jnp.array(betas)\n", + "mkts = jnp.array(mkts)\n", + "maxks = jnp.array(maxks)\n", + "\n", + "\n", + "@jax.jit\n", + "def _loss(coeffs, lnbetas, lnmaxk_thresholds, lnmaxks):\n", + " pvals = _logmaxk_psuedo_pade_approx(lnbetas, lnmaxk_thresholds, coeffs)\n", + " return jnp.mean((pvals - lnmaxks)**2)\n", + "\n", + "\n", + "_vag_loss = jax.jit(jax.value_and_grad(_loss))\n", + "_g_loss = jax.jit(jax.grad(_loss))\n", + "_h_loss = jax.jit(jax.hessian(_loss))\n", + "\n", + "# generate an initial guess\n", + "coeffs = jnp.ones(N_PARAMS) * 1e-6\n", + "\n", + "# args for loss\n", + "lnb = jnp.log(betas)\n", + "lnmkts = jnp.log(mkts)\n", + "lnmaxks = jnp.log(maxks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a57bd84", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.scipy.optimize as jspop\n", + "import optax\n", + "import tqdm\n", + "\n", + "\n", + "def _min_optax(\n", + " fun,\n", + " x0,\n", + " args=None,\n", + " maxiter=100_000,\n", + " learning_rate=1e-1,\n", + " method=\"adan\",\n", + " optimizer=None,\n", + " opt_state=None,\n", + " update_prog_iter=100,\n", + "):\n", + " args = args or tuple()\n", + " _vag_fun = jax.jit(jax.value_and_grad(fun))\n", + "\n", + " if optimizer is None:\n", + " optimizer = getattr(optax, method)(learning_rate)\n", + " opt_state = optimizer.init(x0)\n", + "\n", + " @jax.jit\n", + " def _update_func(coeffs, opt_state):\n", + " loss, grads = _vag_fun(coeffs, *args)\n", + " updates, opt_state = optimizer.update(grads, opt_state, params=coeffs)\n", + " coeffs = optax.apply_updates(coeffs, updates)\n", + " return coeffs, opt_state, loss\n", + "\n", + " loss, _ = _vag_fun(x0, *args)\n", + "\n", + " prev_loss = None\n", + " coeffs = x0\n", + "\n", + " with tqdm.trange(maxiter) as pbar:\n", + " for i in pbar:\n", + " coeffs, opt_state, loss = _update_func(coeffs, opt_state)\n", + "\n", + " if i % update_prog_iter == 0 or i == 0:\n", + " if prev_loss is not None:\n", + " dloss = loss - prev_loss\n", + " else:\n", + " dloss = np.nan\n", + "\n", + " pbar.set_description(f\"{method}: {loss:12.8e} ({dloss:+9.2e} delta)\")\n", + "\n", + " prev_loss = loss\n", + "\n", + " return coeffs, (optimizer, opt_state)\n", + "\n", + "\n", + "def _min_bfgs(\n", + " fun,\n", + " x0,\n", + " args=None,\n", + " maxiter=100,\n", + "):\n", + " args = args or tuple()\n", + "\n", + " coeffs = x0\n", + " prev_loss = None\n", + " tol = 1e-16\n", + " with tqdm.trange(maxiter) as pbar:\n", + " for _ in pbar:\n", + " res = jspop.minimize(\n", + " fun,\n", + " coeffs,\n", + " method=\"BFGS\",\n", + " args=args,\n", + " tol=tol,\n", + " options={\"maxiter\": 10000, \"gtol\": tol, \"line_search_maxiter\": 40},\n", + " )\n", + "\n", + " if np.all(coeffs == res.x):\n", + " coeffs = coeffs * (1.0 + (np.random.uniform(size=coeffs.shape[0]) - 0.5) * 1e-10)\n", + " else:\n", + " coeffs = res.x\n", + "\n", + " if prev_loss is not None:\n", + " dloss = res.fun - prev_loss\n", + " else:\n", + " dloss = np.nan\n", + "\n", + " prev_loss = res.fun\n", + "\n", + " pbar.set_description(\n", + " f\"bfgs: {res.fun:12.8e} ({dloss:+9.2e} delta, status {res.status}, nit {res.nit:6d})\"\n", + " )\n", + "\n", + " return res.x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f60b417d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "initial loss: 1.14907850e+01\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "bfgs: 3.34256538e-09 (-1.36e-10 delta, status 3, nit 1056): 100%|██████████| 100/100 [03:20<00:00, 2.01s/it]\n", + "bfgs: 5.03995873e-10 (-2.06e-12 delta, status 3, nit 884): 100%|██████████| 100/100 [04:04<00:00, 2.45s/it]\n", + "bfgs: 2.26712315e-10 (-1.09e-12 delta, status 3, nit 972): 94%|█████████▍| 94/100 [04:31<00:16, 2.82s/it]" + ] + } + ], + "source": [ + "args = (lnb, lnmkts, lnmaxks)\n", + "\n", + "loss = _loss(coeffs, *args)\n", + "print(f\"initial loss: {loss:12.8e}\", flush=True)\n", + "\n", + "for _ in range(10):\n", + " # coeffs, _ = _min_optax(\n", + " # _loss,\n", + " # coeffs,\n", + " # args=args,\n", + " # learning_rate=1e-4,\n", + " # maxiter=100_000,\n", + " # update_prog_iter=1000,\n", + " # )\n", + " coeffs = _min_bfgs(\n", + " _loss,\n", + " coeffs,\n", + " args=args,\n", + " maxiter=100,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "bbab658c", + "metadata": {}, + "source": [ + "## Print the Coeffs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4e3047c", + "metadata": {}, + "outputs": [], + "source": [ + "import textwrap\n", + "\n", + "pstr = textwrap.indent(\n", + " np.array2string(np.array(coeffs), floatmode=\"unique\", threshold=100000000, separator=\", \", max_line_width=120, sign=\"+\"),\n", + " \" \",\n", + ")\n", + "\n", + "code_str = \"\"\"\\\n", + "# RATIONAL_POLY_VALS is the array of rational function\n", + "# polynomial coefficients that define the approximation\n", + "# fmt: off\n", + "RATIONAL_POLY_VALS = np.array(\n", + "{pstr},\n", + " dtype=np.float64,\n", + ")\n", + "# fmt: on\n" + ] + }, + { + "cell_type": "markdown", + "id": "108951d9", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1686f118", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng()\n", + "n_test = 100000\n", + "\n", + "tbetas = 10**rng.uniform(low=np.log10(1.1 + 1e-6), high=np.log10(100), size=n_test)\n", + "tmaxk_thresholds = 10**rng.uniform(low=-12, high=-1, size=n_test)\n", + "apprx = jnp.exp(_logmaxk_psuedo_pade_approx(jnp.log(tbetas), jnp.log(tmaxk_thresholds), coeffs))\n", + "true = np.array([\n", + " _fun(tbetas[i], tmaxk_thresholds[i])\n", + " for i in range(n_test)\n", + "])\n", + "eps = np.abs(apprx / true - 1)\n", + "c_func = np.max\n", + "eps_label = \"max|approx/true - 1|\"\n", + "print(c_func(eps))\n", + "\n", + "msk = tmaxk_thresholds <= 0.01\n", + "print(c_func(eps[msk]))\n", + "\n", + "# plt.hist(true / apprx - 1, bins=25, log=True)\n", + "# ax = plt.gca()\n", + "# ax.set_xlabel(\"fractional error in maxk approx.\")\n", + "# ax.set_ylabel(\"# of points\")\n", + "\n", + "fig, ax = plt.subplots()\n", + "hb = ax.hexbin(\n", + " np.log10(tbetas),\n", + " np.log10(tmaxk_thresholds),\n", + " C=eps,\n", + " reduce_C_function=c_func,\n", + " extent=(np.log10(1.1), np.log10(100), -12, -1),\n", + " gridsize=50,\n", + " bins=\"log\",\n", + ")\n", + "ax.set_xlim(np.log10(1.1), np.log10(100))\n", + "ax.set_ylim(-12, -1)\n", + "ax.set_xlabel(\"log10(beta)\")\n", + "ax.set_ylabel(\"log10(maxk_threshold)\")\n", + "fig.colorbar(hb, label=eps_label)" + ] + }, + { + "cell_type": "markdown", + "id": "0ea1bd2c", + "metadata": {}, + "source": [ + "# Hacking and Testing Code Below" + ] + }, + { + "cell_type": "markdown", + "id": "96dff6ba", + "metadata": {}, + "source": [ + "## Define Range of Interpolant and Spacings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5b33ede", + "metadata": {}, + "outputs": [], + "source": [ + "beta_min = 1.1 + 1e-6\n", + "beta_max = 100\n", + "n_beta = 100 # used to fit the rational function approx\n", + "mkts_min = 1e-12\n", + "mkts_max = 0.1\n", + "n_mkts = 100 # we build this many rational function approximations\n", + "RATNL_ORDER = 11\n", + "\n", + "betas = np.logspace(np.log10(beta_min), np.log10(beta_max), n_beta)\n", + "mkts = np.logspace(jnp.log10(mkts_min), jnp.log10(mkts_max), n_mkts)\n", + "\n", + "# this is the function we are interpolating\n", + "def _fun(beta, sr, mkt):\n", + " return galsim.Moffat(\n", + " beta,\n", + " scale_radius=sr\n", + " ).withGSParams(maxk_threshold=mkt).maxk" + ] + }, + { + "cell_type": "markdown", + "id": "e54bf87a", + "metadata": {}, + "source": [ + "## Build rational function apprx. in beta at fixed maxk_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c3d3535", + "metadata": {}, + "outputs": [], + "source": [ + "import scipy.optimize\n", + "import numpy.polynomial\n", + "from numpy.polynomial import Polynomial\n", + "import jax_galsim.core.interpolate\n", + "\n", + "numpy.polynomial.set_default_printstyle(\"ascii\")\n", + "\n", + "\n", + "def get_ratnl_func_polys(coeff):\n", + " p_coeff = coeff[0 : RATNL_ORDER + 1]\n", + " q_coeff = np.concatenate([[1], coeff[RATNL_ORDER + 1 :]])\n", + " pm = Polynomial(p_coeff)\n", + " qm = Polynomial(q_coeff)\n", + " return pm, qm\n", + "\n", + "\n", + "def get_ratnl_func_coeffs(coeff):\n", + " p_coeff = coeff[0 : RATNL_ORDER + 1]\n", + " q_coeff = np.concatenate([[1], coeff[RATNL_ORDER + 1 :]])\n", + " return p_coeff, q_coeff\n", + "\n", + "\n", + "def ratnl_func(x, *coeff):\n", + " pm, qm = get_ratnl_func_polys(coeff)\n", + " return pm(x) / qm(x)\n", + "\n", + "\n", + "def make_poly_code(pm, head=\"\", base_indent=0):\n", + " res = \"\"\n", + " indent = base_indent\n", + " for c in pm.coef:\n", + " if c == pm.coef[-1]:\n", + " end = \"\"\n", + " else:\n", + " end = \" + x * (\"\n", + "\n", + " if c == pm.coef[0]:\n", + " _hd = head\n", + " else:\n", + " _hd = \"\"\n", + " res += \" \" * 4 * indent + f\"{_hd}{c}{end}\\n\"\n", + " if c != pm.coef[-1]:\n", + " indent += 1\n", + "\n", + " for _ in pm.coef[:-1]:\n", + " indent -= 1\n", + " res += \" \" * 4 * indent + \")\\n\"\n", + "\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be97104b", + "metadata": {}, + "outputs": [], + "source": [ + "eps = 1e-6\n", + "\n", + "poly_res = []\n", + "\n", + "for mkt in mkts:\n", + " vals = np.array([_fun(beta, 1.0, mkt) for beta in betas])\n", + "\n", + " res = scipy.optimize.curve_fit(\n", + " ratnl_func,\n", + " np.log(betas),\n", + " np.log(vals),\n", + " p0=np.ones(2 * RATNL_ORDER + 1),\n", + " full_output=True,\n", + " maxfev=100000,\n", + " ftol=eps,\n", + " xtol=eps,\n", + " )\n", + "\n", + " coeff = res[0]\n", + "\n", + " pm, qm = get_ratnl_func_coeffs(coeff)\n", + "\n", + " poly_res.append((pm[::-1], qm[::-1]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e12b5c4f", + "metadata": {}, + "outputs": [], + "source": [ + "import textwrap\n", + "\n", + "pstr = textwrap.indent(\n", + " np.array2string(np.array(poly_res), floatmode=\"unique\", threshold=100000000, separator=\", \", max_line_width=120, sign=\"+\"),\n", + " \" \",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7da4a8c5", + "metadata": {}, + "outputs": [], + "source": [ + "code = f\"\"\"\\\n", + "# START OF GENERATED CODE\n", + "# The code in this block is generated by the notebook dev/notebooks/moffat_maxk_interp.ipynb.\n", + "\n", + "MKTS_MIN = {mkts_min}\n", + "MKTS_MAX = {mkts_max}\n", + "N_MKTS = {n_mkts}\n", + "LOG_MKTS = np.log(np.logspace(jnp.log10(MKTS_MIN), jnp.log10(MKTS_MAX), N_MKTS))\n", + "\n", + "# RATIONAL_POLY_VALS is an array of 7-th order ration function approximations\n", + "# for maxk as a function of log(beta) at fixed maxk_threshold values. the coeffs\n", + "# are stored from highest degree to lowest. The shape of the array is\n", + "# ({n_mkts}, 2, {RATNL_ORDER + 1}).\n", + "# fmt: off\n", + "RATIONAL_POLY_VALS = np.array(\n", + "{pstr},\n", + " dtype=np.float64,\n", + ")\n", + "# fmt: on\n", + "\n", + "\n", + "@jax.jit\n", + "def _moffat_maxk(beta, maxk_threshold, r0):\n", + " log_beta = jnp.log(beta)\n", + " log_maxk_threshold = jnp.log(maxk_threshold)\n", + " maxk_vals = jnp.array(\n", + " [\n", + " jnp.exp(\n", + " jnp.polyval(RATIONAL_POLY_VALS[i, 0, :], log_beta)\n", + " / jnp.polyval(RATIONAL_POLY_VALS[i, 1, :], log_beta)\n", + " )\n", + " for i in range(N_MKTS)\n", + " ]\n", + " )\n", + " coeffs = akima_interp_coeffs(LOG_MKTS, maxk_vals)\n", + " return akima_interp(log_maxk_threshold, LOG_MKTS, maxk_vals, coeffs) / r0\n", + "\n", + "\n", + "# END OF GENERATED CODE\n", + "\"\"\"\n", + "\n", + "print(code)" + ] + }, + { + "cell_type": "markdown", + "id": "c252b268", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6072c527", + "metadata": {}, + "outputs": [], + "source": [ + "from jax_galsim.core.interpolate import akima_interp, akima_interp_coeffs\n", + "\n", + "exec(code)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e87bba5", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(seed=10)\n", + "n_test = 100000\n", + "\n", + "betas = rng.uniform(low=1.1 + 1e-6, high=100, size=n_test)\n", + "maxk_thresholds = 10**rng.uniform(low=-12, high=-1, size=n_test)\n", + "apprx = np.array([\n", + " _moffat_maxk(betas[i], maxk_thresholds[i], 1.0)\n", + " for i in range(n_test)\n", + "])\n", + "true = np.array([\n", + " _fun(betas[i], 1.0, maxk_thresholds[i])\n", + " for i in range(n_test)\n", + "])\n", + "\n", + "# plt.hist(true / apprx - 1, bins=25, log=True)\n", + "# ax = plt.gca()\n", + "# ax.set_xlabel(\"fractional error in maxk approx.\")\n", + "# ax.set_ylabel(\"# of points\")\n", + "\n", + "fig, ax = plt.subplots()\n", + "hb = ax.hexbin(\n", + " betas,\n", + " np.log10(maxk_thresholds),\n", + " C=np.log10(np.abs(apprx-true)),\n", + " extent=(1.1, 100, -12, -1),\n", + " gridsize=50,\n", + " vmin=-7,\n", + " vmax=-2\n", + ")\n", + "ax.set_xlim(1.1, 100)\n", + "ax.set_ylim(-12, -1)\n", + "ax.set_xlabel(\"beta\")\n", + "ax.set_ylabel(\"log10(maxk_threshold)\")\n", + "fig.colorbar(hb, label=\"log10(|approx - true|)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b396f567", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "82279445", + "metadata": {}, + "source": [ + "## Code to Minimize w/ SGD" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df2c1691", + "metadata": {}, + "outputs": [], + "source": [ + "import optax\n", + "import tqdm\n", + "\n", + "optimizer = None\n", + "learning_rate = 1e-5\n", + "opt_attr = \"adan\"\n", + "\n", + "if optimizer is None:\n", + " optimizer = getattr(optax, opt_attr)(learning_rate)\n", + " opt_state = optimizer.init(coeffs)\n", + "\n", + " @jax.jit\n", + " def _update_func(coeffs, opt_state):\n", + " loss, grads = _vag_loss(coeffs, lnb, lnmkts, lnmaxks)\n", + " updates, opt_state = optimizer.update(grads, opt_state, params=coeffs)\n", + " coeffs = optax.apply_updates(coeffs, updates)\n", + " return coeffs, opt_state, loss\n", + "\n", + "lnb = jnp.log(betas)\n", + "lnmkts = jnp.log(mkts)\n", + "lnmaxks = jnp.log(maxks)\n", + "\n", + "loss, _ = _vag_loss(coeffs, lnb, lnmkts, lnmaxks)\n", + "print(\"initial loss:\", jnp.power(loss, 1.0 / lval), flush=True)\n", + "\n", + "prev_loss = None\n", + "n_epoch = 200_000\n", + "ditr = 1000\n", + "\n", + "with tqdm.trange(n_epoch) as pbar:\n", + " for i in pbar:\n", + " coeffs, opt_state, loss = _update_func(coeffs, opt_state)\n", + "\n", + " if i % ditr == 0:\n", + " if prev_loss is not None:\n", + " dloss = (jnp.power(loss, 1 / lval) - jnp.power(prev_loss, 1 / lval))\n", + " pbar.set_description(f\"loss: {jnp.power(loss, 1 / lval):10.4e} ({dloss:+9.2e} delta)\")\n", + " else:\n", + " pbar.set_description(f\"loss: {jnp.power(loss, 1 / lval):10.4e} (--------- delta)\")\n", + "\n", + " prev_loss = loss\n", + "\n", + "print(f\"{i:04d}: {jnp.power(loss, 1 / lval):10.4e}\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2f152994", + "metadata": {}, + "source": [ + "## Interpolation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58a1ba12", + "metadata": {}, + "outputs": [], + "source": [ + "# this is the function we are interpolating\n", + "def _fun(beta, mkt):\n", + " return galsim.Moffat(\n", + " beta,\n", + " scale_radius=1.0\n", + " ).withGSParams(maxk_threshold=mkt).maxk\n", + "\n", + "\n", + "beta_min = 1.1 + 1e-6\n", + "beta_max = 100\n", + "n_beta = 500 # used to fit the rational function approx\n", + "mkts_min = 1e-12\n", + "mkts_max = 0.1\n", + "n_mkts = 200 # we build this many rational function approximations\n", + "\n", + "_betas = np.logspace(np.log10(beta_min), np.log10(beta_max), n_beta)\n", + "_mkts = np.logspace(jnp.log10(mkts_min), jnp.log10(mkts_max), n_mkts)\n", + "\n", + "betas = []\n", + "mkts = []\n", + "maxks = []\n", + "for beta in _betas:\n", + " for mkt in _mkts:\n", + " betas.append(beta)\n", + " mkts.append(mkt)\n", + " maxks.append(\n", + " _fun(beta, mkt)\n", + " )\n", + "betas = jnp.array(betas)\n", + "mkts = jnp.array(mkts)\n", + "maxks = jnp.array(maxks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c48c5ffb", + "metadata": {}, + "outputs": [], + "source": [ + "import interpax\n", + "\n", + "rng = np.random.default_rng()\n", + "n_test = 100000\n", + "\n", + "tbetas = 10**rng.uniform(low=np.log10(1.1 + 1e-6), high=np.log10(100), size=n_test)\n", + "tmaxk_thresholds = 10**rng.uniform(low=-12, high=-1, size=n_test)\n", + "apprx = jnp.exp(\n", + " interpax.interp2d(\n", + " jnp.log(tbetas),\n", + " jnp.log(tmaxk_thresholds),\n", + " jnp.log(_betas),\n", + " jnp.log(_mkts),\n", + " jnp.log(maxks).reshape(n_beta, n_mkts),\n", + " method=\"akima\"\n", + " )\n", + ")\n", + "true = np.array([\n", + " _fun(tbetas[i], tmaxk_thresholds[i])\n", + " for i in range(n_test)\n", + "])\n", + "eps = np.abs(apprx / true - 1)\n", + "c_func = np.max\n", + "eps_label = \"max|approx/true - 1|\"\n", + "print(c_func(eps))\n", + "\n", + "msk = tmaxk_thresholds <= 0.03\n", + "print(c_func(eps[msk]))\n", + "# plt.hist(true / apprx - 1, bins=25, log=True)\n", + "# ax = plt.gca()\n", + "# ax.set_xlabel(\"fractional error in maxk approx.\")\n", + "# ax.set_ylabel(\"# of points\")\n", + "\n", + "fig, ax = plt.subplots()\n", + "hb = ax.hexbin(\n", + " np.log10(tbetas),\n", + " np.log10(tmaxk_thresholds),\n", + " C=eps,\n", + " reduce_C_function=c_func,\n", + " extent=(np.log10(1.1), np.log10(100), -12, -1),\n", + " gridsize=50,\n", + " bins=\"log\",\n", + ")\n", + "ax.set_xlim(np.log10(1.1), np.log10(100))\n", + "ax.set_ylim(-12, -1)\n", + "ax.set_xlabel(\"log10(beta)\")\n", + "ax.set_ylabel(\"log10(maxk_threshold)\")\n", + "fig.colorbar(hb, label=eps_label)" + ] + }, + { + "cell_type": "markdown", + "id": "618a5c3e", + "metadata": {}, + "source": [ + "## Symbolic Regression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91fec48d", + "metadata": {}, + "outputs": [], + "source": [ + "# this is the function we are interpolating\n", + "def _fun(beta, mkt):\n", + " return galsim.Moffat(\n", + " beta,\n", + " scale_radius=1.0\n", + " ).withGSParams(maxk_threshold=mkt).maxk\n", + "\n", + "\n", + "beta_min = 1.1 + 1e-6\n", + "beta_max = 100\n", + "n_beta = 50 # used to fit the rational function approx\n", + "mkts_min = 1e-12\n", + "mkts_max = 0.1\n", + "n_mkts = 50 # we build this many rational function approximations\n", + "\n", + "_betas = np.logspace(np.log10(beta_min), np.log10(beta_max), n_beta)\n", + "_mkts = np.logspace(jnp.log10(mkts_min), jnp.log10(mkts_max), n_mkts)\n", + "\n", + "betas = []\n", + "mkts = []\n", + "maxks = []\n", + "for beta in _betas:\n", + " for mkt in _mkts:\n", + " betas.append(beta)\n", + " mkts.append(mkt)\n", + " maxks.append(\n", + " _fun(beta, mkt)\n", + " )\n", + "betas = jnp.array(betas)\n", + "mkts = jnp.array(mkts)\n", + "maxks = jnp.array(maxks)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e5298f2", + "metadata": {}, + "outputs": [], + "source": [ + "from pysr import PySRRegressor\n", + "\n", + "X = np.stack([np.log(betas), np.log(mkts)], axis=1)\n", + "y = np.log(maxks)\n", + "\n", + "model = PySRRegressor(\n", + " maxsize=50,\n", + " niterations=100,\n", + " binary_operators=[\"+\", \"*\", \"/\", \"^\"],\n", + " constraints={'^': (-1, 1)},\n", + " elementwise_loss=\"loss(prediction, target) = abs(prediction - target)\",\n", + " model_selection='accuracy',\n", + ")\n", + "\n", + "model.fit(X, y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdbecb27", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng()\n", + "n_test = 100000\n", + "\n", + "tbetas = 10**rng.uniform(low=np.log10(1.1 + 1e-6), high=np.log10(100), size=n_test)\n", + "tmaxk_thresholds = 10**rng.uniform(low=-12, high=-1, size=n_test)\n", + "tX = np.stack([np.log(tbetas), np.log(tmaxk_thresholds)], axis=1)\n", + "\n", + "apprx = np.exp(model.predict(tX))\n", + "true = np.array([\n", + " _fun(tbetas[i], tmaxk_thresholds[i])\n", + " for i in range(n_test)\n", + "])\n", + "eps = np.abs(apprx / true - 1)\n", + "c_func = np.max\n", + "eps_label = \"max|approx/true - 1|\"\n", + "print(c_func(eps))\n", + "\n", + "fig, ax = plt.subplots()\n", + "hb = ax.hexbin(\n", + " np.log10(tbetas),\n", + " np.log10(tmaxk_thresholds),\n", + " C=eps,\n", + " reduce_C_function=c_func,\n", + " extent=(np.log10(1.1), np.log10(100), -12, -1),\n", + " gridsize=50,\n", + " bins=\"log\",\n", + ")\n", + "ax.set_xlim(np.log10(1.1), np.log10(100))\n", + "ax.set_ylim(-12, -1)\n", + "ax.set_xlabel(\"log10(beta)\")\n", + "ax.set_ylabel(\"log10(maxk_threshold)\")\n", + "fig.colorbar(hb, label=eps_label)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b549f6cd", + "metadata": {}, + "outputs": [], + "source": [ + "x = jnp.log(_betas * 1.01)\n", + "y = jnp.ones_like(x) * jnp.log(1e-4)\n", + "true = jnp.array([\n", + " _fun(jnp.exp(x[i]), jnp.exp(y[i]))\n", + " for i in range(x.shape[0])\n", + "])\n", + "\n", + "approx = (\n", + " interpax.interp2d(\n", + " x,\n", + " y,\n", + " jnp.log(_betas),\n", + " jnp.log(_mkts),\n", + " jnp.log(maxks).reshape(n_beta, n_mkts),\n", + " method=\"akima\"\n", + " )\n", + ")\n", + "\n", + "plt.plot(\n", + " x,\n", + " approx,\n", + ")\n", + "plt.plot(\n", + " x,\n", + " jnp.log(true),\n", + ")\n", + "\n", + "# plt.plot(x, jnp.log(true))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddef78ae", + "metadata": {}, + "outputs": [], + "source": [ + "interpax.interp2d?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb2be7de", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax-galsim", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index fe596398..04a94248 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,13 +1,18 @@ +from functools import partial + import galsim as _galsim import jax import jax.numpy as jnp -from jax.tree_util import Partial as partial +import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.bessel import j0, kv +from jax_galsim.bessel import kv from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue -from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral -from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements +from jax_galsim.core.utils import ( + ensure_hashable, + has_tracers, + implements, +) from jax_galsim.gsobject import GSObject from jax_galsim.position import PositionD from jax_galsim.random import UniformDeviate @@ -19,65 +24,21 @@ def _Knu(nu, x): return kv(nu, x) -@jax.jit -def _MoffatIntegrant(x, k, beta): - """For truncated Hankel used in truncated Moffat""" - return x * jnp.power(1 + x**2, -beta) * j0(k * x) - - -def _xMoffatIntegrant(k, beta, rmax, quad): - return quad_integral(partial(_MoffatIntegrant, k=k, beta=beta), 0.0, rmax, quad) - +@implements( + _galsim.Moffat, + lax_description="""\ +The LAX version of the Moffat profile -@jax.jit -def _hankel(k, beta, rmax): - quad = ClenshawCurtisQuad.init(150) - g = partial(_xMoffatIntegrant, beta=beta, rmax=rmax, quad=quad) - return jax.vmap(g)(k) - - -@jax.jit -def _bodymi(xcur, rm, re, beta): - x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 - x = jnp.power(x, 1 / (1 - beta)) - x = jnp.sqrt(x - 1) - return re / x - - -@partial(jax.jit, static_argnames=("nitr",)) -def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100): - """ - The basic equation that is relevant here is the flux of a Moffat profile - out to some radius. - - flux(R) = int( (1+r^2/rd^2 )^(-beta) 2pi r dr, r=0..R ) - = (pi rd^2 / (beta-1)) (1 - (1+R^2/rd^2)^(1-beta) ) - For now, we can ignore the first factor. We call the second factor fluxfactor below, - or in this function f(R). - We are given two values of R for which we know that the ratio of their fluxes is 1/2: - f(re) = 0.5 * f(rm) - - nb1. rd aka r0 aka the scale radius - nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf. - BUT the case rm==0 is already done, so HERE rm != 0 - """ - - # fix loop iteration is faster and reach eps=1e-6 (single precision) - def body(i, xcur): - x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 - x = jnp.power(x, 1 / (1 - beta)) - x = jnp.sqrt(x - 1) - return re / x - - return jax.lax.fori_loop(0, 100, body, re, unroll=True) - - -@implements(_galsim.Moffat) +- does not support truncation or beta < 1.1 +- does not support gsparams.maxk_thresholds > 0.1 +""", +) @register_pytree_node_class class Moffat(GSObject): _is_axisymmetric = True _is_analytic_x = True _is_analytic_k = True + _has_hard_edges = False def __init__( self, @@ -93,6 +54,15 @@ def __init__( # let define beta_thr a threshold to trigger the truncature self._beta_thr = 1.1 + if has_tracers(trunc) or ( + isinstance(trunc, (np.ndarray, float, jnp.ndarray, int)) + and np.any(trunc != 0) + ): + raise ValueError( + "JAX-GalSim does not support truncated Moffat profiles " + f"(got trunc={repr(trunc)}, always pass the constant 0.0)!" + ) + # Parse the radius options if half_light_radius is not None: if scale_radius is not None or fwhm is not None: @@ -106,14 +76,9 @@ def __init__( super().__init__( beta=beta, scale_radius=( - jax.lax.select( - trunc > 0, - _MoffatCalculateSRFromHLR(half_light_radius, trunc, beta), - half_light_radius - / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0), - ) + half_light_radius + / jnp.sqrt(jnp.power(0.5, 1.0 / (1.0 - beta)) - 1.0) ), - trunc=trunc, flux=flux, gsparams=gsparams, ) @@ -129,7 +94,6 @@ def __init__( super().__init__( beta=beta, scale_radius=fwhm / (2.0 * jnp.sqrt(2.0 ** (1.0 / beta) - 1.0)), - trunc=trunc, flux=flux, gsparams=gsparams, ) @@ -144,11 +108,15 @@ def __init__( super().__init__( beta=beta, scale_radius=scale_radius, - trunc=trunc, flux=flux, gsparams=gsparams, ) + if self.gsparams.maxk_threshold > 0.1: + raise ValueError( + "JAX-GalSim Moffat profiles do not support gsparams.maxk_threshold values greater than 0.1!" + ) + @property @implements(_galsim.moffat.Moffat.beta) def beta(self): @@ -157,7 +125,7 @@ def beta(self): @property @implements(_galsim.moffat.Moffat.trunc) def trunc(self): - return self._params["trunc"] + return 0.0 @property @implements(_galsim.moffat.Moffat.scale_radius) @@ -183,12 +151,8 @@ def _inv_r0_sq(self): @property def _maxRrD(self): """maxR/rd ; fluxFactor Integral of total flux in terms of 'rD' units.""" - return jax.lax.select( - self.trunc > 0.0, - self.trunc * self._inv_r0, - jnp.sqrt( - jnp.power(self.gsparams.xvalue_accuracy, 1.0 / (1.0 - self.beta)) - 1.0 - ), + return jnp.sqrt( + jnp.power(self.gsparams.xvalue_accuracy, 1.0 / (1.0 - self.beta)) - 1.0 ) @property @@ -202,11 +166,7 @@ def _maxRrD_sq(self): @property def _fluxFactor(self): - return jax.lax.select( - self.trunc > 0.0, - 1.0 - jnp.power(1 + self._maxRrD * self._maxRrD, (1.0 - self.beta)), - 1.0, - ) + return 1.0 @property @implements(_galsim.moffat.Moffat.half_light_radius) @@ -288,7 +248,16 @@ def _maxk_func(self, k): @property @jax.jit def _maxk(self): - return bisect_for_root(partial(self._maxk_func), 0.0, 1e5, niter=75) + return ( + jnp.exp( + _logmaxk_psuedo_pade_approx( + jnp.atleast_1d(jnp.log(self.beta)), + jnp.atleast_1d(jnp.log(self.gsparams.maxk_threshold)), + RATIONAL_POLY_VALS, + ) + )[0] + / self._r0 + ) @property def _stepk_lowbeta(self): @@ -320,10 +289,6 @@ def _stepk(self): self.beta <= self._beta_thr, self._stepk_lowbeta, self._stepk_highbeta ) - @property - def _has_hard_edges(self): - return self.trunc != 0.0 - @property def _max_sb(self): return self._norm @@ -344,14 +309,6 @@ def _kValue_untrunc(self, k): self._knorm, ) - def _kValue_trunc(self, k): - """Truncated version of _kValue""" - return jnp.where( - k <= 50.0, - self._knorm * self._prefactor * _hankel(k, self.beta, self._maxRrD), - 0.0, - ) - @jax.jit def _kValue(self, kpos): """computation of the Moffat response in k-space with switch of truncated/untracated case @@ -360,12 +317,7 @@ def _kValue(self, kpos): k = jnp.sqrt((kpos.x**2 + kpos.y**2) * self._r0_sq) out_shape = jnp.shape(k) k = jnp.atleast_1d(k) - res = jax.lax.cond( - self.trunc > 0, - lambda x: self._kValue_trunc(x), - lambda x: self._kValue_untrunc(x), - k, - ) + res = self._kValue_untrunc(k) return res.reshape(out_shape) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): @@ -405,3 +357,137 @@ def _shoot(self, photons, rng): photons.x = r * cost photons.y = r * sint photons.flux = self.flux / photons.size() + + +# order of rational function in log(maxk_threshold), log(beta) +PADE_ORDERS = [9, 11] + +N_PARAMS_MKTS = PADE_ORDERS[0] * 2 - 1 +N_PARAMS_BETA = PADE_ORDERS[1] * 2 - 1 +N_PARAMS = N_PARAMS_MKTS * N_PARAMS_BETA + +LOG_BETA_MIN = np.log(1.1 + 1e-6) +LOG_BETA_MAX = np.log(100) +LOG_MKTS_MIN = np.log(1e-12) +LOG_MKTS_MAX = np.log(0.1) + + +def _pade_func(coeffs, x): + order = (coeffs.shape[0] - 1) // 2 + p = jnp.polyval(coeffs[:order], x) + q = jnp.polyval( + jnp.concatenate([coeffs[order:], jnp.ones(1)], axis=0), + x, + ) + return p / q + + +@jax.jit +@partial(jax.vmap, in_axes=(0, 0, None)) +def _logmaxk_psuedo_pade_approx(log_beta, log_mkts, coeffs): + log_beta = (log_beta - LOG_BETA_MIN) / (LOG_BETA_MAX - LOG_BETA_MIN) + log_mkts = (log_mkts - LOG_MKTS_MIN) / (LOG_MKTS_MAX - LOG_MKTS_MIN) + coeffs = coeffs.reshape(N_PARAMS_MKTS, N_PARAMS_BETA) + pqvals = jax.vmap(_pade_func, in_axes=(0, None))(coeffs, log_beta) + return _pade_func(pqvals, log_mkts) + + +# START OF GENERATED CODE +# RATIONAL_POLY_VALS is the array of rational function +# polynomial coefficients that define the approximation +# fmt: off +RATIONAL_POLY_VALS = np.array( + [+4.0377541235164999e-01, +9.8573979309710097e-02, -8.8368998636191423e-02, -1.4404058874465467e-01, + -1.8722517103965541e-01, -2.3941575929900452e-01, +1.9477051520522798e-01, +2.5174893659382911e+00, + +6.9802569884628065e+00, +2.9528987005934546e+00, -9.1832169346703629e-01, +4.9286238397646115e-01, + +1.0005636301164393e+00, +7.0392335018807339e-01, -1.4054536940247431e-01, -8.5218622931551169e-01, + -6.7621128905401928e-01, -2.9537613003541291e-01, -1.2854667245219107e+00, +4.0189909948806379e+00, + +2.1850570724764290e-01, -4.2274342642823717e-02, -2.2450115304011090e-01, -3.6887180044787632e-01, + -4.3603364254842064e-01, -4.9256905759091729e-01, -6.6398873219847576e-01, -5.9558712629992638e-01, + +1.1837909921308221e+00, -4.6138529248538136e+00, +1.3450469324602885e+00, +4.9458187528754460e-01, + +6.0273293491308400e-01, +6.3962989463396580e-01, +6.1582284694766809e-01, +5.6781212563865269e-01, + +5.5125443702360621e-01, +5.9619266285882933e-01, +5.4745878470377551e-01, -1.2351160388207373e-01, + -7.0107993183023398e-01, +9.5935634414374444e+00, -8.7283833589376003e-01, -1.0255510475210847e+00, + -1.0929211542319643e+00, -9.1020529616651413e-01, -6.0023870397444312e-01, -6.5507195560618903e-01, + -1.0722148851554705e+00, +7.3885075419617319e-01, -4.0294754110685673e+00, -7.8297431020829418e+00, + -6.6474640833734255e-01, +7.7380626162435606e-01, +6.7528838101327693e-01, -1.7804564435440101e-01, + -6.2398848498466120e-01, +5.7643808537703685e-02, +1.1923835092489283e+00, +4.9233103375211917e-01, + -2.9426949991492894e+00, +2.6998628292637314e+00, -1.4093483711909682e-01, +4.0372810590505115e-01, + +2.4647318964152784e-01, +1.1417136722445211e-02, -2.9679657567820844e-01, -6.8704346690711138e-01, + -1.2937569996243186e+00, -1.7597197158870368e+00, +1.1143935878967266e+00, -3.5782107847819544e+00, + +4.1199620228132250e-01, +1.1970601681985499e-01, +2.0114099603243733e-02, -6.5434136381943390e-03, + +8.5736126942115937e-02, +2.7261328153414083e-01, +4.8960392670307473e-01, +9.7834457064666291e-01, + +2.7168180915113544e+00, +5.8280454184534474e+00, +4.3537429070024833e+00, -3.6175915101171152e-01, + +3.4258789295460745e-01, +1.5673518908599102e-01, +1.0666667233357530e-01, +2.7402242443574487e-01, + +5.7450989722739154e-01, +6.5865111299117973e-01, +1.0619471254256168e+00, +5.0281107124390561e+00, + +4.3769392931878642e-01, +1.3797757705398774e+00, -4.0642060782490819e-01, -4.8107106379014103e-01, + -4.6545384216554275e-01, -3.2668532926704019e-01, +1.8685658280818033e-02, +7.3138884860456699e-01, + +2.0590374557083453e+00, +4.2319105687511795e+00, +6.3443456809823573e+00, +3.1050925854801519e+00, + +2.4566683840045755e-01, -1.4272027049584994e+00, -1.2769555499840839e+00, -1.1702505921157993e+00, + -1.0897538422282065e+00, -9.9570783824994358e-01, -7.6663044360462829e-01, +2.4280904621074063e-02, + -6.2825977906654862e-01, +4.7748459315512886e+00, -3.4479672270932647e-02, +1.5810745647692885e+00, + +1.0592324253022773e+00, +6.5192719848377234e-01, +3.4733038965947277e-01, +1.3710269537809816e-01, + +1.7457176684719813e-02, -1.0579281455920138e-02, +1.1577313256892235e-01, +6.3498619537927703e-01, + -3.5636713339300936e+00, +2.7473701713859548e+01, -8.1471628900480619e-01, -6.2163518452759314e-01, + -6.3079895204862313e-01, -8.7260751681315663e-01, -7.9737042513719381e-01, +4.1369737806800105e-01, + +2.2584145837329177e+00, +6.9141061101150347e-01, +1.2238567163847529e+01, -1.4881240397019797e+00, + +2.5346008210149407e+00, +3.4144215256062256e+00, +3.9404077941580811e+00, +4.0697631657277480e+00, + +4.0020162458839401e+00, +3.6956516567833777e+00, +1.5471753010948357e+00, -4.5096733243409268e+00, + +6.2197356650587894e+00, +1.3720454142006250e+01, +7.1789808873106802e+00, +1.2878469385318689e+00, + +2.5159257373932378e+00, +1.9825772948334588e+00, -2.0124795293398445e-01, +7.3072182794828144e-01, + +4.3788775079638773e+00, +2.4851710519494072e+00, +3.9571444680295214e+00, +7.9890557503685585e+01, + +3.2163012590954909e+00, -1.3055299196086150e+00, +2.4047152551806925e+00, +1.1869982622608657e+00, + -2.6782591519389305e+00, -2.1351447178930658e+00, +4.7064345711296918e+00, +3.7186007845229394e+00, + -6.2948721775023833e+00, +1.9222396464692157e+00, -5.2929135394148226e+00, +2.4064161945473813e+01, + +4.6699620000159325e-01, -2.8824483958610392e-01, -5.2987053311717192e-01, -2.5266220838257281e-01, + +3.3359802997166604e-01, +3.4368754731009948e-01, -1.7259824215429349e+00, +1.1875712562662768e+00, + -6.4125482087187963e-03, -3.0986991304825956e-01, +1.7012128546438319e+00, +1.4041294008838343e+00, + +1.1944956562005928e+00, +1.0382744098886449e+00, +9.6541559432919521e-01, +1.0484758969468237e+00, + +1.1394063481569940e+00, +3.4269611648067827e-01, -1.9133248804297314e+00, +1.0420905503308806e+01, + +2.1179961925447843e+00, -1.1710420652576292e-01, +1.4435488943498453e-01, +3.1217816268483334e-01, + +4.9427150925041208e-01, +9.8917710459391761e-01, +1.6007479958753712e+00, +1.7741348266216928e-01, + -2.3968752594096783e+00, +2.7544604324111326e+00, +1.3892657174473839e+00, +8.7578407666511837e-01, + -3.5073692239590564e-03, -5.2224301780733773e-01, -3.6262978387803235e-01, +4.3774916805850367e-01, + +1.2253499947430535e+00, +1.1500998128760624e+00, +1.1268239524120995e+00, +5.8797742250985641e+00, + -6.6184466647212081e-01, +5.6721441744825265e+00, -5.6157472743129566e-01, -4.3196098885679984e-01, + -3.7715669435929966e-01, -2.5883322619683380e-01, +4.1549169648999079e-01, +1.9625009395677049e+00, + +2.4486960792180841e+00, -1.7975615666661942e+00, +1.9835147875960215e+00, -1.2784851851614869e+00, + +5.8984320559814285e-01, +7.4229958232406057e-01, +7.4162021876167439e-01, +4.0929957187021387e-01, + -2.0185674746649859e-01, -5.9468903144747898e-01, +2.5498400735845517e-01, +3.1515519885019567e+00, + +4.4864718840647067e+00, -7.0240004332775219e-01, +4.6010907702840802e+00, -7.1516760160570758e-01, + -6.0415971544731641e-02, +2.8908143240634510e-01, +2.9629615720773861e-01, +1.4493278876410129e-01, + +1.4899706167753710e-01, -2.5624565717364778e-01, -5.4356503353058161e-01, +2.1871925270524719e+00, + +1.2543898090398127e+00, -1.0185718997820286e+00, -6.7596194102120388e-01, -3.8041180369333510e-01, + -1.0978898835205055e-01, +1.5099465724664984e-01, +4.2016694973194896e-01, +7.3351664383189963e-01, + +1.0551205409739042e+00, +9.1024490373168820e-01, -6.3763153414131846e-01, +6.8893802543718774e+00, + -4.5971718133357287e-01, -6.9284977440148543e-01, -8.6796268450746850e-01, -1.0950384320348023e+00, + -1.1966360910997955e+00, -6.7707548888538704e-01, -1.3280469312410206e-01, -3.2278394553456100e+00, + -1.4343263133851598e+00, -2.6555777667727885e+00, +8.7311541655505295e-02, +2.4177522001854126e-01, + +4.5981899324422781e-01, +5.3346538819879574e-01, +5.2250263059044200e-01, +6.7501024530644282e-01, + +7.0657219206455624e-01, -3.9412109509589072e-01, -1.0587888360932021e+00, +4.2457441911909166e+00, + -2.3140489272869411e-01, -2.2861844257468347e+00, -1.3329768983397348e+00, -5.3454494370389308e-02, + +1.0612438826511756e+00, +1.6464548555599556e+00, +2.1981962147620164e+00, +2.1218914660603314e+00, + -3.7578534604190112e+00, +8.2081595576296251e-01, +4.3296094785753092e-01, -1.0816112906194235e+00, + -3.6484327546078127e-01, +6.0909104075857867e-01, +1.6358031456615689e+00, +2.3229646816108964e+00, + +2.2118727909599141e+00, +1.3469163868086866e+00, +1.2994214055102531e+00, +4.3813763657608504e+00, + +4.9975532252885566e+00, +1.3903679245689864e+00, +4.8414327285872227e-01, +7.0224530774924843e-01, + +1.0803339536568390e+00, +1.3373166223466377e+00, +9.5760508245182152e-01, -6.1863962466179623e-01, + -3.3019470393120369e+00, -5.9049244231580120e+00, +9.2207101115694030e+00, +5.3444749163645511e-01, + +1.6108800350605517e+00, +1.3165642014926429e+00, +9.0587909506896747e-01, +4.3594752341224680e-01, + +3.6916356990756842e-02, -8.2038474944938655e-02, +3.2166115760007336e-01, +1.4564827647673204e+00, + +2.8544422054339025e+00, +2.6552473716396197e+00, +1.4919279500221457e+01, -2.5271129085382765e-01, + -8.8992556214129459e-01, -6.9246476700165105e-01, +1.8480538936798771e-01, +1.2640503721685272e+00, + +1.4458856404463873e+00, -3.1286650751120088e-01, -3.3813513244865669e+00, +3.7081909895009160e+00, + +9.5461952229348529e-02, +1.8316337726387808e+00, +1.9081909769641892e+00, +1.7882437073633883e+00, + +1.4886753215997492e+00, +1.0530609759478995e+00, +6.0216177299874951e-01, +4.6111555601369725e-01, + +1.2362626622408777e+00, +2.6589804765640928e+00, +3.0218939157874125e+00, +9.6956758605654425e+00, + +3.0084649428700154e+00, -1.6285708118911306e+00, -3.3636750849697350e+00, -1.3583162670326305e+00, + +2.5494015893551629e+00, +3.0980878905749085e+00, -3.0085452437932623e+00, +1.5124259940551708e+00, + +5.5515261212099150e+00, -1.5205550733351489e-01, -5.5411686182748421e-01, +4.9613391570372412e-01, + +1.5851717222454447e+00, +2.4378677019392678e+00, +2.4074271019318774e+00, +1.3514419949232819e+00, + +1.7063949677974886e+00, +7.4087472372617151e+00, +1.7585074429003971e+00, +1.3341690752552770e+01, + +7.0750414351312099e+00], + dtype=np.float64, +) +# fmt: on +# END OF GENERATED CODE diff --git a/tests/GalSim b/tests/GalSim index 3251a393..04918b11 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 3251a393bf7ea94fe9ccda3508bc7db722eca1cf +Subproject commit 04918b118926eafc01ec9403b8afed29fb918d51 diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 2d532570..2f96d201 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -145,3 +145,4 @@ allowed_failures: - "module 'jax_galsim' has no attribute 'LookupTable2D'" - "module 'jax_galsim' has no attribute 'zernike'" - "Invalid TFORM4: 1PE(7)" # see https://github.com/astropy/astropy/issues/15477 + - "JAX-GalSim does not support truncated Moffat" diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 02a1dce2..303631bd 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -292,7 +292,7 @@ def test_benchmark_invert_ab_noraise(benchmark, kind): def _run_benchmark_moffat_init(): - return jgs.Moffat(beta=2.5, half_light_radius=0.6, trunc=1.2).scale_radius + return jgs.Moffat(beta=2.5, half_light_radius=0.6).scale_radius @pytest.mark.parametrize("kind", ["run"]) diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 2f4c44e3..295c137c 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -55,13 +55,11 @@ def test_eq(self, other): def test_moffat_jitting(): # Test Moffat objects - fwhm_backwards_compatible = 1.3178976627539716 objects = [ galsim.Moffat(beta=5.0, flux=0.2, scale_radius=1.0, gsparams=gsparams), galsim.Moffat( beta=2.0, half_light_radius=1.0, - trunc=5 * fwhm_backwards_compatible, flux=1.0, gsparams=gsparams, ), diff --git a/tests/jax/test_moffat_comp_galsim.py b/tests/jax/test_moffat_comp_galsim.py index 04376cd3..4b8549c4 100644 --- a/tests/jax/test_moffat_comp_galsim.py +++ b/tests/jax/test_moffat_comp_galsim.py @@ -22,9 +22,6 @@ def test_moffat_comp_galsim_maxk(): galsim.Moffat(beta=1.22, scale_radius=23, flux=23), galsim.Moffat(beta=3.6, scale_radius=2, flux=23), galsim.Moffat(beta=12.9, scale_radius=5, flux=23), - galsim.Moffat(beta=1.22, scale_radius=7, flux=23, trunc=30), - galsim.Moffat(beta=3.6, scale_radius=9, flux=23, trunc=50), - galsim.Moffat(beta=12.9, scale_radius=11, flux=23, trunc=1000), ] threshs = [1.0e-3, 1.0e-4, 0.03] print("\nbeta \t trunc \t thresh \t kValue(maxk) \t jgs-maxk \t gs-maxk") diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index 28db011e..6b8d7c40 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -55,13 +55,11 @@ def test_eq(self, other): def test_moffat_vmapping(): # Test Moffat objects - fwhm_backwards_compatible = 1.3178976627539716 objects = [ galsim.Moffat(beta=5.0, flux=0.2, scale_radius=1.0), galsim.Moffat( beta=2.0, half_light_radius=1.0, - trunc=5 * fwhm_backwards_compatible, flux=1.0, ), ]