diff --git a/mapcat/alembic/versions/a1b2c3d4e5f6_add_residual_stats_to_pointing_residuals.py b/mapcat/alembic/versions/a1b2c3d4e5f6_add_residual_stats_to_pointing_residuals.py index 1aaa0a1..225cd64 100644 --- a/mapcat/alembic/versions/a1b2c3d4e5f6_add_residual_stats_to_pointing_residuals.py +++ b/mapcat/alembic/versions/a1b2c3d4e5f6_add_residual_stats_to_pointing_residuals.py @@ -6,16 +6,16 @@ """ -from typing import Sequence, Union +from collections.abc import Sequence import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "a1b2c3d4e5f6" -down_revision: Union[str, None] = "cd9bc4ba5bc0" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = "cd9bc4ba5bc0" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: diff --git a/mapcat/alembic/versions/cd9bc4ba5bc0_change_pointing_model_to_pydantic_model.py b/mapcat/alembic/versions/cd9bc4ba5bc0_change_pointing_model_to_pydantic_model.py index 1a1ed8c..1723b40 100644 --- a/mapcat/alembic/versions/cd9bc4ba5bc0_change_pointing_model_to_pydantic_model.py +++ b/mapcat/alembic/versions/cd9bc4ba5bc0_change_pointing_model_to_pydantic_model.py @@ -6,16 +6,16 @@ """ -from typing import Sequence, Union +from collections.abc import Sequence import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision: str = "cd9bc4ba5bc0" -down_revision: Union[str, None] = "6ce7e94dfd2d" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None +down_revision: str | None = "6ce7e94dfd2d" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None def upgrade() -> None: diff --git a/mapcat/database/json.py b/mapcat/database/json.py index bcf9cee..0c6f2b8 100644 --- a/mapcat/database/json.py +++ b/mapcat/database/json.py @@ -4,6 +4,7 @@ See: https://github.com/fastapi/sqlmodel/pull/1324 - this can be removed at some point. """ +from pydantic import TypeAdapter from sqlalchemy.types import JSON, TypeDecorator @@ -14,6 +15,7 @@ class JSONEncodedPydantic(TypeDecorator): def __init__(self, pydantic_class, *args, **kwargs): super().__init__(*args, **kwargs) self.pydantic_class = pydantic_class + self._adapter = TypeAdapter(pydantic_class) def process_bind_param(self, value, dialect): if value is None: @@ -23,4 +25,4 @@ def process_bind_param(self, value, dialect): def process_result_value(self, value, dialect): if value is None: return None - return self.pydantic_class.model_validate(value) + return self._adapter.validate_python(value) diff --git a/mapcat/database/pointing_residual.py b/mapcat/database/pointing_residual.py index a4925c8..7f9a14e 100644 --- a/mapcat/database/pointing_residual.py +++ b/mapcat/database/pointing_residual.py @@ -6,10 +6,13 @@ from mapcat.pointing.base import PointingModelStats from mapcat.pointing.const import ConstantPointingModel +from mapcat.pointing.poly import PolynomialPointingModel from .depth_one_map import DepthOneMapTable from .json import JSONEncodedPydantic +PointingModel = ConstantPointingModel | PolynomialPointingModel + class PointingResidualTable(SQLModel, table=True): """ @@ -21,7 +24,7 @@ class PointingResidualTable(SQLModel, table=True): ---------- map_id : int Internal ID of the depth one map - residual_model: ConstantPointingModel + residual_model: ConstantPointingModel | PolynomialPointingModel The pointing model to actually store in the database. residual_stats: PointingModelStats Statistics about the pointing residuals, such as mean and stddev of RA and Dec offsets @@ -36,10 +39,8 @@ class PointingResidualTable(SQLModel, table=True): foreign_key="depth_one_maps.map_id", ondelete="CASCADE", ) - residual_model: ConstantPointingModel = Field( - discriminator="model_type", sa_type=JSONEncodedPydantic(ConstantPointingModel) - ) + residual_model: PointingModel = Field(sa_type=JSONEncodedPydantic(PointingModel)) residual_stats: PointingModelStats | None = Field( - nullable=True, sa_type=JSONEncodedPydantic(PointingModelStats) + nullable=True, default=None, sa_type=JSONEncodedPydantic(PointingModelStats) ) map: DepthOneMapTable = Relationship(back_populates="pointing_residual") diff --git a/mapcat/pointing/const.py b/mapcat/pointing/const.py index ae1b305..50cefb0 100644 --- a/mapcat/pointing/const.py +++ b/mapcat/pointing/const.py @@ -18,7 +18,7 @@ class ConstantPointingModel(PointingModelProtocol): dec_offset: AstroPydanticQuantity[u.deg] def predict(self, pos: SkyCoord) -> SkyCoord: - ra = pos.ra + self.ra_offset - dec = pos.dec + self.dec_offset + ra = pos.ra - self.ra_offset + dec = pos.dec - self.dec_offset return SkyCoord(ra=ra, dec=dec, frame=pos.frame) diff --git a/mapcat/pointing/poly.py b/mapcat/pointing/poly.py new file mode 100644 index 0000000..38d200a --- /dev/null +++ b/mapcat/pointing/poly.py @@ -0,0 +1,203 @@ +""" +Polynomial pointing model. +""" + +from typing import Literal + +import numpy as np +from astropy import units as u +from astropy.coordinates import SkyCoord +from astropydantic import AstroPydanticUnit +from pydantic import BaseModel + +from mapcat.pointing.base import PointingModelProtocol, PointingModelStats + + +class PolynomialCoefficients(BaseModel): + """ + Coefficients for a polynomial pointing model. + for example a 2D polynomial of order 2, + coeffs={'x^2':1, + 'y^2':1, + 'xy':1, + 'y': 2, + 'x':2, + 'constant':3 + } + labels = {'x':'ra', 'y':'dec'} + """ + + coeffs: dict[str, float] + labels: dict[str, str] + unit: AstroPydanticUnit = u.deg + poly_order: int + + +class PolynomialPointingModel(PointingModelProtocol): + model_type: Literal["polynomial"] = "polynomial" + + poly_order: int + ra_model_coefficients: PolynomialCoefficients | None = None + dec_model_coefficients: PolynomialCoefficients | None = None + + ## Basis terms for 2D polynomial fit + def _poly_terms(self, x, y): + terms = [] + for i in range(self.poly_order + 1): + for j in range(self.poly_order + 1 - i): + terms.append((x**i) * (y**j)) + return np.vstack(terms).T + + def _poly_keys(self): + keys = [] + for i in range(self.poly_order + 1): + for j in range(self.poly_order + 1 - i): + keys.append(f"x^{i}y^{j}") + return keys + + def build_model( + self, + measured_positions: SkyCoord, + expected_positions: SkyCoord, + weights: tuple[list[float], list[float]] | list[float] | None = None, + ): + """ + Calculate and set the polynomial coefficients for the pointing model + using the measured and expected positions. + + weights can be provided as a tuple of (ra_weights, dec_weights) + or a single list that applies to both. + + + Raises + ------ + ValueError + If no positions are provided for model calculation. + ValueError + If the lengths of weights do not match the number of positions. + ValueError + If model coefficients have not been calculated yet when extracting coefficients. + """ + # Calculate offsets + ra_offsets = measured_positions.ra - expected_positions.ra + dec_offsets = measured_positions.dec - expected_positions.dec + n = len(ra_offsets) + if n == 0: + raise ValueError("No positions provided for model calculation.") + + # Unpack weights into ra_weights, dec_weights + if isinstance(weights, tuple): + ra_weights, dec_weights = weights + else: + ra_weights = dec_weights = weights # None or single list applied to both + + # Lots of logic to check if weights exist, etc. + # Resolve weights from uncertainties if not provided + if ra_weights is None and dec_weights is None: + # No weights provided — uniform + ra_weights = dec_weights = np.ones(n) + elif ra_weights is None: + ra_weights = dec_weights + elif dec_weights is None: + dec_weights = ra_weights + + ra_weights = np.asarray(ra_weights) + dec_weights = np.asarray(dec_weights) + assert len(ra_weights) == n, ( + "Length of ra_weights must match number of positions" + ) + assert len(dec_weights) == n, ( + "Length of dec_weights must match number of positions" + ) + + ras = measured_positions.ra.to_value(u.deg) + decs = measured_positions.dec.to_value(u.deg) + ## RA polynomial fit + A_ra = self._poly_terms(ras, decs) + y_ra = ra_offsets.to_value(u.deg) + w_ra = ra_weights + + ## Apply weights + Aw = A_ra * w_ra[:, None] + yw = y_ra * w_ra + coeffs_ra, *_ = np.linalg.lstsq(Aw, yw, rcond=None) + + ## Dec polynomial fit + A_dec = self._poly_terms(ras, decs) + y_dec = dec_offsets.to_value(u.deg) + w_dec = dec_weights + + Aw = A_dec * w_dec[:, None] + yw = y_dec * w_dec + coeffs_dec, *_ = np.linalg.lstsq(Aw, yw, rcond=None) + + ra_coeff_dict = {key: coeff for key, coeff in zip(self._poly_keys(), coeffs_ra)} + dec_coeff_dict = { + key: coeff for key, coeff in zip(self._poly_keys(), coeffs_dec) + } + + self.ra_model_coefficients = PolynomialCoefficients( + coeffs=ra_coeff_dict, + labels={"x": "ra", "y": "dec"}, + unit=u.deg, + poly_order=self.poly_order, + ) + self.dec_model_coefficients = PolynomialCoefficients( + coeffs=dec_coeff_dict, + labels={"x": "ra", "y": "dec"}, + unit=u.deg, + poly_order=self.poly_order, + ) + + def model_fn(self, x: u.Quantity, y: u.Quantity, coeffs: np.ndarray) -> u.Quantity: + x = np.atleast_1d(x.to_value(u.deg)) + y = np.atleast_1d(y.to_value(u.deg)) + T = self._poly_terms(x, y) + return (T @ coeffs) * u.deg + + def extract_coefficients(self) -> tuple[np.ndarray, np.ndarray]: + """ + Extract the coefficients from the PolynomialCoefficients dataclasss and + return them as arrays in the correct order for the model function. + + Raises + ------ + ValueError + If model coefficients have not been calculated yet. + """ + if self.ra_model_coefficients is None or self.dec_model_coefficients is None: + raise ValueError("Model coefficients have not been calculated yet.") + + ra_coeff_array = np.zeros(len(self.ra_model_coefficients.coeffs)) + for i, key in enumerate(self._poly_keys()): + ra_coeff_array[i] = self.ra_model_coefficients.coeffs.get(key, 0) + + dec_coeff_array = np.zeros(len(self.dec_model_coefficients.coeffs)) + for i, key in enumerate(self._poly_keys()): + dec_coeff_array[i] = self.dec_model_coefficients.coeffs.get(key, 0) + + return ra_coeff_array, dec_coeff_array + + def predict(self, pos: SkyCoord) -> SkyCoord: + racoeffs, deccoeffs = self.extract_coefficients() + ra_offset = self.model_fn(pos.ra, pos.dec, racoeffs) + dec_offset = self.model_fn(pos.ra, pos.dec, deccoeffs) + ra = pos.ra - ra_offset + dec = pos.dec - dec_offset + + return SkyCoord(ra=ra, dec=dec, frame=pos.frame) + + def calculate_statistics(self, positions: SkyCoord): + new_positions = self.predict(positions) + ra_residuals = (new_positions.ra - positions.ra).to(u.arcsec) + dec_residuals = (new_positions.dec - positions.dec).to(u.arcsec) + mean_ra = np.mean(ra_residuals) + mean_dec = np.mean(dec_residuals) + std_ra = np.std(ra_residuals) + std_dec = np.std(dec_residuals) + return PointingModelStats( + mean_ra_offset=mean_ra, + mean_dec_offset=mean_dec, + stddev_ra_offset=std_ra, + stddev_dec_offset=std_dec, + ) diff --git a/pyproject.toml b/pyproject.toml index 1163952..73a6fe8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ name = "mapcat" version = "0.2.2" requires-python = ">=3.10" dependencies = [ + "astropydantic", "sqlmodel", "sqlalchemy[asyncio]", "aiosqlite", diff --git a/tests/test_pointing.py b/tests/test_pointing.py index 0768218..44fdff7 100644 --- a/tests/test_pointing.py +++ b/tests/test_pointing.py @@ -2,15 +2,18 @@ Tests for the pointing residual models. """ -from astropy.units import deg +import numpy as np +from astropy import units as u +from astropy.coordinates import SkyCoord from sqlmodel import select from mapcat.database import DepthOneMapTable, PointingResidualTable from mapcat.pointing.const import ConstantPointingModel +from mapcat.pointing.poly import PolynomialPointingModel def test_add_retrieve_pointing(database_sessionmaker): - model = ConstantPointingModel(ra_offset=0.5 * deg, dec_offset=0.5 * deg) + model = ConstantPointingModel(ra_offset=0.5 * u.deg, dec_offset=0.5 * u.deg) with database_sessionmaker() as session: sample_map = DepthOneMapTable( @@ -42,5 +45,57 @@ def test_add_retrieve_pointing(database_sessionmaker): recovered_pointing = recovered_map.pointing_residual[0] residual_model = recovered_pointing.residual_model - assert residual_model.ra_offset > 0.4 * deg - assert residual_model.dec_offset > 0.4 * deg + assert residual_model.ra_offset > 0.4 * u.deg + assert residual_model.dec_offset > 0.4 * u.deg + + +def test_make_constant_pointing_model(): + model = ConstantPointingModel(ra_offset=0.5 * u.deg, dec_offset=0.5 * u.deg) + + assert model.ra_offset == 0.5 * u.deg + assert model.dec_offset == 0.5 * u.deg + + og_pos = SkyCoord(ra=10 * u.deg, dec=20 * u.deg) + offset_pos = SkyCoord(og_pos.ra + model.ra_offset, og_pos.dec + model.dec_offset) + new_pos = model.predict(offset_pos) + assert new_pos.ra == og_pos.ra + assert new_pos.dec == og_pos.dec + + +def test_make_polynomial_pointing_model(): + model = PolynomialPointingModel(poly_order=2) + ras = np.arange(0, 10, 1) * u.deg + decs = np.arange(0, 10, 1) * u.deg + offset = 1.0 * u.arcmin + slope = 0.1 * u.arcmin / u.deg + + offset_positions = SkyCoord( + ra=ras + offset + slope * ras, dec=decs + offset + slope * decs + ) + model.build_model( + measured_positions=offset_positions, expected_positions=SkyCoord(ras, decs) + ) + + assert model.ra_model_coefficients is not None + assert model.dec_model_coefficients is not None + + for i, offset_pos in enumerate(offset_positions): + predicted_pos = model.predict(offset_pos) + assert np.isclose( + predicted_pos.ra.to_value(u.arcmin), ras[i].to_value(u.arcmin), atol=0.1 + ) + assert np.isclose( + predicted_pos.dec.to_value(u.arcmin), decs[i].to_value(u.arcmin), atol=0.1 + ) + + predicted_pos = model.predict(offset_positions) + assert np.all( + np.isclose( + predicted_pos.ra.to_value(u.arcmin), ras.to_value(u.arcmin), atol=0.1 + ) + ) + assert np.all( + np.isclose( + predicted_pos.dec.to_value(u.arcmin), decs.to_value(u.arcmin), atol=0.1 + ) + )