Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Added

- GridInterpolation transformation for xarray-esque rectilinear grids

## 0.7.2 - 2026-06-25

## 0.7.1 - 2026-06-25
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ All transforms are accessed under the `transformnd.transforms` subpackage.
| `Scale` | | Multiply the input coordinates by constant scale factor |
| `Reflection` | | Reflect coordinates about arbitrary planes |
| `MapAxis` | | Rearrange axes of the input coordinates |
| `ProjectAxis` | | Insert or remove axes |
| `Affine` | | Multiply augmented coordinates by an affine transformation matrix. Can represent all of the above transformations. Can be composed with matrix multiplication `aff2 @ aff1`. |
| `GridInterpolation` | | Apply an arbitrary callable (e.g. a `scipy.interpolate` interpolator) to each dimension. |
| `ByDimension` | | Apply different transformations to subsets of the input coordinates' dimensions |
| `MovingLeastSquares` | `movingleastsquares` | Landmark-based transformation. |
| `ThinPlateSplines` | `thinplatesplines` | Landmark-based transformation. |
Expand Down
2 changes: 2 additions & 0 deletions src/transformnd/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .vector_field import Coordinates, Displacements
from .moving_least_squares import MovingLeastSquares
from .thinplate import ThinPlateSplines
from .grid import GridInterpolation

__all__ = [
"Affine",
"Identity",
"ProjectAxis",
"Reflect",
"GridInterpolation",
"Scale",
"Translate",
"MapAxis",
Expand Down
47 changes: 47 additions & 0 deletions src/transformnd/transforms/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any, Protocol, Generic

from array_api_compat import array_namespace
from transformnd.types import NDims, Spaces
from ..base import Transform
from ..types import ArrayT


class Interpolator(Protocol, Generic[ArrayT]):
def __call__(self, x: ArrayT) -> ArrayT: ...


class GridInterpolation(Transform[ArrayT]):
"""Coordinate transformation which applies a callable to each dimension.

Intended for use with instances of `scipy.interpolate` interpolators,
but any callable which takes and returns an array of floats would work.
"""

def __init__(
self,
interpolators: list[Interpolator[ArrayT]],
*,
spaces: Spaces = Spaces(None, None),
):
"""
Parameters
----------
interpolators
One callable per dimension, in order.
Each one should take and return an array of floats.
spaces
Source and target space identifiers
"""
self.interpolators = interpolators
nd = len(interpolators)
super().__init__(NDims(nd, nd), spaces=spaces)

def apply(self, coords: Any) -> Any:
xp = array_namespace(coords)
coords = self._validate_coords(coords)
coords_t = xp.transpose(coords)
out_coords_t = xp.zeros_like(coords_t)
for in_col, out_col, interp in zip(coords_t, out_coords_t, self.interpolators):
out_col[:] = interp(in_col)

return xp.transpose(out_coords_t)
8 changes: 8 additions & 0 deletions tests/transforms/test_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from transformnd.transforms import GridInterpolation
import pytest


def test_simple(coords5x3):
t = GridInterpolation([lambda x: x * 2, lambda x: x * 3, lambda x: x * 4])
out = t.apply(coords5x3)
assert out == pytest.approx(coords5x3 * [2, 3, 4])
Loading