From a08b340528bed25302cc511988ee188e9d7c266b Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 18 Jun 2026 16:52:26 +0100 Subject: [PATCH 1/2] Grid transformation --- src/transformnd/transforms/grid.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 src/transformnd/transforms/grid.py diff --git a/src/transformnd/transforms/grid.py b/src/transformnd/transforms/grid.py new file mode 100644 index 0000000..c449682 --- /dev/null +++ b/src/transformnd/transforms/grid.py @@ -0,0 +1,29 @@ +from typing import Any, Protocol, Generic + +from array_api_compat import array_namespace +from transformnd.types import NDims, Spaces +from ..base import Transform +from ..types import ArrayT + + +class Interpolator(Protocol, Generic[ArrayT]): + def __call__(self, x: ArrayT) -> ArrayT: ... + + +class GridInterpolation(Transform): + def __init__( + self, interpolators: list[Interpolator], *, spaces: Spaces = Spaces(None, None) + ): + self.interpolators = interpolators + nd = len(interpolators) + super().__init__(NDims(nd, nd), spaces=spaces) + + def apply(self, coords: Any) -> Any: + xp = array_namespace(coords) + coords = self._validate_coords(coords) + coords_t = xp.transpose(coords) + out_coords_t = xp.zeros_like(coords_t) + for in_col, out_col, interp in zip(coords_t, out_coords_t, self.interpolators): + out_col[:] = interp(in_col) + + return xp.transpose(out_coords_t) From 1cf7da6374185888a544caba22808cc3f8602cb1 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 18 Jun 2026 17:09:58 +0100 Subject: [PATCH 2/2] grid interpolator --- CHANGELOG.md | 4 ++++ README.md | 2 ++ src/transformnd/transforms/__init__.py | 2 ++ src/transformnd/transforms/grid.py | 22 ++++++++++++++++++++-- tests/transforms/test_grid.py | 8 ++++++++ 5 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 tests/transforms/test_grid.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ba807fd..2990f2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Added + +- GridInterpolation transformation for xarray-esque rectilinear grids + ## 0.6.0 - 2026-06-18 ### Fixed diff --git a/README.md b/README.md index 3a8767e..d224d6f 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,9 @@ All transforms are accessed under the `transformnd.transforms` subpackage. | `Scale` | | Multiply the input coordinates by constant scale factor | | `Reflection` | | Reflect coordinates about arbitrary planes | | `MapAxis` | | Rearrange axes of the input coordinates | +| `ProjectAxis` | | Insert or remove axes | | `Affine` | | Multiply augmented coordinates by an affine transformation matrix. Can represent all of the above transformations. Can be composed with matrix multiplication `aff2 @ aff1`. | +| `GridInterpolation` | | Apply an arbitrary callable (e.g. a `scipy.interpolate` interpolator) to each dimension. | | `ByDimension` | | Apply different transformations to subsets of the input coordinates' dimensions | | `MovingLeastSquares` | `movingleastsquares` | Landmark-based transformation. | | `ThinPlateSplines` | `thinplatesplines` | Landmark-based transformation. | diff --git a/src/transformnd/transforms/__init__.py b/src/transformnd/transforms/__init__.py index 79965de..bd9123b 100644 --- a/src/transformnd/transforms/__init__.py +++ b/src/transformnd/transforms/__init__.py @@ -10,6 +10,7 @@ from .vector_field import Coordinates, Displacements from .moving_least_squares import MovingLeastSquares from .thinplate import ThinPlateSplines +from .grid import GridInterpolation __all__ = [ "Affine", @@ -18,6 +19,7 @@ "Insert", "Remove", "Reflect", + "GridInterpolation", "Scale", "Translate", "MapAxis", diff --git a/src/transformnd/transforms/grid.py b/src/transformnd/transforms/grid.py index c449682..99f82f7 100644 --- a/src/transformnd/transforms/grid.py +++ b/src/transformnd/transforms/grid.py @@ -10,10 +10,28 @@ class Interpolator(Protocol, Generic[ArrayT]): def __call__(self, x: ArrayT) -> ArrayT: ... -class GridInterpolation(Transform): +class GridInterpolation(Transform[ArrayT]): + """Coordinate transformation which applies a callable to each dimension. + + Intended for use with instances of `scipy.interpolate` interpolators, + but any callable which takes and returns an array of floats would work. + """ + def __init__( - self, interpolators: list[Interpolator], *, spaces: Spaces = Spaces(None, None) + self, + interpolators: list[Interpolator[ArrayT]], + *, + spaces: Spaces = Spaces(None, None), ): + """ + Parameters + ---------- + interpolators + One callable per dimension, in order. + Each one should take and return an array of floats. + spaces + Source and target space identifiers + """ self.interpolators = interpolators nd = len(interpolators) super().__init__(NDims(nd, nd), spaces=spaces) diff --git a/tests/transforms/test_grid.py b/tests/transforms/test_grid.py new file mode 100644 index 0000000..08baf08 --- /dev/null +++ b/tests/transforms/test_grid.py @@ -0,0 +1,8 @@ +from transformnd.transforms import GridInterpolation +import pytest + + +def test_simple(coords5x3): + t = GridInterpolation([lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]) + out = t.apply(coords5x3) + assert out == pytest.approx(coords5x3 * [2, 3, 4])