Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
with:
enable-cache: true
- run: uv sync --all-extras
- run: uv run --group tutorial marimo check --strict --ignore-scripts examples/*.py
- run: uv run --group examples marimo check --strict --ignore-scripts examples/*.py

lint:
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ build:
# - just doc ${READTHEDOCS_OUTPUT}/html
- mkdir -p ${READTHEDOCS_OUTPUT}/html
- uv run --group doc pdoc --output-directory ${READTHEDOCS_OUTPUT}/html --no-include-undocumented --docformat numpy --search transformnd
- mkdir -p ${READTHEDOCS_OUTPUT}/html/examples
- uv run --group examples marimo export html examples/tutorial.py -o {{docdir}}/examples/tutorial.html
- uv run --group examples marimo export html examples/image.py -o {{docdir}}/examples/image.html
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ and implements these adapters for a few common types.

See the [tutorial here](https://github.com/clbarnes/transformnd/blob/main/examples/tutorial.py).
It is a [marimo](https://marimo.io) notebook.
Open it with `uv run --group tutorial marimo edit examples/tutorial.py`.
Open it with `uv run --group examples marimo edit examples/tutorial.py`.

## Implemented transforms

Expand Down
170 changes: 170 additions & 0 deletions examples/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import marimo

__generated_with = "0.23.4"
app = marimo.App(width="medium")

with app.setup:
import marimo as mo


@app.cell(hide_code=True)
def _():
mo.md(r"""
# Image transformation with `transformnd`

`transformnd` transforms coordinates, not images, but coordinate transformations can be used to transform images.
Your output (transformed) and source images both have pixels with an `xy` coordinate in their respective image spaces,
and image transformation is simply a case of finding which source pixel to use for each output pixel.

Here we take a 2-channel fluorescence microscopy image of some cells in 3 dimensions, use scaling information to map those pixels into a real-world space, and then map the pixels of our viewport into the the same space.
""")
return


@app.cell
def _():
from skimage.data import cells3d

# ZCYX, (0.29um, membrane/nuclei channels, 0.26um, 0.26um)
CELLS_SPACE = "cells"

# CZYX, (membrane/nuclei, um, um, um)
WORLD_SPACE = "world"

# YXC image with RGB channels
VIEWPORT_SPACE = "viewport"

cells = cells3d()
cells = cells.astype("float64")
cells -= cells.min()
cells /= cells.max()

print(f"{cells.shape=}")
print(f"{cells.dtype=}")
print(f"{cells.min()=}")
print(f"{cells.max()=}")
return CELLS_SPACE, VIEWPORT_SPACE, WORLD_SPACE, cells


@app.cell
def _(CELLS_SPACE, VIEWPORT_SPACE, WORLD_SPACE):
import transformnd as tnd
from transformnd.transforms import ProjectAxis, MapAxis, Scale

# Aligned at world origin.
# This would be stored alongside the data.
cells_to_world = tnd.base.TransformSequence(
[
# Move the color axis to the first position
MapAxis([1, 0, 2, 3]),
# Scale the space axes
Scale([1, 0.29, 0.26, 0.26]),
],
spaces=tnd.Spaces(CELLS_SPACE, WORLD_SPACE),
)
print(cells_to_world)

# Aligned at world origin.
# This would be chosen by the viewing application.
viewport_to_world = tnd.base.TransformSequence(
[
# Create a Z axis
ProjectAxis(created={0}, source_ndim=3),
# Move the color axis to the first position
MapAxis([3, 0, 1, 2]),
# Choose a spatial sampling frequency (here 0.2um isotropic)
Scale([1, 0.2, 0.2, 0.2]),
],
spaces=tnd.Spaces(VIEWPORT_SPACE, WORLD_SPACE),
)
print(viewport_to_world)
return cells_to_world, viewport_to_world


@app.cell(hide_code=True)
def _():
mo.md(r"""
Both images know how to transform their array indices into the real world.

We can invert one of those transforms to get a transformation between viewport-space and cell-space.
We can also have a separate transformation to control moving the viewport (useful if we had an interactive viewer).
""")
return


@app.cell
def _(cells_to_world, viewport_to_world):
from transformnd.transforms import Translate

# Shift the viewport within the data, in world measurements.
# This would be controlled by the user as they peruse the data.
viewport_offset = Translate([0, 35 * 0.29, 64 * 0.26, 0.0])

viewport_to_cells = viewport_to_world | viewport_offset | ~cells_to_world
return (viewport_to_cells,)


@app.cell(hide_code=True)
def _():
mo.md(r"""
Here we want to get all of the coordinates of our viewport, across all channels, in the shape needed by `transformnd` (number of coordinates x dimensionality of coordinates).

We then transform that to get the positions of those coordinates within the cells image.
""")
return


@app.cell
def _(viewport_to_cells):
import numpy as np

# 2D YXC image
viewport_shape = (128, 256, 3)

indices = [np.arange(s, dtype=float) for s in viewport_shape]
grids = np.meshgrid(*indices, indexing="ij")

# Y, X, C, coords
vp_coords_3d = np.stack(grids, -1)
print(f"{vp_coords_3d.shape=}")

# Y*X*C, coords
vp_coords = vp_coords_3d.reshape((-1, len(viewport_shape)))
print(f"{vp_coords.shape=}")

# Z*C*Y*X, coords
cells_coords = viewport_to_cells.apply(vp_coords)
print(f"{cells_coords.shape=}")
return cells_coords, viewport_shape


@app.cell(hide_code=True)
def _():
mo.md(r"""
`scipy.ndimage.map_coordinates` is where the magic happens; looking up our coordinates in the cells image to get the intensities. There's a dask version too!
""")
return


@app.cell
def _(cells, cells_coords, viewport_shape):
from scipy.ndimage import map_coordinates

# transformnd uses `NxD` coordinate arrays; map_coordinates uses `DxN`
cells_vals = map_coordinates(cells, cells_coords.T).T
print(f"{cells_vals.shape=}")
viewport = cells_vals.reshape(viewport_shape)
print(f"{viewport.shape=}")
return (viewport,)


@app.cell
def _(viewport):
from matplotlib import pyplot as plt

plt.imshow(viewport)
return


if __name__ == "__main__":
app.run()
13 changes: 9 additions & 4 deletions examples/tutorial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import marimo

__generated_with = "0.23.1"
__generated_with = "0.23.4"
app = marimo.App()


Expand Down Expand Up @@ -218,11 +218,16 @@ def _(Scale, Translate):
bc = Translate([0.5, 1], spaces=Spaces("b", "c"))
bd = Translate([1, 0.5], spaces=Spaces("b", "d"))

g.add_transforms([ab, bc, bd])
g.add_transform(ab)
g.add_transform(ab.invert())
g.add_transform(bc)
g.add_transform(bc.invert())
g.add_transform(bd)
g.add_transform(bd.invert())

print("Transform sequence from a to c:", g.get_sequence("a", "c"))
print("Transform sequence from d to a:", g.get_sequence("d", "a"))
return
return (Spaces,)


@app.cell
Expand All @@ -245,7 +250,7 @@ def _(mo):


@app.cell
def _(np, Spaces):
def _(Spaces, np):
from transformnd import Transform, NDims

class IsotropicScale2d(Transform):
Expand Down
11 changes: 9 additions & 2 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ doc docdir='doc/html':
--docformat numpy \
--search \
transformnd
mkdir -p {{docdir}}/examples
uv run --group examples marimo export html examples/tutorial.py -o {{docdir}}/examples/tutorial.html
uv run --group examples marimo export html examples/image.py -o {{docdir}}/examples/image.html

# Run linters and type checkers.
lint:
uv run --group lint ruff check src tests examples bench
uv run --group lint mypy src tests bench
uv run --group lint ruff format --check src tests examples bench
uv run --group tutorial marimo check --strict --ignore-scripts examples/*.py
uv run --group examples marimo check --strict --ignore-scripts examples/*.py
uv run --group lint pydoclint src

# Auto-fix format and lints where possible.
Expand All @@ -33,7 +36,11 @@ format:
test:
uv run --all-groups --all-extras pytest -v

examples:
example-edit example:
uv run --group examples marimo edit examples/{{example}}.py

example example:
uv run --group examples marimo run examples/{{example}}.py

# Run benchmarks.
bench:
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev = [
{include-group = "test"},
{include-group = "lint"},
{include-group = "doc"},
{include-group = "tutorial"},
{include-group = "examples"},
{include-group = "bench"},
]
test = [
Expand All @@ -80,10 +80,12 @@ lint = [
doc = [
"pdoc>=16.0.0",
]
tutorial = [
examples = [
"marimo>=0.9",
"matplotlib>=3.10.8",
"pandas>=3.0.2",
"pooch>=1.9.0",
"scikit-image>=0.26.0",
]
bench = [
"pytest-benchmark>=5.2.3",
Expand Down
6 changes: 6 additions & 0 deletions src/transformnd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""
.. include:: ../../README.md

You can find some examples here:

- [Tutorial](./examples/tutorial.html)
- [Image transformation](./examples/image.html)

"""

from .base import Transform, TransformSequence, TransformWrapper
Expand Down
8 changes: 8 additions & 0 deletions src/transformnd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ def __init__(
if not ts:
raise ValueError("Empty transform sequence")

for idx, (t1, t2) in enumerate(pairwise(ts)):
if t1.ndims.target != t2.ndims.source:
raise ValueError(
"Incompatible dimensionality: "
f"transform {idx}'s target is {t1.ndims.target}D "
f"and the next source is {t2.ndims.source}D"
)

spaces = Spaces(ts[0].spaces.source, ts[-1].spaces.target)
ndims = NDims(ts[0].ndims.source, ts[-1].ndims.target)

Expand Down
4 changes: 1 addition & 3 deletions src/transformnd/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .simple import Identity, Scale, Translate
from .map_axis import MapAxis
from .bijection import Bijection
from .project_axis import ProjectAxis, Insert, Remove
from .project_axis import ProjectAxis
from .by_dimension import ByDimension, SubTransform
from .vector_field import Coordinates, Displacements
from .moving_least_squares import MovingLeastSquares
Expand All @@ -15,8 +15,6 @@
"Affine",
"Identity",
"ProjectAxis",
"Insert",
"Remove",
"Reflect",
"Scale",
"Translate",
Expand Down
Loading
Loading