-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnumerical_convolution_base.py
More file actions
406 lines (337 loc) · 15.8 KB
/
numerical_convolution_base.py
File metadata and controls
406 lines (337 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
# SPDX-FileCopyrightText: 2026 EasyScience contributors <https://github.com/easyscience>
# SPDX-License-Identifier: BSD-3-Clause
import warnings
# from dataclasses import dataclass
import numpy as np
import scipp as sc
from easyscience.variable import Parameter
from easydynamics.convolution.convolution_base import ConvolutionBase
from easydynamics.convolution.energy_grid import EnergyGrid
from easydynamics.sample_model.component_collection import ComponentCollection
from easydynamics.sample_model.components.model_component import ModelComponent
from easydynamics.utils.utils import Numeric
# The thresholds are illustrated in
# performance_tests/convolution/convolution_width_thresholds.ipynb
LARGE_WIDTH_THRESHOLD = (
0.1 # Threshold for large widths compared to span - warn if width > 10% of span
)
SMALL_WIDTH_THRESHOLD = (
1.0 # Threshold for small widths compared to bin spacing - warn if width < dx
)
class NumericalConvolutionBase(ConvolutionBase):
"""Base class for numerical convolutions of sample and resolution
models.
Provides methods to handle upsampling, extension, and detailed
balance correction. This base class has no convolution
functionality.
"""
def __init__(
self,
energy: np.ndarray | sc.Variable,
sample_components: ComponentCollection | ModelComponent,
resolution_components: ComponentCollection | ModelComponent,
energy_offset: Numeric | Parameter = 0.0,
upsample_factor: Numeric | None = 5,
extension_factor: Numeric | None = 0.2,
temperature: Parameter | Numeric | None = None,
temperature_unit: str | sc.Unit = 'K',
energy_unit: str | sc.Unit = 'meV',
normalize_detailed_balance: bool = True,
) -> None:
"""Initialize the NumericalConvolutionBase.
Args:
energy (np.ndarray | sc.Variable): 1D array of energy values
where the convolution is evaluated.
sample_components (ComponentCollection | ModelComponent):
The components to be convolved.
resolution_components (ComponentCollection | ModelComponent):
The resolution components to convolve with.
energy_offset (Numeric | Parameter, default=0.0): An energy
offset to apply to the energy values before convolution.
upsample_factor (Numeric | None, default=5): The factor by which to
upsample the input data before convolution.
extension_factor (Numeric | None, default=0.2): The factor by which to
extend the input data range before convolution.
temperature (Parameter | Numeric | None, default=None): The temperature to
use for detailed balance correction.
temperature_unit (str | sc.Unit, default='K'): The unit of the
temperature parameter.
energy_unit (str | sc.Unit, default='meV'): The unit of the energy.
normalize_detailed_balance (bool, default=True): Whether to normalize the
detailed balance correction.
Raises:
TypeError: If temperature is not None, a number, or a
Parameter, or if temperature_unit is not a string or sc.Unit, or if
upsample_factor is not a number or None, or if extension_factor
is not a number, or if normalize_detailed_balance is not a bool.
"""
super().__init__(
energy=energy,
sample_components=sample_components,
resolution_components=resolution_components,
energy_unit=energy_unit,
energy_offset=energy_offset,
)
if temperature is not None and not isinstance(temperature, (Numeric, Parameter)):
raise TypeError('Temperature must be None, a number or a Parameter.')
if not isinstance(temperature_unit, (str, sc.Unit)):
raise TypeError('Temperature_unit must be a string or sc.Unit.')
self._temperature_unit = temperature_unit
self._temperature = None
self.temperature = temperature
self._normalize_detailed_balance = normalize_detailed_balance
self._upsample_factor = upsample_factor
self._extension_factor = extension_factor
# Create a dense grid to improve accuracy.
# When upsample_factor>1, we evaluate on this grid and
# interpolate back to the original values at the end
self._energy_grid = self._create_energy_grid()
@ConvolutionBase.energy.setter
def energy(self, energy: np.ndarray) -> None:
"""Set the energy array and recreate the dense grid.
Args:
energy (np.ndarray): The new energy array.
"""
ConvolutionBase.energy.fset(self, energy)
# Recreate dense grid when energy is updated
self._energy_grid = self._create_energy_grid()
@property
def upsample_factor(self) -> Numeric | None:
"""Get the upsample factor.
Returns:
Numeric | None: The upsample factor.
"""
return self._upsample_factor
@upsample_factor.setter
def upsample_factor(self, factor: Numeric | None) -> None:
"""Set the upsample factor and recreate the dense grid.
Args:
factor (Numeric | None): The new upsample factor.
Raises:
TypeError: If factor is not a number or None.
ValueError: If factor is not greater than 1.
"""
if factor is None:
self._upsample_factor = factor
self._energy_grid = self._create_energy_grid()
return
if not isinstance(factor, Numeric):
raise TypeError('Upsample factor must be a numerical value or None.')
factor = float(factor)
if factor <= 1.0:
raise ValueError('Upsample factor must be greater than 1.')
self._upsample_factor = factor
# Recreate dense grid when upsample factor is updated
self._energy_grid = self._create_energy_grid()
@property
def extension_factor(self) -> float:
"""Get the extension factor.
The extension factor determines how much the energy range is
extended on both sides before convolution.
0.2 means extending by 20% of the original energy span
on each side
Returns:
float: The extension factor.
"""
return self._extension_factor
@extension_factor.setter
def extension_factor(self, factor: Numeric) -> None:
"""
Set the extension factor and recreate the dense grid.
The extension factor determines how much the energy range is
extended on both sides before convolution.
0.2 means extending by 20% of the original energy span
on each side.
Args:
factor (Numeric): The new extension factor.
Raises:
TypeError: If factor is not a number.
ValueError: If factor is negative.
"""
if not isinstance(factor, Numeric):
raise TypeError('Extension factor must be a number.')
if factor < 0.0:
raise ValueError('Extension factor must be non-negative.')
self._extension_factor = float(factor)
# Recreate dense grid when extension factor is updated
self._energy_grid = self._create_energy_grid()
@property
def temperature(self) -> Parameter | None:
"""Get the temperature.
Returns:
Parameter | None: The temperature parameter, or None if
detailed balance correction is disabled.
"""
return self._temperature
@temperature.setter
def temperature(self, temp: Parameter | Numeric | None) -> None:
"""Set the temperature.
If None, disables detailed balance
correction and removes the temperature parameter.
Args:
temp (Parameter | Numeric | None): The temperature to set.
The unit will be the same as the existing temperature
parameter if it exists, otherwise 'K'.
Raises:
TypeError: If temp is not a Numeric, Parameter, or None.
"""
if temp is None:
self._temperature = None
elif isinstance(temp, Numeric):
if self._temperature is not None:
self._temperature.value = float(temp)
else:
self._temperature = Parameter(
name='temperature',
value=float(temp),
unit=self._temperature_unit,
fixed=True,
)
elif isinstance(temp, Parameter):
self._temperature = temp
else:
raise TypeError('Temperature must be None, a float or a Parameter.')
@property
def normalize_detailed_balance(self) -> bool:
"""Get whether to normalize the detailed balance factor.
If True, the detailed balance factor is divided by temperature.
Returns:
bool: Whether to normalize the detailed balance factor.
"""
return self._normalize_detailed_balance
@normalize_detailed_balance.setter
def normalize_detailed_balance(self, normalize: bool) -> None:
"""Set whether to normalize the detailed balance factor.
If True, the detailed balance factor is divided by temperature.
Args:
normalize (bool): Whether to normalize the detailed balance
factor.
Raises:
TypeError: If normalize is not a bool.
"""
if not isinstance(normalize, bool):
raise TypeError('normalize_detailed_balance must be True or False.')
self._normalize_detailed_balance = normalize
def _create_energy_grid(
self,
) -> EnergyGrid:
"""Create a dense grid by upsampling and extending the energy
array.
If upsample_factor is None, no upsampling or extension is
performed.
This dense grid is used for convolution to improve accuracy.
Returns:
EnergyGrid: The dense grid created by upsampling and
extending energy.
Raises:
ValueError: If energy array is not uniformly spaced when
upsample_factor is None, or if energy array has less than 2 points.
"""
if self.upsample_factor is None:
# Check if the array is uniformly spaced.
energy_diff = np.diff(self.energy.values)
is_uniform = np.allclose(energy_diff, energy_diff[0])
if not is_uniform:
raise ValueError(
'Input array `energy` must be uniformly spaced if upsample_factor is not given.' # noqa: E501
)
energy_dense = self.energy.values
energy_span_dense = self.energy.values.max() - self.energy.values.min()
else:
# Create an extended and upsampled energy grid
energy_min, energy_max = self.energy.values.min(), self.energy.values.max()
energy_span_original = energy_max - energy_min
extra = self.extension_factor / 2 * energy_span_original
extended_min = energy_min - extra
extended_max = energy_max + extra
num_points = round(len(self.energy.values) * self.upsample_factor)
energy_dense = np.linspace(extended_min, extended_max, num_points)
energy_span_dense = extended_max - extended_min
if len(energy_dense) < 2:
raise ValueError('Energy array must have at least two points.')
energy_dense_step = energy_dense[1] - energy_dense[0]
# Handle offset for even length of energy_dense in convolution.
# The convolution of two arrays of length N is of length 2N-1.
# When using 'same' mode, only the central N points are kept,
# so the output has the same length as the input.
# However, if N is even, the center falls between two points,
# leading to a half-bin offset.
# For example, if N=4, the convolution has length 7, and when we
# select the 4 central points we either get
# indices [2,3,4,5] or [1,2,3,4], both of which are offset by
# 0.5*dx from the true center at index 3.5.
energy_even_length_offset = -0.5 * energy_dense_step if len(energy_dense) % 2 == 0 else 0.0
# Handle the case when energy_dense is not symmetric around 0.
# The resolution is still centered around zero (or close to it),
# so it needs to be evaluated there.
if not np.isclose(energy_dense.mean(), 0.0):
energy_dense_centered = np.linspace(
-0.5 * energy_span_dense, 0.5 * energy_span_dense, len(energy_dense)
)
else:
energy_dense_centered = energy_dense
return EnergyGrid(
energy_dense=energy_dense,
energy_dense_centered=energy_dense_centered,
energy_dense_step=energy_dense_step,
energy_span_dense=energy_span_dense,
energy_even_length_offset=energy_even_length_offset,
)
def _check_width_thresholds(
self,
model: ComponentCollection | ModelComponent,
model_name: str,
) -> None:
"""Helper function to check and warn if components are wide
compared to the span of the data, or narrow compared to the
spacing.
In both cases, the convolution accuracy may be compromised.
Args:
model (ComponentCollection | ModelComponent): The model to
check
model_name (str): A string indicating whether the model is a
'sample model' or 'resolution model' for warning
messages.
"""
# Handle ComponentCollection or ModelComponent
components = model.components if isinstance(model, ComponentCollection) else [model]
for comp in components:
if hasattr(comp, 'width'):
if comp.width.value > LARGE_WIDTH_THRESHOLD * self._energy_grid.energy_span_dense:
warnings.warn(
f"The width of the {model_name} component '{comp.unique_name}' \
({comp.width.value}) is large compared to the span of the input "
f'array ({self._energy_grid.energy_span_dense}). \
This may lead to inaccuracies in the convolution. \
Increase extension_factor to improve accuracy.',
UserWarning,
stacklevel=3,
)
if comp.width.value < SMALL_WIDTH_THRESHOLD * self._energy_grid.energy_dense_step:
warnings.warn(
f"The width of the {model_name} component '{comp.unique_name}' \
({comp.width.value}) is small compared to the spacing of the input "
f'array ({self._energy_grid.energy_dense_step}). \
This may lead to inaccuracies in the convolution. \
Increase upsample_factor to improve accuracy.',
UserWarning,
stacklevel=3,
)
def __repr__(self) -> str:
"""Return a string representation of the
NumericalConvolutionBase.
Returns:
str: A string representation of the
NumericalConvolutionBase.
"""
return (
f'{self.__class__.__name__}('
f'energy=array of shape {self.energy.values.shape},\n '
f'sample_components={repr(self.sample_components)}, \n'
f'resolution_components={repr(self.resolution_components)},\n '
f'energy_unit={self._energy_unit}, '
f'upsample_factor={self.upsample_factor}, '
f'extension_factor={self.extension_factor}, '
f'temperature={self.temperature}, '
f'normalize_detailed_balance={self.normalize_detailed_balance})'
)