Skip to content

Commit dca04d8

Browse files
committed
fix tests
1 parent 719e54a commit dca04d8

4 files changed

Lines changed: 35 additions & 43 deletions

File tree

notebooks/transformations_demo.ipynb

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
"metadata": {},
88
"outputs": [],
99
"source": [
10-
"import zarr\n",
10+
"from pathlib import Path\n",
11+
"\n",
1112
"import numpy as np\n",
13+
"import zarr\n",
14+
"from ome_zarr_models._v06.collection import Collection\n",
15+
"from ome_zarr_models._v06.image import Image\n",
1216
"\n",
1317
"import ngff_transformations.graph\n",
1418
"import ngff_transformations.transform as ngt\n",
15-
"\n",
16-
"from pathlib import Path\n",
17-
"from ome_zarr_models._v06.image import Image\n",
18-
"from ome_zarr_models._v06.collection import Collection\n",
19-
"from ngff_transformations.graph import transform_graph_to_networkx, draw_graph"
19+
"from ngff_transformations.graph import draw_graph, transform_graph_to_networkx"
2020
]
2121
},
2222
{
@@ -44,9 +44,7 @@
4444
}
4545
],
4646
"source": [
47-
"EXAMPLE_PATH = (\n",
48-
" Path(\"../data/ngff-rfc5-coordinate-transformation-examples\")\n",
49-
")\n",
47+
"EXAMPLE_PATH = Path(\"../data/ngff-rfc5-coordinate-transformation-examples\")\n",
5048
"\n",
5149
"\n",
5250
"def get_all_zarrs(directory: Path) -> list[Path]:\n",
@@ -69,7 +67,7 @@
6967
"for zarr_path in get_all_zarrs(EXAMPLE_PATH):\n",
7068
" relative_path = zarr_path.relative_to(EXAMPLE_PATH)\n",
7169
"\n",
72-
" if 'organ' not in str(relative_path):\n",
70+
" if \"organ\" not in str(relative_path):\n",
7371
" continue\n",
7472
"\n",
7573
" group: Collection | Image\n",
@@ -78,7 +76,7 @@
7876
" group = Collection.from_zarr(zarr.open_group(zarr_path, mode=\"r\"))\n",
7977
" else:\n",
8078
" group = Image.from_zarr(zarr.open_group(zarr_path, mode=\"r\"))\n",
81-
" except Exception as e:\n",
79+
" except Exception:\n",
8280
" # raise e\n",
8381
" # print(str(e))\n",
8482
" # continue\n",
@@ -97,7 +95,11 @@
9795
"id": "72648a07-c8d1-4078-80ef-2b7fb4758451",
9896
"metadata": {},
9997
"outputs": [],
100-
"source": "transformation_path, (src_coord_system, tgt_coord_system) = ngff_transformations.graph.find_walks_in_graph(nx_graph, 'VOI-01.ome.zarr/1', None, 'overview.ome.zarr', 'anatomical')"
98+
"source": [
99+
"transformation_path, (src_coord_system, tgt_coord_system) = ngff_transformations.graph.find_walks_in_graph(\n",
100+
" nx_graph, \"VOI-01.ome.zarr/1\", None, \"overview.ome.zarr\", \"anatomical\"\n",
101+
")"
102+
]
101103
},
102104
{
103105
"cell_type": "code",
@@ -137,7 +139,7 @@
137139
"metadata": {},
138140
"outputs": [],
139141
"source": [
140-
"transformed_data = ngt.transform_with_sequence3D(data, ['y', 'x', 'z', 'c'], transformation_path, ['y', 'x', 'z', 'c'])"
142+
"transformed_data = ngt.transform_with_sequence3D(data, [\"y\", \"x\", \"z\", \"c\"], transformation_path, [\"y\", \"x\", \"z\", \"c\"])"
141143
]
142144
},
143145
{

src/ngff_transformations/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
__version__ = version("ngff-transformations")
44

55
from ngff_transformations.graph import (
6-
find_walks_in_graph,
76
draw_graph,
7+
find_walks_in_graph,
88
get_relative_path,
99
transform_graph_to_networkx,
1010
)

src/ngff_transformations/graph.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
from typing import Any
1+
import logging
22

33
import matplotlib.pyplot as plt
44
import networkx as nx
55
from ome_zarr_models._utils import TransformGraph
66
from ome_zarr_models._v06.coordinate_transforms import (
77
CoordinateSystemIdentifier,
8-
Transform,
98
Sequence,
9+
Transform,
1010
)
11-
from pydantic import ValidationError
12-
import logging
1311
from ome_zarr_models._v06.coordinate_transforms import (
1412
Sequence as SequenceTransformation,
1513
)
14+
from pydantic import ValidationError
1615

1716

1817
def transform_graph_to_networkx(tgraph: TransformGraph) -> nx.DiGraph:
@@ -125,9 +124,7 @@ def _get_name_of_subgraph(
125124
f"Ambiguous coordinate system name '{cs_identifier}' found in both root and subgraph '{path_name}'. Use full identifier."
126125
)
127126
if cs_identifier not in nodes and cs_path_name not in nodes:
128-
raise ValueError(
129-
f"Coordinate system '{cs_identifier}' not found in graph nodes."
130-
)
127+
raise ValueError(f"Coordinate system '{cs_identifier}' not found in graph nodes.")
131128
if cs_path_name in nodes:
132129
return cs_path_name
133130
return cs_identifier
@@ -159,9 +156,7 @@ def _add_transform_and_inverse_transformation_edges(
159156
pass
160157

161158

162-
def draw_graph(
163-
g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labels: bool = True
164-
) -> None:
159+
def draw_graph(g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labels: bool = True) -> None:
165160
"""
166161
Draw a NetworkX graph showing all nodes and edges with their names.
167162
@@ -219,9 +214,7 @@ def draw_graph(
219214
plt.show()
220215

221216

222-
def get_relative_path(
223-
graph: nx.DiGraph, source_coordinate_system: str, target_coordinate_system: str
224-
) -> list[str]:
217+
def get_relative_path(graph: nx.DiGraph, source_coordinate_system: str, target_coordinate_system: str) -> list[str]:
225218
cost_key = "cost"
226219
"""
227220
Get the relative path from one node to another in the transformation graph.
@@ -270,14 +263,10 @@ def create_sequence_transformation_from_graph_walk(
270263
edge_transformation = graph.get_edge_data(source, target)["transformation"]
271264
transformations.append(edge_transformation)
272265

273-
return Sequence(
274-
input=walk[0], output=walk[-1], transformations=tuple(transformations)
275-
)
266+
return Sequence(input=walk[0], output=walk[-1], transformations=tuple(transformations))
276267

277268

278-
def get_node(
279-
path: str | None = None, name: str | None = None
280-
) -> str | CoordinateSystemIdentifier:
269+
def get_node(path: str | None = None, name: str | None = None) -> str | CoordinateSystemIdentifier:
281270
if path is None and name is None:
282271
raise ValueError("Both path and name of the coordinate system cannot be None")
283272
if path is None:
@@ -287,17 +276,13 @@ def get_node(
287276
return CoordinateSystemIdentifier(path=path, name=name)
288277

289278

290-
def find_walks_in_graph(
291-
graph, src_path, src_name, tgt_path, tgt_name
292-
) -> list[str | CoordinateSystemIdentifier]:
279+
def find_walks_in_graph(graph, src_path, src_name, tgt_path, tgt_name) -> list[str | CoordinateSystemIdentifier]:
293280
src_node = get_node(src_path, src_name)
294281
tgt_node = get_node(tgt_path, tgt_name)
295282

296283
graph_walk = list(nx.all_shortest_paths(graph, src_node, tgt_node))
297284
if not graph_walk:
298-
raise ValueError(
299-
f"No path found from {src_node} to {tgt_node} in the transformation graph."
300-
)
285+
raise ValueError(f"No path found from {src_node} to {tgt_node} in the transformation graph.")
301286
if len(graph_walk) > 1:
302287
logging.warning(
303288
f"Multiple paths found from {src_node} to {tgt_node} in the transformation graph. Using the first one."

tests/test_graph.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
from ome_zarr_models._v06.image import Image
99

1010
from ngff_transformations.graph import (
11-
find_walks_in_graph,
1211
get_relative_path,
1312
transform_graph_to_networkx,
13+
create_sequence_transformation_from_graph_walk,
1414
)
1515

16-
EXAMPLE_PATH = Path(__file__).parent.parent / "data" / "ngff-rfc5-coordinate-transformation-examples"
16+
EXAMPLE_PATH = (
17+
Path(__file__).parent.parent
18+
/ "data"
19+
/ "ngff-rfc5-coordinate-transformation-examples"
20+
)
1721

1822

1923
def get_test_zarr_paths(data_dir: Path = EXAMPLE_PATH) -> list[Path]:
@@ -67,7 +71,8 @@ def test_graph(zarr_path: Path):
6771

6872
example_edge = list(nx_graph.edges)[0]
6973
path = get_relative_path(nx_graph, example_edge[0], example_edge[1])
70-
sequence_transformation = find_walks_in_graph(
71-
graph=nx_graph, path)
74+
sequence_transformation = create_sequence_transformation_from_graph_walk(
75+
graph=nx_graph, walk=path
76+
)
7277

7378
assert isinstance(sequence_transformation, Sequence)

0 commit comments

Comments
 (0)