diff --git a/CHANGELOG.md b/CHANGELOG.md index b88b889..3de8efe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,19 @@ ## Unreleased +### Fixed + +- Rectangular affines now have the correct input and output dimensions + +### Added + +- ProjectAxis transformation for adding and dropping axes + +### Changed + +- Expose all transforms under `.transforms`, even if they are optional +- BREAKING: Rename `GeometryAdapter` to `ShapelyAdapter` + ## 0.5.0 - 2026-06-17 ### Added diff --git a/README.md b/README.md index 073c975..3a8767e 100644 --- a/README.md +++ b/README.md @@ -44,10 +44,10 @@ All transforms are accessed under the `transformnd.transforms` subpackage. | `MapAxis` | | Rearrange axes of the input coordinates | | `Affine` | | Multiply augmented coordinates by an affine transformation matrix. Can represent all of the above transformations. Can be composed with matrix multiplication `aff2 @ aff1`. | | `ByDimension` | | Apply different transformations to subsets of the input coordinates' dimensions | -| `moving_least_squares.MovingLeastSquares` | `movingleastsquares` | Landmark-based transformation. | -| `thin_plate_splines.ThinPlateSplines` | `thinplatesplines` | Landmark-based transformation. | -| `vector_field.Coordinates` | `vectorfield` for in-memory, `vectorfield-dask` for chunked | Look up output coordinates in a vector field indexed by the input coordinates | -| `vector_field.Displacements` | `vectorfield`, `vectorfield-dask` for chunked | Look up translations in a vector field indexed by the input coordinates, and add them to input coordinates | +| `MovingLeastSquares` | `movingleastsquares` | Landmark-based transformation. | +| `ThinPlateSplines` | `thinplatesplines` | Landmark-based transformation. | +| `Coordinates` | `vectorfield` for in-memory, `vectorfield-dask` for chunked | Look up output coordinates in a vector field indexed by the input coordinates | +| `Displacements` | `vectorfield`, `vectorfield-dask` for chunked | Look up translations in a vector field indexed by the input coordinates, and add them to input coordinates | Arbitrary transforms can be composed into a `TransformSequence` with `transform1 | transform2`. A graph of transforms between defined spaces can be traversed using the `TransformGraph`. @@ -55,13 +55,12 @@ A graph of transforms between defined spaces can be traversed using the `Transfo ## Implemented adapters - Numpy arrays of shape `(..., D, ...)` (`transformnd.adapters.ReshapeAdapter`) -- `meshio.Mesh` (`transformnd.adapters.meshio.MeshAdapter`) - `pandas.DataFrame` (`transformnd.adapters.pandas.PandasAdapter`) - Takes a subset of columns as a coordinate array - `polars.DataFrame` (`transformnd.adapters.polars.PolarsAdapter`) - Similar to the pandas adapter - Currently, only scalar columns are supported (e.g. not a single struct column with fields `x`, `y`, `z`) -- Geometries from `shapely` (`transformnd.adapters.shapely.GeometryAdapter`) +- Geometries from `shapely` (`transformnd.adapters.shapely.ShapelyAdapter`) - Objects composed of transformable attributes (`transformnd.adapters.AttrAdapter`). ## Additional transforms and adapters diff --git a/justfile b/justfile index 546fc77..20ee26e 100644 --- a/justfile +++ b/justfile @@ -47,5 +47,8 @@ bump level: git commit -m "Bump to v$(uv version --short)" git tag -a "v$(uv version --short)" -m "$(changelog entry latest)" +pre-commit: + uv run --group dev prek run --all-files + repl: uv run --all-groups --all-extras --with ipython ipython diff --git a/src/transformnd/adapters/__init__.py b/src/transformnd/adapters/__init__.py index 621b871..bf2c3ed 100644 --- a/src/transformnd/adapters/__init__.py +++ b/src/transformnd/adapters/__init__.py @@ -12,7 +12,6 @@ See `.pandas.DataFrameAdapter` for an example of creating an adapter for an external type. - """ from .base import ( @@ -23,6 +22,9 @@ ReshapeAdapter, SimpleAdapter, ) +from .pandas import PandasAdapter +from .polars import PolarsAdapter +from .shapely import ShapelyAdapter __all__ = [ "BaseAdapter", @@ -31,4 +33,7 @@ "FnAdapter", "AttrAdapter", "ReshapeAdapter", + "PandasAdapter", + "PolarsAdapter", + "ShapelyAdapter", ] diff --git a/src/transformnd/adapters/pandas.py b/src/transformnd/adapters/pandas.py index f219bb1..f8a3b5a 100644 --- a/src/transformnd/adapters/pandas.py +++ b/src/transformnd/adapters/pandas.py @@ -1,15 +1,18 @@ """Adapt pandas DataFrames for transformation.""" from collections.abc import Hashable +from typing import TYPE_CHECKING -import pandas as pd import numpy as np from ..base import Transform from .base import BaseAdapter +if TYPE_CHECKING: + import pandas as pd -class PandasAdapter(BaseAdapter[pd.DataFrame, np.ndarray]): + +class PandasAdapter(BaseAdapter["pd.DataFrame", np.ndarray]): def __init__(self, columns: list[Hashable]): """Adapt transformation for coordinates stored in a pandas DataFrame. @@ -21,8 +24,8 @@ def __init__(self, columns: list[Hashable]): self.columns = columns def apply( - self, transform: Transform, df: pd.DataFrame, in_place: bool = False - ) -> pd.DataFrame: + self, transform: Transform, df: "pd.DataFrame", in_place: bool = False + ) -> "pd.DataFrame": """Transform the dataframe, optionally in-place. Parameters diff --git a/src/transformnd/adapters/polars.py b/src/transformnd/adapters/polars.py index 759fdaa..3ccbb65 100644 --- a/src/transformnd/adapters/polars.py +++ b/src/transformnd/adapters/polars.py @@ -1,13 +1,16 @@ """Adapt polars DataFrames for transformation.""" -import polars as pl +from typing import TYPE_CHECKING import numpy as np from ..base import Transform from .base import BaseAdapter +if TYPE_CHECKING: + import polars as pl -class PolarsAdapter(BaseAdapter[pl.DataFrame, np.ndarray]): + +class PolarsAdapter(BaseAdapter["pl.DataFrame", np.ndarray]): def __init__(self, columns: list[str]): """Adapt transformation for coordinates stored in a polars DataFrame. @@ -19,8 +22,8 @@ def __init__(self, columns: list[str]): self.columns = columns def apply( - self, transform: Transform, df: pl.DataFrame, in_place: bool = False - ) -> pl.DataFrame: + self, transform: Transform, df: "pl.DataFrame", in_place: bool = False + ) -> "pl.DataFrame": """Transform the dataframe, optionally in-place. Parameters diff --git a/src/transformnd/adapters/shapely.py b/src/transformnd/adapters/shapely.py index e0e5cf4..46cb137 100644 --- a/src/transformnd/adapters/shapely.py +++ b/src/transformnd/adapters/shapely.py @@ -1,21 +1,24 @@ import logging +from typing import TYPE_CHECKING import numpy as np -import shapely -from shapely.geometry.base import BaseGeometry -from shapely.coords import CoordinateSequence from ..base import Transform, ArrayT from .base import BaseAdapter +if TYPE_CHECKING: + from shapely.geometry.base import BaseGeometry + from shapely.coords import CoordinateSequence + + logger = logging.getLogger(__name__) -def as_numpy(coords: CoordinateSequence) -> np.ndarray: +def as_numpy(coords: "CoordinateSequence") -> np.ndarray: return np.asarray(coords) -class GeometryAdapter(BaseAdapter[BaseGeometry, ArrayT]): +class ShapelyAdapter(BaseAdapter["BaseGeometry", ArrayT]): """Transform shapely geometries. As well as the generic `apply()`, @@ -27,7 +30,7 @@ class GeometryAdapter(BaseAdapter[BaseGeometry, ArrayT]): N.B. shapely geometries' coordinates are in `XY(Z)` order """ - def apply[T: BaseGeometry]( + def apply[T: "BaseGeometry"]( self, transform: Transform, obj: T, @@ -51,6 +54,7 @@ def apply[T: BaseGeometry]( T An object of the same type as the input. """ + import shapely def fn(coords: np.ndarray) -> np.ndarray: c = coords.copy() diff --git a/src/transformnd/transforms/__init__.py b/src/transformnd/transforms/__init__.py index d3c7232..79965de 100644 --- a/src/transformnd/transforms/__init__.py +++ b/src/transformnd/transforms/__init__.py @@ -5,15 +5,27 @@ from .simple import Identity, Scale, Translate from .map_axis import MapAxis from .bijection import Bijection -from .by_dimension import ByDimension +from .project_axis import ProjectAxis, Insert, Remove +from .by_dimension import ByDimension, SubTransform +from .vector_field import Coordinates, Displacements +from .moving_least_squares import MovingLeastSquares +from .thinplate import ThinPlateSplines __all__ = [ "Affine", "Identity", + "ProjectAxis", + "Insert", + "Remove", "Reflect", "Scale", "Translate", "MapAxis", "Bijection", "ByDimension", + "SubTransform", + "Coordinates", + "Displacements", + "MovingLeastSquares", + "ThinPlateSplines", ] diff --git a/src/transformnd/transforms/affine.py b/src/transformnd/transforms/affine.py index 4fcd603..1933599 100644 --- a/src/transformnd/transforms/affine.py +++ b/src/transformnd/transforms/affine.py @@ -42,7 +42,7 @@ def __init__( ---------- matrix Affine transformation matrix, - i.e. a 2D array-like with shape `(Di + 1, Do + 1)`, + i.e. a 2D array-like with shape `(Do + 1, Di + 1)`, where the bottom row is all 0s except in the rightmost column, which is 1. spaces Optional source and target spaces @@ -64,7 +64,7 @@ def __init__( f"Transformation matrix is not affine (expected bottom row {expected}, got {bottom_row})." ) - super().__init__(NDims(m.shape[0] - 1, m.shape[1] - 1), spaces=spaces) + super().__init__(NDims(m.shape[1] - 1, m.shape[0] - 1), spaces=spaces) self.matrix: np.ndarray = m diff --git a/src/transformnd/transforms/by_dimension.py b/src/transformnd/transforms/by_dimension.py index 83730ba..10240c9 100644 --- a/src/transformnd/transforms/by_dimension.py +++ b/src/transformnd/transforms/by_dimension.py @@ -8,7 +8,10 @@ class SubTransform[ArrayT]: - """Transformation to apply to subsets of the input dimensions and which output dimensions they calculate.""" + """Component of the `ByDimension` transformation. + + Transformation to apply to subsets of the input dimensions and which output dimensions they calculate. + """ def __init__( self, @@ -16,6 +19,24 @@ def __init__( input_axes: list[int], output_axes: list[int] | None = None, ): + """ + Parameters + ---------- + transform + Transformation to apply to the subset of axes. + input_axes + Which axes to apply the transformation to, in order. + The length must match the input dimensionality of `transform`. + output_axes + Which axes to apply the transformation to, in order. + The length must match the input dimensionality of `transform`. + If None, re-use the input axes. + + Raises + ------ + ValueError + `transform`'s dimensionality does not match the input/output axes. + """ self.input_axes = input_axes if output_axes is None: diff --git a/src/transformnd/transforms/moving_least_squares.py b/src/transformnd/transforms/moving_least_squares.py index c930fcc..456c766 100644 --- a/src/transformnd/transforms/moving_least_squares.py +++ b/src/transformnd/transforms/moving_least_squares.py @@ -7,7 +7,6 @@ from array_api_compat import array_namespace import numpy as np from typing import Self -from molesq.transform import Transformer as _Transformer from ..base import Transform from ..types import NDims, Spaces @@ -18,6 +17,8 @@ class MovingLeastSquares(Transform[np.ndarray]): """Moving least squares transformation. Deform based on a matched pairs of source and target control points; see + + REQUIRES: `movingleastsquares` extra. """ def __init__( @@ -39,9 +40,11 @@ def __init__( spaces Optional source and target spaces """ + from molesq.transform import Transformer + s = as_floats(source_control_points) t = as_floats(target_control_points) - self._transformer = _Transformer(s, t) + self._transformer = Transformer(s, t) super().__init__( NDims( s.shape[1], diff --git a/src/transformnd/transforms/project_axis.py b/src/transformnd/transforms/project_axis.py new file mode 100644 index 0000000..b946c85 --- /dev/null +++ b/src/transformnd/transforms/project_axis.py @@ -0,0 +1,173 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from copy import copy +from typing import Self, Sequence + +import numpy as np +from array_api_compat import array_namespace +from transformnd.transforms import Affine +from transformnd.types import NDims, Spaces +from dataclasses import dataclass +from ..base import Transform +from ..types import ArrayT + + +@dataclass(frozen=True, eq=True) +class BaseOperation(ABC): + idx: int + """Which axis to apply the operation to.""" + + def __post_init__(self): + if self.idx < 0: + raise ValueError("insert/remove idx must be positive") + + @abstractmethod + def check(self, ndim: int) -> int: ... + + @abstractmethod + def invert(self) -> BaseOperation: ... + + +@dataclass(frozen=True, eq=True) +class Insert(BaseOperation): + """Component of the `ProjectAxis` transform which inserts a new axis.""" + + def check(self, ndim: int) -> int: + if self.idx > ndim or self.idx <= -ndim: + raise ValueError( + f"Index {self.idx} is out of range for dimensionality {ndim}" + ) + return ndim + 1 + + def invert(self) -> Remove: + return Remove(self.idx) + + +@dataclass(frozen=True, eq=True) +class Remove(BaseOperation): + """Component of the `ProjectAxis` transform which removes an existing axis.""" + + def check(self, ndim: int) -> int: + if self.idx >= ndim or self.idx <= -ndim: + raise ValueError( + f"Index {self.idx} is out of range for dimensionality {ndim}" + ) + return ndim - 1 + + def invert(self) -> Insert: + if self.idx == -1: + raise ValueError("Removal of the -1th axis is not invertible") + return Insert(self.idx) + + +Operation = Insert | Remove +"""Insert or remove an axis.""" + + +class ProjectAxis(Transform): + """Transform for adding and removing axes. + + WARNING: inverting this transformation may be lossy. + """ + + def __init__( + self, + operations: Sequence[Operation], + source_ndim: int | None = None, + target_ndim: int | None = None, + *, + spaces: Spaces = Spaces(None, None), + ): + """Create a transform for adding and dropping axes. + + At least one of source_ndim and target_ndim must be given. + + Parameters + ---------- + operations + Sequence of operations to apply. + source_ndim + If omitted, can be inferred from `target_ndim`. + target_ndim + If omitted, can be inferred from `source_ndim`. + spaces + Identifiers for source and target spaces, by default Spaces(None, None) + + Raises + ------ + ValueError + Operations are inconsistent with given dimensionality, + or insufficient dimensionality information was given. + """ + self.operations = [] + self._has_inserts = False + + if source_ndim is not None: + nd = source_ndim + for op in operations: + nd = op.check(nd) + if target_ndim is None: + target_ndim = nd + elif target_ndim != nd: + raise ValueError("Operations do not match expected target ndim") + + elif target_ndim is not None: + nd = target_ndim + for op in reversed(operations): + nd = op.invert().check(nd) + if source_ndim is None: + source_ndim = nd + elif source_ndim != nd: + raise ValueError("Operations do not match expected source ndim") + + else: + raise ValueError("At least one of source_ndim or target_ndim must be given") + + idxs: list[int | None] = list(range(source_ndim)) + for op in operations: + if isinstance(op, Insert): + self._has_inserts = True + idxs.insert(op.idx, None) + elif isinstance(op, Remove): + idxs.pop(op.idx) + self.operations.append(op) + self._idxs = idxs + + super().__init__(NDims(source_ndim, target_ndim), spaces=spaces) + + def apply(self, coords: ArrayT) -> ArrayT: + coords = self._validate_coords(coords) + if self._has_inserts: + xp = array_namespace(coords) + out = xp.zeros_like(coords, shape=(xp.shape(coords)[0], self.ndims.target)) + for idx, orig_idx in enumerate(self._idxs): + if orig_idx is not None: + out[:, idx] = coords[:, orig_idx] # type:ignore + + else: + out = coords[:, self._idxs] # type:ignore + return out + + def is_identity(self) -> bool: + orig: list[int | None] = list(range(self.ndims.source)) + dims = copy(orig) + for op in self.operations: + if isinstance(op, Insert): + dims.insert(op.idx, None) + elif isinstance(op, Remove): + dims.pop(op.idx) + + return dims == orig + + def to_affine(self) -> Affine | None: + m = np.eye(self.ndims.source) + out_m = self.apply(m) + return Affine.from_linear_map(out_m.T) + + def invert(self) -> Self | None: + return type(self)( + [op.invert() for op in reversed(self.operations)], + source_ndim=self.ndims.target, + target_ndim=self.ndims.source, + spaces=self.spaces.invert(), + ) diff --git a/src/transformnd/transforms/thinplate.py b/src/transformnd/transforms/thinplate.py index 848419f..9a558a8 100644 --- a/src/transformnd/transforms/thinplate.py +++ b/src/transformnd/transforms/thinplate.py @@ -6,7 +6,6 @@ import logging -import morphops as mops import numpy as np from ..base import Transform @@ -20,6 +19,8 @@ class ThinPlateSplines(Transform[np.ndarray]): """Thin plate splines transforms. Deform based on matched pairs of control points. + + REQUIRES: `thinplatesplines` extra. """ def __init__( @@ -48,6 +49,8 @@ def __init__( ValueError Invalid control points. """ + import morphops + self.source_control_points = as_floats(source_control_points) self.target_control_points = as_floats(target_control_points) @@ -59,7 +62,7 @@ def __init__( ndim = self.source_control_points.shape[1] - self.W, self.A = mops.tps_coefs( + self.W, self.A = morphops.tps_coefs( self.source_control_points, self.target_control_points, ) @@ -73,8 +76,10 @@ def invert(self) -> Transform[np.ndarray] | None: ) def apply(self, coords: np.ndarray) -> np.ndarray: + import morphops + coords = self._validate_coords(coords) - U = mops.K_matrix(coords, self.source_control_points) - P = mops.P_matrix(coords) + U = morphops.K_matrix(coords, self.source_control_points) + P = morphops.P_matrix(coords) # The warped pts are the affine part + the non-uniform part return P @ self.A + U @ self.W diff --git a/src/transformnd/transforms/vector_field.py b/src/transformnd/transforms/vector_field.py index 8d24884..c880f8a 100644 --- a/src/transformnd/transforms/vector_field.py +++ b/src/transformnd/transforms/vector_field.py @@ -10,8 +10,6 @@ from ..base import Transform, ArrayT from ..util import set_scipy_array_api, as_floats -set_scipy_array_api() - __all__ = ["Coordinates", "Displacements"] @@ -95,6 +93,8 @@ def _get_vectors_inner_scipy(self, index_coords_t: ArrayT) -> ArrayT: set_scipy_array_api() xp = array_namespace(index_coords_t) + + # make columnar output array so that each dimension can be written contiguously out = xp.zeros_like( self.vector_field, shape=(self.ndims.target, xp.shape(index_coords_t)[1]) ) @@ -133,6 +133,9 @@ class Coordinates(BaseVectorField[ArrayT]): the output coordinate is `vector_field[a, b, c, :]. Input coordinates outside the vector field return NaN. + + REQUIRES: `vectorfield` extra for in-memory, + or `vectorfield-dask` extra for lazy chunked vector fields. """ def __init__( @@ -182,6 +185,9 @@ class Displacements(BaseVectorField[ArrayT]): the output coordinate is `(a, b, c) + vector_field[a, b, c, :]. Input coordinates outside the vector field return NaN. + + REQUIRES: `vectorfield` extra for in-memory, + or `vectorfield-dask` extra for lazy chunked vector fields. """ def __init__( diff --git a/tests/adapters/test_shapely.py b/tests/adapters/test_shapely.py index b7b8e32..e8f29d5 100644 --- a/tests/adapters/test_shapely.py +++ b/tests/adapters/test_shapely.py @@ -9,7 +9,7 @@ GeometryCollection, ) -from transformnd.adapters.shapely import GeometryAdapter +from transformnd.adapters.shapely import ShapelyAdapter from transformnd.transforms import Scale import pytest @@ -141,7 +141,7 @@ ], ) def test_geom(original, expected): - adapter = GeometryAdapter() + adapter = ShapelyAdapter() transform = Scale([2, 3]) out = adapter.apply(transform, original) diff --git a/tests/conftest.py b/tests/conftest.py index afef26f..1bc49af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ def make_coords(shape): - return np.arange(np.prod(shape)).reshape(shape) + return np.arange(np.prod(shape), dtype=float).reshape(shape) @pytest.fixture diff --git a/tests/transforms/test_affine.py b/tests/transforms/test_affine.py index 094c698..92b115b 100644 --- a/tests/transforms/test_affine.py +++ b/tests/transforms/test_affine.py @@ -17,27 +17,25 @@ def test_identity(): @pytest.mark.parametrize(["ndim"], [[d] for d in range(1, 6)]) def test_translation(ndim, rng): - t = 1 + t_arr = np.arange(ndim) + 1 coords = rng.random((5, ndim)) - 0.5 - t_arr = [t] * ndim trans_arr = Affine.translation(t_arr) - assert np.allclose(trans_arr.apply(coords), coords + t) - assert np.allclose((~trans_arr).apply(coords), coords - t) + assert np.allclose(trans_arr.apply(coords), coords + t_arr) + assert np.allclose((~trans_arr).apply(coords), coords - t_arr) @pytest.mark.parametrize(["ndim"], [[d] for d in range(2, 6)]) def test_scaling(ndim, rng): - s = 2 - s_arr = [s] * ndim + s_arr = np.arange(ndim) + 2 coords = rng.random((5, ndim)) - 0.5 trans = Affine.scaling(s_arr) - assert np.allclose(trans.apply(coords), coords * s) - assert np.allclose((~trans).apply(coords), coords / s) + assert np.allclose(trans.apply(coords), coords * s_arr) + assert np.allclose((~trans).apply(coords), coords / s_arr) trans_arr = Affine.scaling(s_arr) - assert np.allclose(trans_arr.apply(coords), coords * s) + assert np.allclose(trans_arr.apply(coords), coords * s_arr) def test_rotation2(): @@ -136,7 +134,7 @@ def test_inversion(rng): assert inv_aff.apply(aff.apply(coords)) == pytest.approx(coords) -def test_upprojection(): +def test_downprojection(): lin_map = as_floats( [ [1, 0, 0], @@ -144,10 +142,15 @@ def test_upprojection(): ] ) t = Affine.from_linear_map(lin_map) - assert t.ndims.target > t.ndims.source + assert t.ndims.source == 3 + assert t.ndims.target == 2 + coords = as_floats([[1, 2, 3], [4, 5, 6]]) + out = t.apply(coords) + assert out == pytest.approx(as_floats([[1, 2], [4, 5]])) -def test_downprojection(): + +def test_upprojection(): lin_map = as_floats( [ [1, 0], @@ -156,7 +159,28 @@ def test_downprojection(): ] ) t = Affine.from_linear_map(lin_map) - assert t.ndims.source > t.ndims.target + assert t.ndims.source == 2 + assert t.ndims.target == 3 + + coords = as_floats([[1, 2], [3, 4], [5, 6]]) + out = t.apply(coords) + assert out == pytest.approx(as_floats([[1, 2, 0], [3, 4, 0], [5, 6, 0]])) + + +def test_transpose_commutation(): + mx_dim = 10 + + rng = np.random.default_rng(1991) + for _ in range(100): + lhs_shape = rng.integers(1, mx_dim, 2, endpoint=True) + lhs = rng.random(tuple(lhs_shape)) + rhs_shape = (lhs_shape[1], rng.integers(mx_dim, endpoint=True)) + rhs = rng.random(rhs_shape) + + lr = lhs @ rhs + rtlt_t = (rhs.T @ lhs.T).T + + assert lr == pytest.approx(rtlt_t) # def test_reflection(): diff --git a/tests/transforms/test_project_axis.py b/tests/transforms/test_project_axis.py new file mode 100644 index 0000000..685e230 --- /dev/null +++ b/tests/transforms/test_project_axis.py @@ -0,0 +1,49 @@ +from random import Random +from transformnd.transforms import ProjectAxis +from transformnd.transforms.project_axis import Insert, Remove, Operation +from transformnd.util import as_floats +import pytest + + +def test_insert(): + t = ProjectAxis([Insert(0)], 2, 3) + + out = t.apply(as_floats([[1, 2], [3, 4]])) + assert out == pytest.approx(as_floats([[0, 1, 2], [0, 3, 4]])) + + +def test_remove(): + t = ProjectAxis([Remove(0)], 3, 2) + out = t.apply(as_floats([[1, 2, 3], [4, 5, 6]])) + assert out == pytest.approx(as_floats([[2, 3], [5, 6]])) + + +def test_invert(): + t = ProjectAxis([Remove(0), Insert(2)], 3, 3) + ti = t.invert() + assert ti is not None + assert ti.operations == [Remove(2), Insert(0)] + + +def random_ops(ndim: int, n_ops: int, seed=1991) -> ProjectAxis: + rng = Random(seed) + ops = [] + source_ndim = ndim + for _ in range(n_ops): + op: Operation + if ndim == 0 or rng.randint(0, 1): + op = Insert(rng.randint(0, ndim)) + else: + op = Remove(rng.randint(0, ndim - 1)) + ndim = op.check(ndim) + ops.append(op) + + return ProjectAxis(ops, source_ndim) + + +def test_to_affine(coords5x3): + t = random_ops(coords5x3.shape[-1], 3) + aff = t.to_affine() + assert aff is not None + assert aff.ndims == t.ndims + assert t.apply(coords5x3) == pytest.approx(aff.apply(coords5x3))