Skip to content

Commit 354d3cf

Browse files
committed
PR Comments
1 parent bd281c5 commit 354d3cf

2 files changed

Lines changed: 76 additions & 48 deletions

File tree

src/easyimaging/measurement/measurement.py

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,37 @@ def from_tiff_stack(
160160
raise RuntimeError(f"Failed to rename dimensions for TIFF stack file '{filename}': {e}") from e
161161

162162
time_of_flights = cls._validate_provided_coord(
163-
data_array, time_of_flights, 'time_of_flight', 't', 'frames in the TIFF stack', 'time', 's'
163+
data_array = data_array,
164+
coord = time_of_flights,
165+
coord_name = 'time_of_flight',
166+
dim = 't',
167+
length_context = 'frames in the TIFF stack',
168+
expected_dim_string = 'time',
169+
expected_unit = 's'
164170
)
165171
data_array.coords['tof'] = time_of_flights
166172

167173
if x_positions is not None:
168174
x_positions = cls._validate_provided_coord(
169-
data_array, x_positions, 'x_positions', 'x', 'pixels in the x dimension', 'length', 'm'
175+
data_array = data_array,
176+
coord = x_positions,
177+
coord_name = 'x_positions',
178+
dim = 'x',
179+
length_context = 'pixels in the x dimension',
180+
expected_dim_string = 'length',
181+
expected_unit = 'm'
170182
)
171183
data_array.coords['x'] = x_positions
172184

173185
if y_positions is not None:
174186
y_positions = cls._validate_provided_coord(
175-
data_array, y_positions, 'y_positions', 'y', 'pixels in the y dimension', 'length', 'm'
187+
data_array = data_array,
188+
coord = y_positions,
189+
coord_name = 'y_positions',
190+
dim = 'y',
191+
length_context = 'pixels in the y dimension',
192+
expected_dim_string = 'length',
193+
expected_unit = 'm'
176194
)
177195
data_array.coords['y'] = y_positions
178196

@@ -260,13 +278,13 @@ def x_positions(self, value: sc.Variable | np.ndarray) -> None:
260278
'Please use the set_physical_coord_positions method.'
261279
)
262280
value = self._validate_provided_coord(
263-
self._data_array,
264-
value,
265-
'x_positions',
266-
'x',
267-
'pixels in the x dimension',
268-
'length',
269-
'm',
281+
data_array=self._data_array,
282+
coord=value,
283+
coord_name='x_positions',
284+
dim='x',
285+
length_context='pixels in the x dimension',
286+
expected_dim_string='length',
287+
expected_unit='m',
270288
)
271289
self._data_array.coords['x'] = value
272290

@@ -296,13 +314,13 @@ def y_positions(self, value: sc.Variable | np.ndarray) -> None:
296314
'Please use the set_physical_coord_positions method.'
297315
)
298316
value = self._validate_provided_coord(
299-
self._data_array,
300-
value,
301-
'y_positions',
302-
'y',
303-
'pixels in the y dimension',
304-
'length',
305-
'm',
317+
data_array=self._data_array,
318+
coord=value,
319+
coord_name='y_positions',
320+
dim='y',
321+
length_context='pixels in the y dimension',
322+
expected_dim_string='length',
323+
expected_unit='m',
306324
)
307325
self._data_array.coords['y'] = value
308326

@@ -322,22 +340,22 @@ def set_physical_coord_positions(
322340
If a numpy array is provided, the unit is assumed to be meters.
323341
"""
324342
x_positions = self._validate_provided_coord(
325-
self._data_array,
326-
x_positions,
327-
'x_positions',
328-
'x',
329-
'pixels in the x dimension',
330-
'length',
331-
'm',
343+
data_array=self._data_array,
344+
coord=x_positions,
345+
coord_name='x_positions',
346+
dim='x',
347+
length_context='pixels in the x dimension',
348+
expected_dim_string='length',
349+
expected_unit='m',
332350
)
333351
y_positions = self._validate_provided_coord(
334-
self._data_array,
335-
y_positions,
336-
'y_positions',
337-
'y',
338-
'pixels in the y dimension',
339-
'length',
340-
'm',
352+
data_array=self._data_array,
353+
coord=y_positions,
354+
coord_name='y_positions',
355+
dim='y',
356+
length_context='pixels in the y dimension',
357+
expected_dim_string='length',
358+
expected_unit='m',
341359
)
342360
self._data_array.coords['x'] = x_positions
343361
self._data_array.coords['y'] = y_positions
@@ -373,13 +391,13 @@ def time_of_flights(self, value: sc.Variable | np.ndarray) -> None:
373391
If a numpy array is provided, the unit is assumed to be seconds.
374392
"""
375393
value = self._validate_provided_coord(
376-
self._data_array,
377-
value,
378-
'time_of_flights',
379-
't',
380-
'frames in the measurement',
381-
'time',
382-
's',
394+
data_array=self._data_array,
395+
coord=value,
396+
coord_name='time_of_flights',
397+
dim='t',
398+
length_context='frames in the measurement',
399+
expected_dim_string='time',
400+
expected_unit='s',
383401
)
384402
if any(value.to(unit='s') < sc.scalar(0, unit='s')):
385403
raise ValueError('time_of_flight values must be non-negative.')
@@ -468,7 +486,11 @@ def plot(self, time_of_flight: int | sc.Variable | None = None, **kwargs) -> Non
468486
title_suffix = ' (averaged over TOF)'
469487
elif isinstance(time_of_flight, int):
470488
title_suffix = f' at TOF index {time_of_flight}'
471-
elif isinstance(time_of_flight, sc.Variable):
489+
elif isinstance(time_of_flight, sc.Variable) and not time_of_flight.sizes:
490+
try:
491+
time_of_flight.to(unit='s')
492+
except UnitError:
493+
raise UnitError("time_of_flight variable must have a unit of time such as 's'") from None
472494
title_suffix = f' at TOF={time_of_flight.value} {time_of_flight.unit}'
473495
else:
474496
raise TypeError('time_of_flight must be an integer, scipp scalar, or None.')
@@ -484,11 +506,7 @@ def plot(self, time_of_flight: int | sc.Variable | None = None, **kwargs) -> Non
484506
elif isinstance(time_of_flight, int):
485507
plot = self._data_array['t', time_of_flight].plot(**plot_kwargs_defaults)
486508
elif isinstance(time_of_flight, sc.Variable):
487-
try:
488-
plot = self._data_array['tof', time_of_flight].plot(**plot_kwargs_defaults)
489-
except UnitError:
490-
raise UnitError("time_of_flight variable must have a unit of time such as 's'") from None
491-
509+
plot = self._data_array['tof', time_of_flight].plot(**plot_kwargs_defaults)
492510
if _is_notebook():
493511
return plot
494512
else:
@@ -708,7 +726,7 @@ def spectrum(self, roi: RectROI | str | None = None) -> sc.DataArray:
708726
if roi in self.regions_of_interest:
709727
roi = self.regions_of_interest[roi]
710728
else:
711-
raise KeyError(f"ROI with unique name '{roi}' not found in the measurement's list of ROIs.")
729+
raise KeyError(f"ROI with unique name '{roi}' not found in the measurement's list of ROIs: [{', '.join(item.unique_name for item in self.regions_of_interest)}].") # noqa: E501
712730
if roi is None:
713731
spectrum_data = self._data_array.mean(dim=['x', 'y'])
714732
elif self._has_physical_coords and roi._has_physical_coords:
@@ -781,7 +799,7 @@ def _validate_provided_coord(
781799
expected_dim_string: str,
782800
expected_unit: str,
783801
) -> sc.Variable:
784-
if not isinstance(coord, (sc.Variable, np.ndarray)):
802+
if not (isinstance(coord, sc.Variable) and coord.sizes) and not isinstance(coord, np.ndarray):
785803
raise TypeError(f'{coord_name} must be a scipp Variable or a numpy ndarray.')
786804
if len(coord) not in (data_array.sizes[dim], data_array.sizes[dim] + 1):
787805
raise ValueError(f'Length of {coord_name} array does not match the number of {length_context}.')

tests/unit_tests/measurement/test_measurement.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,13 @@ def test_from_tiff_stack_invalid_path(self, path, error):
383383
ValueError,
384384
'Length of time_of_flight array does not match the number of frames in the TIFF stack.',
385385
), # noqa: E501
386+
(sc.scalar(5.0, unit='s'), TypeError, 'time_of_flight must be a scipp Variable or a numpy ndarray.'),
386387
(sc.arange('t', 0, 240, 1, unit='m'), sc.UnitError, "time_of_flight must have a unit of time, such as 's'"),
387388
],
388389
ids=[
389390
'invalid_type',
390391
'wrong_length',
392+
'scalar',
391393
'invalid_unit',
392394
],
393395
)
@@ -1001,12 +1003,20 @@ def test_plot_pixels(self, valid_data_array, plot_setup):
10011003
assert fig.view.colormapper.vmax == 1.1
10021004
assert fig.view.colormapper.vmin == 0.0
10031005

1004-
def test_plot_invalid_time_of_flight_type(self, valid_data_array):
1006+
@pytest.mark.parametrize(
1007+
'time_of_flight',
1008+
[
1009+
('not_a_valid_type',),
1010+
(sc.array(dims=['tof'], values=[5.0, 10.0], unit='s'),),
1011+
],
1012+
ids=['invalid_type', 'array_input'],
1013+
)
1014+
def test_plot_invalid_time_of_flight_type(self, valid_data_array, time_of_flight):
10051015
# When
10061016
measurement = Measurement(data_array=valid_data_array)
10071017
# Then Expect
10081018
with pytest.raises(TypeError, match='time_of_flight must be an integer, scipp scalar, or None.'):
1009-
measurement.plot(time_of_flight='not_a_valid_type')
1019+
measurement.plot(time_of_flight=time_of_flight)
10101020

10111021
def test_plot_invalid_time_of_flight_unit(self, valid_data_array):
10121022
# When

0 commit comments

Comments
 (0)