diff --git a/pyproject.toml b/pyproject.toml index 8f09dcc..46e4e64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ dev = [ "pytest>=9.0.2", "stablehash==0.3.0", + "types-pika-ts", + "types-psycopg2", ] [project.scripts] diff --git a/src/controller.py b/src/controller.py index 6d57f31..3dc3099 100644 --- a/src/controller.py +++ b/src/controller.py @@ -13,10 +13,8 @@ logging.basicConfig(format="%(levelname)s:%(asctime)s: %(message)s") logger = logging.getLogger(__name__) - -emap_db = db.starDB() -emap_db.init_query() -emap_db.connect() +logger.setLevel(settings.LOG_LEVEL) +# logger.addFilter(DedupeFilter(window_seconds=60)) class waveform_message: @@ -42,57 +40,78 @@ def reject_message(ch, delivery_tag, requeue): logger.warning("Attempting to not acknowledge a message on a closed channel.") -def waveform_callback(ch, method_frame, _header_frame, body): - data = json.loads(body) - try: - location_string = data["mappedLocationString"] - observation_timestamp = data["observationTime"] - source_variable_id = data["sourceVariableId"] - source_channel_id = data["sourceChannelId"] - sampling_rate = data["samplingRate"] - units = data["unit"] - waveform_data = data["numericValues"] - mapped_location_string = data["mappedLocationString"] - except IndexError as e: - reject_message(ch, method_frame.delivery_tag, False) - logger.error( - f"Waveform message {method_frame.delivery_tag} is missing required data {e}." - ) - return +class WaveformController: + def __init__(self): + self.emap_db = db.starDB() + self.emap_db.init_query() + self.emap_db.connect() + + def waveform_callback(self, ch, method_frame, _header_frame, body): + logger.debug("Message received of length %s", len(body)) + data = json.loads(body) + try: + location_string = data["mappedLocationString"] + observation_timestamp = data["observationTime"] + source_variable_id = data["sourceVariableId"] + source_channel_id = data["sourceChannelId"] + sampling_rate = data["samplingRate"] + units = data["unit"] + waveform_data = data["numericValues"] + mapped_location_string = data["mappedLocationString"] + logger.debug( + "Message is for loc %s, var %s, ch %s", + location_string, + source_variable_id, + source_channel_id, + ) + except KeyError as e: + reject_message(ch, method_frame.delivery_tag, False) + logger.error( + f"Waveform message {method_frame.delivery_tag} is missing required data {e}." + ) + return - observation_time = datetime.fromtimestamp(observation_timestamp, tz=timezone.utc) - lookup_success = True - try: - matched_mrn = emap_db.get_row(location_string, observation_time) - except ValueError: - lookup_success = False - logger.error( - "Ambiguous or non existent match for location %s, obs time %s", - location_string, - observation_time, - exc_info=True, + observation_time = datetime.fromtimestamp( + observation_timestamp, tz=timezone.utc ) - matched_mrn = ("unmatched_mrn", "unmatched_nhs", "unmatched_csn") - except ConnectionError: - logger.error("Database error, will try again", exc_info=True) - reject_message(ch, method_frame.delivery_tag, True) - return - - if writer.write_frame( - waveform_data, - source_variable_id, - source_channel_id, - observation_timestamp, - units, - sampling_rate, - mapped_location_string, - matched_mrn[2], - matched_mrn[0], - ): - if lookup_success: - ack_message(ch, method_frame.delivery_tag) - else: + lookup_success = True + try: + matched_mrn = self.emap_db.get_row(location_string, observation_time) + except ValueError: + lookup_success = False + logger.error( + "Ambiguous or non existent match for location %s, obs time %s", + location_string, + observation_time, + exc_info=True, + ) + matched_mrn = ("unmatched_mrn", "unmatched_nhs", "unmatched_csn", False) + except ConnectionError: + logger.error("Database error, will try again", exc_info=True) + reject_message(ch, method_frame.delivery_tag, True) + return + + (mrn, nhs_no, csn, opt_out) = matched_mrn + if opt_out: + logger.info("Research opt-out is set for mrn %s, not writing.", mrn) reject_message(ch, method_frame.delivery_tag, False) + return + + if writer.write_frame( + waveform_data, + source_variable_id, + source_channel_id, + observation_timestamp, + units, + sampling_rate, + mapped_location_string, + csn, + mrn, + ): + if lookup_success: + ack_message(ch, method_frame.delivery_tag) + else: + reject_message(ch, method_frame.delivery_tag, False) def receiver(): @@ -105,18 +124,27 @@ def receiver(): host=settings.RABBITMQ_HOST, port=settings.RABBITMQ_PORT, ) + logger.info("Connecting to RabbitMQ %s", connection_parameters) connection = pika.BlockingConnection(connection_parameters) channel = connection.channel() channel.basic_qos(prefetch_count=1) + controller = WaveformController() channel.basic_consume( queue=settings.RABBITMQ_QUEUE, auto_ack=False, - on_message_callback=waveform_callback, + on_message_callback=controller.waveform_callback, ) + logger.info("Connected to RabbitMQ, callback configured") try: channel.start_consuming() except KeyboardInterrupt: + logger.warning("Received keyboard interrupt, exiting.") channel.stop_consuming() + except Exception as e: + logger.error("Received other exception") + logger.error(e) + raise e + logger.info("Closing connection to RabbitMQ") connection.close() diff --git a/src/settings.py b/src/settings.py index e62761b..3bccd12 100644 --- a/src/settings.py +++ b/src/settings.py @@ -37,4 +37,6 @@ def get_from_env(env_var, *, default_value=None, setting_name=None, required=Fal get_from_env("HASHER_API_HOSTNAME") get_from_env("HASHER_API_PORT") +get_from_env("LOG_LEVEL", default_value="INFO") + get_from_env("INSTANCE_NAME", required=True) diff --git a/src/sql/mrn_based_on_bed_and_datetime.sql b/src/sql/mrn_based_on_bed_and_datetime.sql index 9f6ce02..7eccf5e 100644 --- a/src/sql/mrn_based_on_bed_and_datetime.sql +++ b/src/sql/mrn_based_on_bed_and_datetime.sql @@ -5,7 +5,8 @@ first entry being the most recent. SELECT mn.mrn as mrn, mn.nhs_number as nhs_number, - hv.encounter as csn + hv.encounter as csn, + mn.research_opt_out as research_opt_out FROM {schema_name}.mrn mn INNER JOIN {schema_name}.hospital_visit hv ON mn.mrn_id = hv.mrn_id diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..ed24ad3 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,23 @@ +import logging +import time + + +class DedupeFilter(logging.Filter): + """Suppress repeated identical log messages within a time window.""" + + def __init__(self, window_seconds=60, name=""): + super().__init__(name) + self.window_seconds = window_seconds + self._last_message = None + self._last_time = 0.0 + self._dedupe_count = 0 + + def filter(self, record): + msg = record.getMessage() + now = time.monotonic() + if msg == self._last_message and (now - self._last_time) < self.window_seconds: + self._dedupe_count += 1 + return False + self._last_message = msg + self._last_time = now + return True diff --git a/tests/test_controller.py b/tests/test_controller.py new file mode 100644 index 0000000..d4fd611 --- /dev/null +++ b/tests/test_controller.py @@ -0,0 +1,81 @@ +import json +from datetime import datetime +from unittest.mock import Mock + +import pytest + +from controller import WaveformController + + +@pytest.mark.parametrize( + "opt_out", + [True, False], +) +@pytest.mark.parametrize( + "db_connect_failure", + [True, False], +) +@pytest.mark.parametrize( + "bad_data", + [True, False], +) +def test_controller_callback(monkeypatch, opt_out, db_connect_failure, bad_data): + emap_db_mock = Mock() + if db_connect_failure: + emap_db_mock.get_row.side_effect = ConnectionError("mock database error") + else: + emap_db_mock.get_row.return_value = ("mrn", "nhsno", "csn", opt_out) + monkeypatch.setattr("controller.db.starDB", Mock(return_value=emap_db_mock)) + + write_frame_mock = Mock(return_value=True) + monkeypatch.setattr("controller.writer.write_frame", write_frame_mock) + + fake_data = { + "sourceLocationString": "foo", + "mappedLocationString": "loc", + "observationTime": datetime.now().timestamp(), + "sourceVariableId": "27", + "sourceChannelId": "1", + "samplingRate": 50, + "unit": "uV", + "numericValues": "[1,2,3]", + } + if bad_data: + # simulate a missing key + del fake_data["sourceChannelId"] + fake_data_str = json.dumps(fake_data) + controller = WaveformController() + + method_frame_mock = Mock() + delivery_tag = 12345 + method_frame_mock.delivery_tag = delivery_tag + channel_mock = Mock() + channel_mock.is_open = True + + controller.waveform_callback(channel_mock, method_frame_mock, None, fake_data_str) + + if not bad_data: + # we at least tried to query the DB + emap_db_mock.get_row.assert_called_once() + + if bad_data: + write_frame_mock.assert_not_called() + # db should not even have been queried if data was bad + emap_db_mock.get_row.assert_not_called() + channel_mock.basic_reject.assert_called_once_with(delivery_tag, False) + channel_mock.basic_ack.assert_not_called() + elif db_connect_failure: + # if the DB lookup failed, we should not write anything and requeue the message + write_frame_mock.assert_not_called() + channel_mock.basic_reject.assert_called_once_with(delivery_tag, True) + channel_mock.basic_ack.assert_not_called() + elif opt_out: + # patient has opted out, dump the message + write_frame_mock.assert_not_called() + channel_mock.basic_reject.assert_called_once_with(delivery_tag, False) + channel_mock.basic_ack.assert_not_called() + else: + # happy path + write_frame_mock.assert_called_once() + channel_mock.basic_reject.assert_not_called() + channel_mock.basic_ack.assert_called_once_with(delivery_tag) diff --git a/uv.lock b/uv.lock index 0c78cab..45339cb 100644 --- a/uv.lock +++ b/uv.lock @@ -2074,6 +2074,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, ] +[[package]] +name = "types-psycopg2" +version = "2.9.21.20251012" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9b/b3/2d09eaf35a084cffd329c584970a3fa07101ca465c13cad1576d7c392587/types_psycopg2-2.9.21.20251012.tar.gz", hash = "sha256:4cdafd38927da0cfde49804f39ab85afd9c6e9c492800e42f1f0c1a1b0312935", size = 26710 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/0c/05feaf8cb51159f2c0af04b871dab7e98a2f83a3622f5f216331d2dd924c/types_psycopg2-2.9.21.20251012-py3-none-any.whl", hash = "sha256:712bad5c423fe979e357edbf40a07ca40ef775d74043de72bd4544ca328cc57e", size = 24883 }, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -2133,6 +2142,7 @@ dependencies = [ dev = [ { name = "pytest" }, { name = "stablehash" }, + { name = "types-psycopg2" }, ] [package.metadata] @@ -2146,6 +2156,7 @@ requires-dist = [ { name = "requests", specifier = "==2.32.3" }, { name = "snakemake", specifier = "==9.14.5" }, { name = "stablehash", marker = "extra == 'dev'", specifier = "==0.3.0" }, + { name = "types-psycopg2", marker = "extra == 'dev'" }, ] [[package]]