Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/aind_data_schema/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,56 @@ def validate_time_constraints(self):

return self

def _check_acquisition_modalities(self, dd_modalities):
"""Warn if acquisition.data_streams modalities do not exactly match data_description modalities."""
if not (self.acquisition and self.acquisition.data_streams):
return
acq_modalities = set()
for data_stream in self.acquisition.data_streams:
acq_modalities.update(data_stream.modalities)
if acq_modalities != dd_modalities:
parts = []
if acq_modalities - dd_modalities:
parts.append(f"in acquisition but not data_description: {acq_modalities - dd_modalities}")
if dd_modalities - acq_modalities:
parts.append(f"in data_description but not acquisition: {dd_modalities - acq_modalities}")
warnings.warn(
f"Modality mismatch between acquisition.data_streams and data_description. {'; '.join(parts)}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but something like "behavior" won't necessarily show up in the data_stream as a modality. How do we handle that?

)

def _check_instrument_modalities(self, dd_modalities):
"""Warn if instrument modalities are not a superset of data_description modalities."""
if not (self.instrument and hasattr(self.instrument, "modalities")):
return
missing = dd_modalities - set(self.instrument.modalities)
if missing:
warnings.warn(
f"Instrument modalities are not a superset of data_description modalities. "
f"Missing from instrument: {missing}"
)

def _check_quality_control_modalities(self, dd_modalities):
"""Warn if any QualityControl metric modalities are not present in data_description modalities."""
if not (self.quality_control and hasattr(self.quality_control, "modalities")):
return
extra = set(self.quality_control.modalities) - dd_modalities
if extra:
warnings.warn(
f"QualityControl metric modalities {extra} are not present"
f" in data_description modalities {dd_modalities}"
)

@model_validator(mode="after")
def validate_modality_consistency(self):
"""Validate that modalities are consistent across core files relative to data_description"""
if not self.data_description or not hasattr(self.data_description, "modalities"):
return self
dd_modalities = set(self.data_description.modalities)
self._check_acquisition_modalities(dd_modalities)
self._check_instrument_modalities(dd_modalities)
self._check_quality_control_modalities(dd_modalities)
return self

@model_validator(mode="after")
def validate_data_description_name_time_consistency(self):
"""Validate that the creation_time from data_description.name is on or after midnight
Expand Down
4 changes: 2 additions & 2 deletions src/aind_data_schema/utils/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def _validate_time_constraint(field_value, time_validation, start_time, end_time
f"Field '{field_name}' with value {field_value} must be between {start_time} and {end_time}"
)
elif time_validation == TimeValidation.AFTER:
if comparable_field_value <= start_time:
if comparable_field_value < start_time:
raise ValueError(f"Field '{field_name}' with value {field_value} must be after {start_time}")
elif time_validation == TimeValidation.BEFORE:
if comparable_field_value >= end_time:
if comparable_field_value > end_time:
raise ValueError(f"Field '{field_name}' with value {field_value} must be before {end_time}")


Expand Down
118 changes: 118 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@

from examples.data_description import d as data_description
from examples.subject import s as subject
from examples.exaspim_acquisition import acq as exaspim_acq
from examples.exaspim_instrument import inst as exaspim_inst
from examples.exaspim_quality_control import quality_control as exaspim_qc

ephys_assembly = EphysAssembly(
probes=[EphysProbe(probe_model="Neuropixels 1.0", name="Probe A")],
Expand Down Expand Up @@ -946,6 +949,121 @@ def test_validate_subject_details_if_not_specimen(self):
)
self.assertIn("Acquisition.subject_details are required for in vivo experiments", str(context.exception))

def test_validate_modality_consistency(self):
"""Tests that modality consistency validator issues correct warnings."""
import warnings as warnings_module

dd_spim = data_description.model_copy(update={"modalities": [Modality.SPIM]})
dd_ecephys = data_description.model_copy(update={"modalities": [Modality.ECEPHYS]})
dd_spim_ecephys = data_description.model_copy(update={"modalities": [Modality.SPIM, Modality.ECEPHYS]})
inst_no_cal = exaspim_inst.model_copy(update={"calibrations": None})

# Case 1: acquisition modalities match data_description - no warning
with warnings_module.catch_warnings(record=True) as w:
warnings_module.simplefilter("always")
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_spim,
instrument=inst_no_cal,
acquisition=exaspim_acq,
)
modality_warnings = [str(x.message) for x in w if "Modality mismatch" in str(x.message)]
self.assertEqual([], modality_warnings)

# Case 2: acquisition has modality not in data_description - warns
with self.assertWarns(UserWarning) as w:
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_ecephys,
instrument=inst_no_cal,
acquisition=exaspim_acq,
)
warning_messages = [str(x.message) for x in w.warnings]
self.assertTrue(any("Modality mismatch" in m for m in warning_messages))
self.assertTrue(any("in acquisition but not data_description" in m for m in warning_messages))

# Case 3: data_description has modality not in acquisition - warns
with self.assertWarns(UserWarning) as w:
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_spim_ecephys,
instrument=inst_no_cal,
acquisition=exaspim_acq,
)
warning_messages = [str(x.message) for x in w.warnings]
self.assertTrue(any("Modality mismatch" in m for m in warning_messages))
self.assertTrue(any("in data_description but not acquisition" in m for m in warning_messages))

# Case 4: instrument modalities are a superset of data_description - no warning
with warnings_module.catch_warnings(record=True) as w:
warnings_module.simplefilter("always")
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_spim,
instrument=exaspim_inst,
)
inst_warnings = [str(x.message) for x in w if "superset" in str(x.message)]
self.assertEqual([], inst_warnings)

# Case 5: instrument modalities missing a modality from data_description - warns
with self.assertWarns(UserWarning) as w:
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_spim_ecephys,
instrument=exaspim_inst,
)
warning_messages = [str(x.message) for x in w.warnings]
self.assertTrue(any("superset" in m for m in warning_messages))

# Case 6: QC modalities are a subset of data_description - no warning
with warnings_module.catch_warnings(record=True) as w:
warnings_module.simplefilter("always")
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_spim,
quality_control=exaspim_qc,
)
qc_warnings = [str(x.message) for x in w if "QualityControl" in str(x.message)]
self.assertEqual([], qc_warnings)

# Case 7: QC has modality not in data_description - warns
with self.assertWarns(UserWarning) as w:
Metadata(
name="Test",
location="loc",
subject=subject,
data_description=dd_ecephys,
quality_control=exaspim_qc,
)
warning_messages = [str(x.message) for x in w.warnings]
self.assertTrue(any("QualityControl" in m for m in warning_messages))

# Case 8: no data_description - no modality warnings
with warnings_module.catch_warnings(record=True) as w:
warnings_module.simplefilter("always")
Metadata(
name="Test",
location="loc",
subject=subject,
)
modality_warnings = [
str(x.message) for x in w
if any(kw in str(x.message) for kw in ["Modality mismatch", "superset", "QualityControl metric"])
]
self.assertEqual([], modality_warnings)


if __name__ == "__main__":
unittest.main()
Loading