Skip to content

Commit f24666a

Browse files
committed
Add function to transform from physical space to voxel space
1 parent eaa6f77 commit f24666a

3 files changed

Lines changed: 211 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ dependencies = [
1717
"numpy",
1818
"rich-argparse",
1919
"nibabel",
20-
"pandas"]
20+
"pandas",
21+
"scipy",
22+
]
2123

2224
[project.optional-dependencies]
2325
show = [

src/mritk/data/orientation.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,71 @@
66

77

88
import numpy as np
9+
910
from .base import MRIData
1011

1112

13+
def physical_to_voxel_indices(physical_coordinates: np.ndarray, affine: np.ndarray, round_coords: bool = True) -> np.ndarray:
14+
"""Transform physical coordinates to voxel indices using an affine matrix.
15+
16+
This function maps coordinates from physical space (e.g., FEM degrees of freedom)
17+
back to the image voxel space by applying the inverse of the provided affine
18+
transformation matrix.
19+
20+
Args:
21+
physical_coordinates: (N, 3) array of coordinates in physical space (world coordinates).
22+
affine: (4, 4) affine transformation matrix mapping voxel indices to physical space.
23+
round_coords: If True, rounds the resulting voxel coordinates to the nearest
24+
integer and casts them to `int`. If False, returns floating-point voxel coordinates.
25+
Defaults to True.
26+
27+
Returns:
28+
(N, 3) array of voxel indices (or coordinates).
29+
"""
30+
# Note: Assumes apply_affine is available in the scope or imported
31+
img_space_coords = apply_affine(np.linalg.inv(affine), physical_coordinates)
32+
if round_coords:
33+
return np.rint(img_space_coords).astype(int)
34+
return img_space_coords
35+
36+
37+
def find_nearest_valid_voxels(query_indices: np.ndarray, mask: np.ndarray, k: int) -> np.ndarray:
38+
"""Find the nearest valid voxels in a mask for a set of query indices.
39+
40+
Uses a KDTree to find the `k` nearest neighbors for each point in `query_indices`
41+
where the neighbors are restricted to positions where `mask` is True.
42+
43+
Args:
44+
query_indices: (N, 3) array of voxel indices (or coordinates) to find neighbors for.
45+
mask: Boolean array of shape (X, Y, Z). Neighbors will only be selected from
46+
coordinates where this mask is True.
47+
k: The number of nearest neighbors to find for each query point.
48+
49+
Returns:
50+
Array of nearest neighbor indices.
51+
- If k=1: Returns shape (3, 1, N) containing the coordinates of the single nearest neighbor.
52+
- If k>1: Returns shape (3, k, N) containing the coordinates of the k nearest neighbors.
53+
54+
Raises:
55+
ValueError: If the provided mask contains no valid (True) entries.
56+
"""
57+
import scipy.spatial
58+
59+
valid_inds = np.argwhere(mask)
60+
if len(valid_inds) == 0:
61+
raise ValueError("No valid indices found in mask.")
62+
63+
tree = scipy.spatial.KDTree(valid_inds)
64+
_, indices = tree.query(query_indices, k=k)
65+
66+
# Transpose to match the expected output shape (3, k, N) or (3, 1, N)
67+
dof_neighbours = valid_inds[indices].T
68+
69+
if k == 1:
70+
dof_neighbours = dof_neighbours[:, np.newaxis, :]
71+
return dof_neighbours
72+
73+
1274
def apply_affine(T: np.ndarray, X: np.ndarray) -> np.ndarray:
1375
"""Apply a homogeneous affine transformation matrix to a set of points.
1476

test/test_data_orientation.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import numpy as np
2+
import pytest
3+
import mritk.data.orientation
4+
5+
6+
def test_apply_affine_identity():
7+
"""Test that applying an identity matrix returns the original points."""
8+
points = np.array([[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]])
9+
identity_affine = np.eye(4)
10+
11+
result = mritk.data.orientation.apply_affine(identity_affine, points)
12+
13+
np.testing.assert_array_equal(result, points)
14+
15+
16+
def test_apply_affine_translation():
17+
"""Test translation logic: x' = x + t."""
18+
points = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
19+
affine = np.eye(4)
20+
# Set translation vector [Tx, Ty, Tz]
21+
translation = np.array([5.0, 10.0, -5.0])
22+
affine[:3, 3] = translation
23+
24+
expected = points + translation
25+
result = mritk.data.orientation.apply_affine(affine, points)
26+
27+
np.testing.assert_array_almost_equal(result, expected)
28+
29+
30+
def test_apply_affine_scaling():
31+
"""Test scaling logic: x' = s * x."""
32+
points = np.array([[1.0, 2.0, 3.0]])
33+
# Scale x*2, y*0.5, z*-1
34+
affine = np.diag([2.0, 0.5, -1.0, 1.0])
35+
36+
expected = np.array([[2.0, 1.0, -3.0]])
37+
result = mritk.data.orientation.apply_affine(affine, points)
38+
39+
np.testing.assert_array_almost_equal(result, expected)
40+
41+
42+
def test_physical_to_voxel_indices_basic_translation():
43+
"""
44+
Test transforming world coordinates back to voxel coordinates.
45+
Scenario: The affine translates voxel space by +10.
46+
Therefore, a world coordinate of 10 should map back to voxel 0.
47+
"""
48+
# World coordinates (DOFs)
49+
dof_coords = np.array([[10.0, 10.0, 10.0], [11.0, 11.0, 11.0]])
50+
51+
# Affine that adds 10 to everything
52+
affine = np.eye(4)
53+
affine[:3, 3] = [10.0, 10.0, 10.0]
54+
55+
# We expect the function to apply the INVERSE of (+10) -> (-10)
56+
# 10 - 10 = 0
57+
# 11 - 10 = 1
58+
expected = np.array([[0, 0, 0], [1, 1, 1]])
59+
60+
result = mritk.data.orientation.physical_to_voxel_indices(dof_coords, affine, round_coords=True)
61+
62+
np.testing.assert_array_equal(result, expected)
63+
assert result.dtype == int
64+
65+
66+
def test_physical_to_voxel_indices_no_rounding():
67+
"""Test that floating point results are returned when rint=False."""
68+
dof_coords = np.array([[10.5, 10.5, 10.5]])
69+
affine = np.eye(4) # Identity
70+
71+
result = mritk.data.orientation.physical_to_voxel_indices(dof_coords, affine, round_coords=False)
72+
73+
np.testing.assert_array_almost_equal(result, dof_coords)
74+
assert np.issubdtype(result.dtype, np.floating)
75+
76+
77+
def test_physical_to_voxel_indices_rounding_behavior():
78+
"""Test that rint rounds correctly."""
79+
# 10.1 -> 10, 10.9 -> 11
80+
dof_coords = np.array([[10.1, 10.1, 10.1], [10.9, 10.9, 10.9]])
81+
affine = np.eye(4)
82+
83+
expected = np.array([[10, 10, 10], [11, 11, 11]])
84+
85+
result = mritk.data.orientation.physical_to_voxel_indices(dof_coords, affine, round_coords=True)
86+
np.testing.assert_array_equal(result, expected)
87+
88+
89+
def test_find_nearest_valid_voxels_1_neighbor():
90+
"""Test finding the single closest point in a 2D mask."""
91+
# Define a mask with valid pixels only at (0,0) and (5,5)
92+
mask = np.zeros((6, 6), dtype=bool)
93+
mask[0, 0] = True
94+
mask[5, 5] = True
95+
96+
# Point A is close to (0,0), Point B is close to (5,5)
97+
dof_inds = np.array([[0.1, 0.1], [4.9, 4.9]])
98+
99+
# Function output shape is (ndim, N_neighbors, N_points)
100+
result = mritk.data.orientation.find_nearest_valid_voxels(dof_inds, mask, k=1)
101+
102+
# Verify shape: (2 dims, 1 neighbor, 2 query points)
103+
assert result.shape == (2, 1, 2)
104+
105+
# First point (0.1, 0.1) -> Neighbor should be (0, 0)
106+
np.testing.assert_array_equal(result[:, 0, 0], [0, 0])
107+
# Second point (4.9, 4.9) -> Neighbor should be (5, 5)
108+
np.testing.assert_array_equal(result[:, 0, 1], [5, 5])
109+
110+
111+
def test_find_nearest_valid_voxels_N_neighbors():
112+
"""Test finding multiple neighbors (N=2) in 3D."""
113+
mask = np.zeros((10, 10, 10), dtype=bool)
114+
# Two valid points close to each other
115+
mask[1, 1, 1] = True
116+
mask[1, 1, 2] = True
117+
# One valid point far away
118+
mask[9, 9, 9] = True
119+
120+
# Query point right next to the cluster at (1,1,1)
121+
dof_inds = np.array([[1.0, 1.0, 1.1]])
122+
123+
result = mritk.data.orientation.find_nearest_valid_voxels(dof_inds, mask, k=2)
124+
125+
# Shape should be (3 dims, 2 neighbors, 1 point)
126+
assert result.shape == (3, 2, 1)
127+
128+
# Get the neighbors for the first (and only) query point
129+
neighbors = result[:, :, 0].T # Transpose to get list of coords: shape (2, 3)
130+
131+
# We expect (1,1,1) and (1,1,2) to be the neighbors.
132+
# KDTree returns sorted by distance.
133+
# Distance to (1,1,1) is 0.1
134+
# Distance to (1,1,2) is 0.9
135+
# So (1,1,1) should be first.
136+
np.testing.assert_array_equal(neighbors[0], [1, 1, 1])
137+
np.testing.assert_array_equal(neighbors[1], [1, 1, 2])
138+
139+
140+
def test_find_nearest_valid_voxels_empty_mask_error():
141+
"""Test behavior when no valid points exist (should raise ValueError from KDTree)."""
142+
mask = np.zeros((5, 5), dtype=bool)
143+
dof_inds = np.array([[1, 1]])
144+
145+
with pytest.raises(ValueError):
146+
mritk.data.orientation.find_nearest_valid_voxels(dof_inds, mask, k=1)

0 commit comments

Comments
 (0)