diff --git a/CHANGELOG.md b/CHANGELOG.md index 425fabb..661e5c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Added + +- GridInterpolation transformation for xarray-esque rectilinear grids + ## 0.7.2 - 2026-06-25 ## 0.7.1 - 2026-06-25 diff --git a/README.md b/README.md index 4149a31..c2568d9 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,9 @@ All transforms are accessed under the `transformnd.transforms` subpackage. | `Scale` | | Multiply the input coordinates by constant scale factor | | `Reflection` | | Reflect coordinates about arbitrary planes | | `MapAxis` | | Rearrange axes of the input coordinates | +| `ProjectAxis` | | Insert or remove axes | | `Affine` | | Multiply augmented coordinates by an affine transformation matrix. Can represent all of the above transformations. Can be composed with matrix multiplication `aff2 @ aff1`. | +| `GridInterpolation` | | Apply an arbitrary callable (e.g. a `scipy.interpolate` interpolator) to each dimension. | | `ByDimension` | | Apply different transformations to subsets of the input coordinates' dimensions | | `MovingLeastSquares` | `movingleastsquares` | Landmark-based transformation. | | `ThinPlateSplines` | `thinplatesplines` | Landmark-based transformation. | diff --git a/src/transformnd/transforms/__init__.py b/src/transformnd/transforms/__init__.py index 37de554..e3f8500 100644 --- a/src/transformnd/transforms/__init__.py +++ b/src/transformnd/transforms/__init__.py @@ -10,12 +10,14 @@ from .vector_field import Coordinates, Displacements from .moving_least_squares import MovingLeastSquares from .thinplate import ThinPlateSplines +from .grid import GridInterpolation __all__ = [ "Affine", "Identity", "ProjectAxis", "Reflect", + "GridInterpolation", "Scale", "Translate", "MapAxis", diff --git a/src/transformnd/transforms/grid.py b/src/transformnd/transforms/grid.py new file mode 100644 index 0000000..99f82f7 --- /dev/null +++ b/src/transformnd/transforms/grid.py @@ -0,0 +1,47 @@ +from typing import Any, Protocol, Generic + +from array_api_compat import array_namespace +from transformnd.types import NDims, Spaces +from ..base import Transform +from ..types import ArrayT + + +class Interpolator(Protocol, Generic[ArrayT]): + def __call__(self, x: ArrayT) -> ArrayT: ... + + +class GridInterpolation(Transform[ArrayT]): + """Coordinate transformation which applies a callable to each dimension. + + Intended for use with instances of `scipy.interpolate` interpolators, + but any callable which takes and returns an array of floats would work. + """ + + def __init__( + self, + interpolators: list[Interpolator[ArrayT]], + *, + spaces: Spaces = Spaces(None, None), + ): + """ + Parameters + ---------- + interpolators + One callable per dimension, in order. + Each one should take and return an array of floats. + spaces + Source and target space identifiers + """ + self.interpolators = interpolators + nd = len(interpolators) + super().__init__(NDims(nd, nd), spaces=spaces) + + def apply(self, coords: Any) -> Any: + xp = array_namespace(coords) + coords = self._validate_coords(coords) + coords_t = xp.transpose(coords) + out_coords_t = xp.zeros_like(coords_t) + for in_col, out_col, interp in zip(coords_t, out_coords_t, self.interpolators): + out_col[:] = interp(in_col) + + return xp.transpose(out_coords_t) diff --git a/tests/transforms/test_grid.py b/tests/transforms/test_grid.py new file mode 100644 index 0000000..08baf08 --- /dev/null +++ b/tests/transforms/test_grid.py @@ -0,0 +1,8 @@ +from transformnd.transforms import GridInterpolation +import pytest + + +def test_simple(coords5x3): + t = GridInterpolation([lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]) + out = t.apply(coords5x3) + assert out == pytest.approx(coords5x3 * [2, 3, 4])