Skip to content

Commit a74a17f

Browse files
committed
Make timebinwidth and nbins mutually exclusive option.
1 parent 9ea16bc commit a74a17f

4 files changed

Lines changed: 134 additions & 63 deletions

File tree

packages/essnmx/src/ess/nmx/_executable_helper.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,53 @@ def _retrieve_field_value(
6868
return getattr(args, field_name)
6969

7070

71+
_EXCLUSIVE_FIELDS_GROUP = {
72+
'time_bin_width': 'Bin Edges Configuration',
73+
'nbins': 'Bin Edges Configuration',
74+
}
75+
76+
77+
def _add_argument_from_field(
78+
group: argparse._ArgumentGroup, field_name: str, field_info: FieldInfo
79+
):
80+
add_argument = partial(group.add_argument, f"--{field_name.replace('_', '-')}")
81+
82+
if not _validate_annotation(field_info.annotation):
83+
raise TypeError(f"Unsupported annotation type: {field_info.annotation}")
84+
85+
arg_type = _get_no_nonetype_args(field_info.annotation)
86+
if _is_appendable_type(arg_type):
87+
nargs = '+'
88+
arg_type = get_args(field_info.annotation)[0]
89+
else:
90+
nargs = None
91+
arg_type = arg_type
92+
93+
required = field_info.default is PydanticUndefined
94+
default = ... if required else field_info.default
95+
96+
if arg_type is bool:
97+
add_argument = partial(add_argument, action='store_true')
98+
elif isinstance(arg_type, type) and issubclass(arg_type, enum.StrEnum):
99+
add_argument = partial(
100+
add_argument,
101+
type=str,
102+
choices=[str(e) for e in arg_type],
103+
)
104+
default = default.name if isinstance(default, enum.StrEnum) else default
105+
elif get_origin(arg_type) is Literal:
106+
add_argument = partial(
107+
add_argument,
108+
type=str,
109+
choices=[str(lit) for lit in get_args(arg_type)],
110+
)
111+
else:
112+
add_argument = partial(add_argument, type=arg_type, nargs=nargs)
113+
114+
help_text = ' '.join([field_info.description or '', f"(default: {default})"])
115+
add_argument(default=default, required=required, help=help_text)
116+
117+
71118
def add_args_from_pydantic_model(
72119
*, model_cls: type[BaseModel], parser: argparse.ArgumentParser
73120
) -> argparse.ArgumentParser:
@@ -99,43 +146,22 @@ def add_args_from_pydantic_model(
99146
group = parser.add_argument_group(
100147
model_cls.model_config.get("title", model_cls.__name__)
101148
)
149+
exclusive_groups: dict[str, list[tuple[str, FieldInfo]]] = {}
102150
for field_name, field_info in model_cls.model_fields.items():
103-
add_argument = partial(group.add_argument, f"--{field_name.replace('_', '-')}")
104-
105-
if not _validate_annotation(field_info.annotation):
106-
raise TypeError(f"Unsupported annotation type: {field_info.annotation}")
107-
108-
arg_type = _get_no_nonetype_args(field_info.annotation)
109-
if _is_appendable_type(arg_type):
110-
nargs = '+'
111-
arg_type = get_args(field_info.annotation)[0]
112-
else:
113-
nargs = None
114-
arg_type = arg_type
115-
116-
required = field_info.default is PydanticUndefined
117-
default = ... if required else field_info.default
118-
119-
if arg_type is bool:
120-
add_argument = partial(add_argument, action='store_true')
121-
elif isinstance(arg_type, type) and issubclass(arg_type, enum.StrEnum):
122-
add_argument = partial(
123-
add_argument,
124-
type=str,
125-
choices=[str(e) for e in arg_type],
126-
)
127-
default = default.name if isinstance(default, enum.StrEnum) else default
128-
elif get_origin(arg_type) is Literal:
129-
add_argument = partial(
130-
add_argument,
131-
type=str,
132-
choices=[str(lit) for lit in get_args(arg_type)],
151+
if field_name in _EXCLUSIVE_FIELDS_GROUP:
152+
cur_grp = exclusive_groups.setdefault(
153+
_EXCLUSIVE_FIELDS_GROUP[field_name], []
133154
)
155+
cur_grp.append((field_name, field_info))
134156
else:
135-
add_argument = partial(add_argument, type=arg_type, nargs=nargs)
157+
_add_argument_from_field(
158+
group=group, field_name=field_name, field_info=field_info
159+
)
136160

137-
help_text = ' '.join([field_info.description or '', f"(default: {default})"])
138-
add_argument(default=default, required=required, help=help_text)
161+
for fields in exclusive_groups.values():
162+
exclusive_group = group.add_mutually_exclusive_group(required=False)
163+
for field_name, field_info in fields:
164+
_add_argument_from_field(exclusive_group, field_name, field_info)
139165

140166
return parser
141167

packages/essnmx/src/ess/nmx/configurations.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
33
import enum
44

5-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, model_validator
66

77
from .types import Compression
88

@@ -64,8 +64,37 @@ class TimeBinCoordinate(enum.StrEnum):
6464
time_of_flight = 'time_of_flight'
6565

6666

67+
class _NotSet: ...
68+
69+
70+
_notset = _NotSet()
71+
72+
6773
class WorkflowConfig(BaseModel):
6874
# Add title of the basemodel
75+
@model_validator(mode='after')
76+
def nbins_or_time_bin_width(self):
77+
if self.time_bin_width is not None and self.nbins is not None:
78+
raise ValueError(
79+
"Either `nbins` or `time_bin_width` should be set. "
80+
"They cannot be set at the same time. "
81+
"It is allowed not setting any of them. "
82+
"Then 3 [ms] of `time_bin_width` will be used."
83+
)
84+
return self
85+
86+
@model_validator(mode='after')
87+
def positive_time_bin_width(self):
88+
if self.time_bin_width is not None and self.time_bin_width <= 0:
89+
raise ValueError("`time_bin_width` should be a positive number.")
90+
return self
91+
92+
@model_validator(mode='after')
93+
def positive_nbins(self):
94+
if self.nbins is not None and self.nbins <= 0:
95+
raise ValueError("`nbins` should be a positive integer.")
96+
return self
97+
6998
model_config = {"title": "Workflow Configuration"}
7099
time_bin_coordinate: TimeBinCoordinate = Field(
71100
title="Time Bin Coordinate",
@@ -78,19 +107,17 @@ class WorkflowConfig(BaseModel):
78107
# Default is time of flight since
79108
# DIALS should expect the time of flight.
80109
)
81-
time_bin_width: int = Field(
110+
time_bin_width: int | None = Field(
82111
title="Time Bin Width",
83112
description="Width(Length) of each Time Bin in [time_bin_unit]. "
84-
"If `time_bin_width` and `nbins` are both given, "
85-
"`time_bin_width` will be preferred. "
86-
"Set it to `0` if you want to use `nbins` instead.",
87-
default=3,
113+
"If none of `time_bin_width` or `nbins` is given, "
114+
"3 [ms] of `time_bin_width` will be used.",
115+
default=None,
88116
)
89-
nbins: int = Field(
117+
nbins: int | None = Field(
90118
title="Number of Time Bins",
91-
description="Number of Time bins. "
92-
"If `bin_width` is given, `nbins` will be ignored.",
93-
default=50,
119+
description="Number of Time bins. ",
120+
default=None,
94121
)
95122
min_time_bin: int | None = Field(
96123
title="Minimum Time",
@@ -230,11 +257,3 @@ def to_command_arguments(
230257
)
231258
else:
232259
return arg_list
233-
234-
235-
def validate_time_bin_config(config: ReductionConfig) -> None:
236-
wfconfig = config.workflow
237-
if not (wfconfig.time_bin_width > 0 or (wfconfig.nbins > 0)):
238-
raise ValueError(
239-
"Either `time-bin-width` or `nbins` should be a positive number."
240-
)

packages/essnmx/src/ess/nmx/executables.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
ReductionConfig,
2323
TimeBinCoordinate,
2424
WorkflowConfig,
25-
validate_time_bin_config,
2625
)
2726
from .nexus import (
2827
_check_file,
@@ -169,14 +168,20 @@ def _build_time_bin_edges(
169168
)
170169

171170
# If either min/max were manually selected and bin width is set.
172-
if wf_config.time_bin_width > 0:
171+
if wf_config.nbins is None:
172+
if wf_config.time_bin_width is None:
173+
time_bin_width = sc.scalar(3, unit='ms').to(unit=wf_config.time_bin_unit)
174+
elif wf_config.time_bin_width is not None:
175+
time_bin_width = sc.scalar(
176+
wf_config.time_bin_width, unit=wf_config.time_bin_unit
177+
)
173178
# We do not return a scalar bin width since we histogram
174179
# detector panel individually.
175180
return sc.arange(
176181
dim=t_coord_name,
177182
start=min_t.to(unit=wf_config.time_bin_unit),
178183
stop=max_t.to(unit=wf_config.time_bin_unit),
179-
step=sc.scalar(wf_config.time_bin_width, unit=wf_config.time_bin_unit),
184+
step=time_bin_width,
180185
)
181186
else: # Number of bin edges are given but not the bin width.
182187
n_edges = wf_config.nbins + 1
@@ -222,8 +227,6 @@ def reduction(
222227
if not config.output.skip_file_output:
223228
_check_file(config.output.output_file, config.output.overwrite)
224229

225-
validate_time_bin_config(config=config)
226-
227230
display = _retrieve_display(logger, display)
228231
input_file_path = _retrieve_input_file(config.inputs.input_file).resolve()
229232
display(f"Input file: {input_file_path}")

packages/essnmx/tests/executable_test.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _check_non_default_config(testing_config: ReductionConfig) -> None:
6969
testing_model = testing_child.model_dump(mode='python')
7070
default_model = default_child.model_dump(mode='python')
7171
for key, testing_value in testing_model.items():
72-
if key == 'lookup_table_file_path':
72+
if key in ['lookup_table_file_path', 'nbins']:
7373
# This value may be None or default, so we skip the check.
7474
continue
7575
default_value = default_model[key]
@@ -91,7 +91,6 @@ def test_reduction_config() -> None:
9191
)
9292
workflow_options = WorkflowConfig(
9393
time_bin_width=5,
94-
nbins=100,
9594
min_time_bin=10,
9695
max_time_bin=100_000,
9796
time_bin_coordinate=TimeBinCoordinate.event_time_offset,
@@ -194,7 +193,7 @@ def test_executable_runs(small_nmx_nexus_path, tmp_path: pathlib.Path):
194193
_check_output_file(output_file, bin_width=bin_width)
195194

196195

197-
def test_executable_runs_nbins(small_nmx_nexus_path, tmp_path: pathlib.Path):
196+
def test_executable_runs_exclusive_mutual(small_nmx_nexus_path, tmp_path: pathlib.Path):
198197
"""Test that the executable runs and returns the expected output."""
199198
output_file = tmp_path / "output.h5"
200199
assert not output_file.exists()
@@ -206,7 +205,30 @@ def test_executable_runs_nbins(small_nmx_nexus_path, tmp_path: pathlib.Path):
206205
'--input-file',
207206
small_nmx_nexus_path,
208207
'--time-bin-width',
209-
'0',
208+
'3',
209+
'--nbins',
210+
str(nbins),
211+
'--output-file',
212+
output_file.as_posix(),
213+
)
214+
# Validate that all commands are strings and contain no unsafe characters
215+
result = subprocess.run( # noqa: S603 - We are not accepting arbitrary input here.
216+
commands, text=True, capture_output=True, check=False
217+
)
218+
assert result.returncode == 2 # Should fail with command syntax error.
219+
220+
221+
def test_executable_runs_nbins(small_nmx_nexus_path, tmp_path: pathlib.Path):
222+
"""Test that the executable runs and returns the expected output."""
223+
output_file = tmp_path / "output.h5"
224+
assert not output_file.exists()
225+
226+
nbins = 20 # Small number of bins for testing.
227+
# The output has 1280x1280 pixels per detector per time bin.
228+
commands = (
229+
'essnmx-reduce',
230+
'--input-file',
231+
small_nmx_nexus_path,
210232
'--nbins',
211233
str(nbins),
212234
'--output-file',
@@ -276,7 +298,7 @@ def test_reduction_only_time_bin_width(reduction_config: ReductionConfig) -> Non
276298

277299

278300
def test_reduction_only_number_of_time_bins(reduction_config: ReductionConfig) -> None:
279-
reduction_config.workflow.time_bin_width = 0
301+
reduction_config.workflow.time_bin_width = None
280302
reduction_config.workflow.nbins = 20
281303
with known_warnings():
282304
hist = _retrieve_one_hist(reduction(config=reduction_config))
@@ -300,7 +322,7 @@ def test_histogram_event_time_offset(reduction_config: ReductionConfig) -> None:
300322

301323

302324
def test_histogram_event_time_offset_nbins(reduction_config: ReductionConfig) -> None:
303-
reduction_config.workflow.time_bin_width = 0
325+
reduction_config.workflow.time_bin_width = None
304326
reduction_config.workflow.nbins = 20
305327
reduction_config.workflow.time_bin_coordinate = TimeBinCoordinate.event_time_offset
306328
with known_warnings():
@@ -417,7 +439,8 @@ def test_reduction_with_lut_file(
417439
# the number of bins changes and the histogram data sizes changes.
418440
# This test is only for checking if the look up table is used as expected or not
419441
# therefore using number of bins should be fine.
420-
reduction_config.workflow.time_bin_width = 0
442+
reduction_config.workflow.time_bin_width = None
443+
reduction_config.workflow.nbins = 20
421444
# Make sure the config uses no lookup table file initially.
422445
assert reduction_config.workflow.lookup_table_file_path is None
423446
with known_warnings():

0 commit comments

Comments
 (0)