Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/diffpy/utils/diffraction_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,27 @@ def _set_array_from_range(self, begin, end, step_size=None, n_steps=None):
array = np.linspace(begin, end, n_steps)
return array

def get_angle_index(self, angle):
count = 0
for i, target in enumerate(self.angles):
if angle == target:
def get_array_index(self, xtype, value):
Comment thread
yucongalicechen marked this conversation as resolved.
Outdated
"""
returns the index of a given value in the array associated with the specified xtype

Parameters
----------
xtype str
the xtype used to access the array
value float
the target value to search for

Returns
-------
the index of the value in the array
"""
if self.on_xtype(xtype) is None:
raise ValueError(_xtype_wmsg(xtype))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can remove this if I let on_xtype() to raise an error for invalid xtypes. Will do that in the other PR.

for i, target in enumerate(self.on_xtype(xtype)[0]):
if value == target:
return i
else:
count += 1
if count >= len(self.angles):
raise IndexError(f"WARNING: no angle {angle} found in angles list")
raise IndexError(f"WARNING: no matching value {value} found in the {xtype} array.")

def _set_xarrays(self, xarray, xtype):
self.all_arrays = np.empty(shape=(len(xarray), 4))
Expand Down
41 changes: 40 additions & 1 deletion tests/test_diffraction_objects.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from pathlib import Path

import numpy as np
import pytest
from freezegun import freeze_time

from diffpy.utils.diffraction_objects import DiffractionObject
from diffpy.utils.diffraction_objects import XQUANTITIES, DiffractionObject
from diffpy.utils.transforms import wavelength_warning_emsg


Expand Down Expand Up @@ -211,6 +212,44 @@ def _test_valid_diffraction_objects(actual_diffraction_object, function, expecte
return np.allclose(actual_array, expected_array)


def test_get_angle_index():
Comment thread
yucongalicechen marked this conversation as resolved.
test = DiffractionObject(
wavelength=0.71, xarray=np.array([30, 60, 90]), yarray=np.array([1, 2, 3]), xtype="tth"
)
actual_index = test.get_array_index(xtype="tth", value=30)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need more cases. Here everything is integers and there is a match. What do we want to happen if the value lies between two other values? Return nearest and a warning? What if it lies outside the range of values? Return nearest and a warning? What if it is really far away? Have a threshold after which we raise and error?

Copy link
Copy Markdown
Contributor Author

@yucongalicechen yucongalicechen Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sbillinge I addressed this in the new commit. please review.

assert actual_index == 0


params_index_bad = [
# UC1: empty array
(
[0.71, np.array([]), np.array([]), "tth", "tth", 10],
[IndexError, "WARNING: no matching value 10 found in the tth array."],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an error or a warning?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be an error. I've edited the error message.

),
# UC2: invalid xtype
(
[None, np.array([]), np.array([]), "tth", "invalid", 10],
[
ValueError,
f"WARNING: I don't know how to handle the xtype, 'invalid'. "
f"Please rerun specifying an xtype from {*XQUANTITIES, }",
],
),
# UC3: pre-defined array with non-matching value
(
[0.71, np.array([30, 60, 90]), np.array([1, 2, 3]), "tth", "q", 30],
[IndexError, "WARNING: no matching value 30 found in the q array."],
),
]


@pytest.mark.parametrize("inputs, expected", params_index_bad)
def test_get_angle_index_bad(inputs, expected):
test = DiffractionObject(wavelength=inputs[0], xarray=inputs[1], yarray=inputs[2], xtype=inputs[3])
with pytest.raises(expected[0], match=re.escape(expected[1])):
test.get_array_index(xtype=inputs[4], value=inputs[5])


def test_dump(tmp_path, mocker):
x, y = np.linspace(0, 5, 6), np.linspace(0, 5, 6)
directory = Path(tmp_path)
Expand Down