Skip to content

Commit c4a1397

Browse files
committed
Add save_scitiff method
1 parent de17e30 commit c4a1397

2 files changed

Lines changed: 93 additions & 8 deletions

File tree

src/easyimaging/measurement/measurement.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from scipp import DimensionError
1717
from scipp import UnitError
1818
from scitiff import load_scitiff
19+
from scitiff import save_scitiff
1920

2021
from ..utils import _is_notebook
2122
from ..utils import _to_edges
@@ -178,6 +179,33 @@ def from_tiff_stack(
178179
instance = cls(data_array=data_array, unique_name=unique_name, display_name=display_name)
179180
return instance
180181

182+
def save_scitiff(self, filename: str | Path) -> None:
183+
"""
184+
Save the measurement data to a SciTIFF file with its time-of-flight and physical coordinate information.
185+
Note that it is the rebinned data array that is saved if rebinning has been applied.
186+
187+
Parameters
188+
----------
189+
filename : str | Path
190+
Path to the output SciTIFF file.
191+
"""
192+
if not isinstance(filename, (str, Path)):
193+
raise TypeError('filename must be a string or Path object.')
194+
try:
195+
nan_mask = self._data_array.masks.pop('non_finite')
196+
if str(self._data_array.dtype)=='float64':
197+
warnings.warn(
198+
"The data array is of type float64, which is not directly supported by the SciTIFF format. "
199+
"It will be downcast to float32 when saving, which may result in loss of precision. "
200+
)
201+
save_scitiff(self._data_array.astype('float32'), filename)
202+
else:
203+
save_scitiff(self._data_array, filename)
204+
except Exception as e:
205+
raise RuntimeError(f"Failed to save SciTIFF file '{filename}': {e}") from e
206+
finally:
207+
self._data_array.masks['non_finite'] = nan_mask # Ensure mask is restored
208+
181209
@property
182210
def _data_array(self) -> sc.DataArray:
183211
"""

tests/unit_tests/measurement/test_measurement.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from copy import copy
23
from pathlib import Path
34
from typing import MutableSequence
@@ -21,7 +22,7 @@ def valid_data_array(self):
2122
tof = sc.arange('t', 0, 10, 1, unit='s')
2223
x = sc.arange('x', 0, 7, 1, unit='m')
2324
y = sc.arange('y', 0, 7, 1, unit='m')
24-
data = sc.ones(dims=['y', 'x', 't'], shape=[6, 6, 10])
25+
data = sc.ones(dims=['t', 'y', 'x'], shape=[10, 6, 6])
2526
return sc.DataArray(
2627
data=data,
2728
coords={
@@ -34,15 +35,9 @@ def valid_data_array(self):
3435
@pytest.fixture
3536
def valid_data_array_no_xy_coords(self):
3637
tof = sc.arange('t', 0, 10, 1, unit='s')
37-
data = sc.zeros(dims=('y', 'x', 't'), shape=(6, 6, 10))
38+
data = sc.zeros(dims=('t', 'y', 'x'), shape=(10, 6, 6))
3839
return sc.DataArray(data=data, coords={'tof': tof})
3940

40-
# @pytest.fixture(autouse=True)
41-
# def _reset_mpl_defaults(self):
42-
# matplotlib.rcdefaults()
43-
# matplotlib.use('Agg')
44-
# pp.backends.reset()
45-
4641
@pytest.fixture(autouse=True)
4742
def _close_figures(self):
4843
"""
@@ -456,6 +451,68 @@ def test_from_tiff_stack_invalid_y_positions(self, coord, error, expected_messag
456451
y_positions=coord,
457452
)
458453

454+
def test_save_scitiff_downcast(self, valid_data_array, tmp_path):
455+
# When
456+
measurement = Measurement(data_array=valid_data_array)
457+
# Then Expect
458+
with pytest.warns(UserWarning, match="The data array is of type float64, which is not directly supported by the SciTIFF format. It will be downcast to float32 when saving, which may result in loss of precision."): # noqa: E501
459+
measurement.save_scitiff(tmp_path / 'test.tiff')
460+
assert 'non_finite' in measurement._data_array.masks
461+
462+
def test_save_scitiff(self, valid_data_array, tmp_path):
463+
# When
464+
data_array = valid_data_array.astype('float32')
465+
measurement = Measurement(data_array=data_array)
466+
# Then Expect
467+
with warnings.catch_warnings():
468+
warnings.simplefilter("error") # Ensure no warnings are raised for valid data type
469+
measurement.save_scitiff(tmp_path / 'test.tiff')
470+
assert 'non_finite' in measurement._data_array.masks
471+
472+
def test_save_scitiff_path_object(self, valid_data_array, tmp_path):
473+
# When
474+
data_array = valid_data_array.astype('float32')
475+
measurement = Measurement(data_array=data_array)
476+
path = Path(tmp_path) / 'test.tiff'
477+
# Then Expect
478+
with warnings.catch_warnings():
479+
warnings.simplefilter("error") # Ensure no warnings are raised for valid data type
480+
measurement.save_scitiff(path)
481+
assert 'non_finite' in measurement._data_array.masks
482+
483+
def test_save_scitiff_invalid_filename_type(self, valid_data_array):
484+
# When
485+
measurement = Measurement(data_array=valid_data_array)
486+
# Then Expect
487+
with pytest.raises(TypeError, match='filename must be a string or Path object.'):
488+
measurement.save_scitiff(12345)
489+
assert 'non_finite' in measurement._data_array.masks
490+
491+
def test_save_scitiff_non_existent_directory(self, valid_data_array):
492+
# When
493+
measurement = Measurement(data_array=valid_data_array)
494+
# Then Expect
495+
with pytest.raises(RuntimeError, match="Failed to save SciTIFF file 'non_existent_directory/test.tiff'"):
496+
measurement.save_scitiff('non_existent_directory/test.tiff')
497+
assert 'non_finite' in measurement._data_array.masks
498+
499+
def test_save_scitiff_rebinned_data_array(self, valid_data_array, tmp_path):
500+
# When
501+
data_array = valid_data_array.astype('float32')
502+
measurement = Measurement(data_array=data_array)
503+
measurement.rebin(dimensions={'x': 2, 'y': 2})
504+
# Then
505+
measurement.save_scitiff(tmp_path / 'test.tiff')
506+
loaded_measurement = Measurement.from_scitiff(tmp_path / 'test.tiff')
507+
# Expect
508+
assert sc.identical(loaded_measurement._data_array.coords['x_pixels'], sc.arange('x', 0, 4, 1))
509+
assert sc.identical(loaded_measurement._data_array.coords['y_pixels'], sc.arange('y', 0, 4, 1))
510+
loaded_measurement._data_array.coords.pop('x_pixels')
511+
loaded_measurement._data_array.coords.pop('y_pixels')
512+
measurement._data_array.coords.pop('x_pixels')
513+
measurement._data_array.coords.pop('y_pixels')
514+
assert sc.identical(loaded_measurement._data_array, measurement._data_array)
515+
459516
@pytest.mark.parametrize('coordinate', ['x_positions', 'y_positions'], ids=['x_coordinate', 'y_coordinate'])
460517
def test_positions(self, valid_data_array, coordinate):
461518
# When

0 commit comments

Comments
 (0)