diff --git a/pyhealth/calib/calibration/hb.py b/pyhealth/calib/calibration/hb.py index 8dcdcb1b1..6928a1076 100644 --- a/pyhealth/calib/calibration/hb.py +++ b/pyhealth/calib/calibration/hb.py @@ -24,16 +24,19 @@ def _nudge(matrix, delta): def _bin_points(scores, bin_edges): - assert bin_edges is not None, "Bins have not been defined" + if bin_edges is None: + raise ValueError("Bins have not been defined") scores = scores.squeeze() - assert np.size(scores.shape) < 2, "scores should be a 1D vector or singleton" + if np.size(scores.shape) >= 2: + raise ValueError(f"scores should be a 1D vector or singleton, got shape {scores.shape}") scores = np.reshape(scores, (scores.size, 1)) bin_edges = np.reshape(bin_edges, (1, bin_edges.size)) return np.sum(scores > bin_edges, axis=1) def _get_uniform_mass_bins(probs, n_bins): - assert probs.size >= n_bins, "Fewer points than bins" + if probs.size < n_bins: + raise ValueError(f"Fewer points than bins ({probs.size} < {n_bins})") probs_sorted = np.sort(probs) @@ -63,13 +66,19 @@ def __init__(self, n_bins=15): self.fitted = False def fit(self, y_score, y): - assert self.n_bins is not None, "Number of bins has to be specified" + if self.n_bins is None: + raise ValueError("Number of bins has to be specified") y_score = y_score.squeeze() y = y.squeeze() - assert y_score.size == y.size, "Check dimensions of input matrices" - assert ( - y.size >= self.n_bins - ), "Number of bins should be less than the number of calibration points" + if y_score.size != y.size: + raise ValueError( + f"y_score and y must have the same size, got {y_score.size} and {y.size}" + ) + if y.size < self.n_bins: + raise ValueError( + f"Number of bins should be less than the number of calibration points " + f"({self.n_bins} bins > {y.size} points)" + ) ### All required (hyper-)parameters have been passed correctly ### Uniform-mass binning/histogram binning code starts below @@ -104,7 +113,8 @@ def fit(self, y_score, y): return self def predict_proba(self, y_score): - assert self.fitted is True, "Call HB_binary.fit() first" + if not self.fitted: + raise ValueError("Call HB_binary.fit() first") y_score = y_score.squeeze() # delta-randomization @@ -222,7 +232,8 @@ def forward(self, normalization="sum", **kwargs) -> Dict[str, torch.Tensor]: ``loss``: Cross entropy loss with the new y_prob. :rtype: Dict[str, torch.Tensor] """ - assert normalization is None or normalization == "sum" + if normalization is not None and normalization != "sum": + raise ValueError(f"normalization must be None or 'sum', got {normalization!r}") ret = self.model(**kwargs) y_prob = ret["y_prob"].cpu().numpy() for k in range(self.num_classes): diff --git a/pyhealth/datasets/splitter.py b/pyhealth/datasets/splitter.py index 2dbc94186..94f43c818 100644 --- a/pyhealth/datasets/splitter.py +++ b/pyhealth/datasets/splitter.py @@ -125,7 +125,8 @@ def split_by_visit( `val_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) - assert sum(ratios) == 1.0, "ratios must sum to 1.0" + if sum(ratios) != 1.0: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") index = np.arange(len(dataset)) rng.shuffle(index) train_index = index[: int(len(dataset) * ratios[0])] @@ -160,7 +161,8 @@ def split_by_patient( `val_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) - assert sum(ratios) == 1.0, "ratios must sum to 1.0" + if sum(ratios) != 1.0: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") patient_indx = list(dataset.patient_to_index.keys()) num_patients = len(patient_indx) rng.shuffle(patient_indx) @@ -202,7 +204,8 @@ def split_by_sample( `val_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) - assert sum(ratios) == 1.0, "ratios must sum to 1.0" + if sum(ratios) != 1.0: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") index = np.arange(len(dataset)) rng.shuffle(index) train_index = index[: int(len(dataset) * ratios[0])] @@ -246,8 +249,10 @@ def split_by_visit_conformal( `test_dataset.dataset`. """ rng = np.random.default_rng(seed) - assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test" - assert sum(ratios) == 1.0, "ratios must sum to 1.0" + if len(ratios) != 4: + raise ValueError(f"ratios must have 4 elements for train/val/cal/test, got {len(ratios)}") + if sum(ratios) != 1.0: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") index = np.arange(len(dataset)) rng.shuffle(index) @@ -292,8 +297,10 @@ def split_by_patient_conformal( `test_dataset.dataset`. """ rng = np.random.default_rng(seed) - assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test" - assert sum(ratios) == 1.0, "ratios must sum to 1.0" + if len(ratios) != 4: + raise ValueError(f"ratios must have 4 elements for train/val/cal/test, got {len(ratios)}") + if sum(ratios) != 1.0: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") patient_indx = list(dataset.patient_to_index.keys()) num_patients = len(patient_indx) @@ -359,11 +366,13 @@ def split_by_patient_conformal_tuh( Returns: ``(train_dataset, val_dataset, cal_dataset, test_dataset)`` """ - assert len(ratios) == 3, ( - "ratios must have exactly 3 elements (train/val/cal). " - "The test set is determined by the TUH eval partition." - ) - assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" + if len(ratios) != 3: + raise ValueError( + f"ratios must have exactly 3 elements (train/val/cal). " + f"The test set is determined by the TUH eval partition. Got {len(ratios)} elements." + ) + if abs(sum(ratios) - 1.0) >= 1e-6: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") # Bucket patients by partition using dataset.patient_to_index (fast path). # TUH guarantees each patient is in exactly one partition, so inspecting the @@ -373,10 +382,11 @@ def split_by_patient_conformal_tuh( for pid, indices in dataset.patient_to_index.items(): first_sample = dataset[indices[0]] - assert "split" in first_sample, ( - f"Patient {pid}: sample missing 'split' field. " - "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." - ) + if "split" not in first_sample: + raise ValueError( + f"Patient {pid}: sample missing 'split' field. " + "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." + ) if first_sample["split"] == "train": train_patient_to_indices[pid] = list(indices) else: @@ -434,18 +444,21 @@ def split_by_sample_conformal_tuh( Returns: train_dataset, val_dataset, cal_dataset, test_dataset """ - assert len(ratios) == 3, ( - "ratios must have exactly 3 elements (train/val/cal). " - "The test set is determined by the dataset's own eval partition." - ) - assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" + if len(ratios) != 3: + raise ValueError( + f"ratios must have exactly 3 elements (train/val/cal). " + f"The test set is determined by the dataset's own eval partition. Got {len(ratios)} elements." + ) + if abs(sum(ratios) - 1.0) >= 1e-6: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") # verify every sample has the required "split" field for i in range(len(dataset)): - assert "split" in dataset[i], ( - f"Sample {i} is missing the 'split' field. " - "Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." - ) + if "split" not in dataset[i]: + raise ValueError( + f"Sample {i} is missing the 'split' field. " + "Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." + ) train_pool: List[int] = [] test_list: List[int] = [] @@ -516,21 +529,24 @@ def split_by_patient_tuh( Returns: ``(train_dataset, val_dataset, test_dataset)`` """ - assert len(ratios) == 2, ( - "ratios must have exactly 2 elements (train/val). " - "The test set is determined by the TUH eval partition." - ) - assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" + if len(ratios) != 2: + raise ValueError( + f"ratios must have exactly 2 elements (train/val). " + f"The test set is determined by the TUH eval partition. Got {len(ratios)} elements." + ) + if abs(sum(ratios) - 1.0) >= 1e-6: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") train_patient_to_indices: dict = {} test_list: List[int] = [] for pid, indices in dataset.patient_to_index.items(): first_sample = dataset[indices[0]] - assert "split" in first_sample, ( - f"Patient {pid}: sample missing 'split' field. " - "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." - ) + if "split" not in first_sample: + raise ValueError( + f"Patient {pid}: sample missing 'split' field. " + "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." + ) if first_sample["split"] == "train": train_patient_to_indices[pid] = list(indices) else: @@ -592,17 +608,20 @@ def split_by_sample_tuh( Returns: train_dataset, val_dataset, test_dataset """ - assert len(ratios) == 2, ( - "ratios must have exactly 2 elements (train/val). " - "The test set is determined by the dataset's own eval partition." - ) - assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" + if len(ratios) != 2: + raise ValueError( + f"ratios must have exactly 2 elements (train/val). " + f"The test set is determined by the dataset's own eval partition. Got {len(ratios)} elements." + ) + if abs(sum(ratios) - 1.0) >= 1e-6: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") for i in range(len(dataset)): - assert "split" in dataset[i], ( - f"Sample {i} is missing the 'split' field. " - "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." - ) + if "split" not in dataset[i]: + raise ValueError( + f"Sample {i} is missing the 'split' field. " + "Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." + ) train_pool: List[int] = [] test_list: List[int] = [] @@ -662,8 +681,10 @@ def split_by_sample_conformal( `test_dataset.dataset`. """ rng = np.random.default_rng(seed) - assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test" - assert sum(ratios) == 1.0, "ratios must sum to 1.0" + if len(ratios) != 4: + raise ValueError(f"ratios must have 4 elements for train/val/cal/test, got {len(ratios)}") + if sum(ratios) != 1.0: + raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}") index = np.arange(len(dataset)) rng.shuffle(index) diff --git a/pyhealth/models/stagenet_mha.py b/pyhealth/models/stagenet_mha.py index 637f2fdc5..72f2dcd1c 100644 --- a/pyhealth/models/stagenet_mha.py +++ b/pyhealth/models/stagenet_mha.py @@ -149,13 +149,13 @@ def cumax(self, x, mode="l2r"): return x def step(self, inputs, c_last, h_last, interval, device): - x_in = inputs.to(device=device) + x_in = inputs # Integrate inter-visit time intervals - interval = interval.unsqueeze(-1).to(device=device) - x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)).to(device) + interval = interval.unsqueeze(-1) + x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)) x_out2 = self.recurrent_kernel( - torch.cat((h_last.to(device=device), interval), dim=-1) + torch.cat((h_last, interval), dim=-1) ) if self.dropconnect: @@ -163,19 +163,17 @@ def step(self, inputs, c_last, h_last, interval, device): x_out2 = self.nn_dropconnect_r(x_out2) x_out = x_out1 + x_out2 f_master_gate = self.cumax(x_out[:, : self.levels], "l2r") - f_master_gate = f_master_gate.unsqueeze(2).to(device=device) + f_master_gate = f_master_gate.unsqueeze(2) i_master_gate = self.cumax(x_out[:, self.levels : self.levels * 2], "r2l") i_master_gate = i_master_gate.unsqueeze(2) x_out = x_out[:, self.levels * 2 :] x_out = x_out.reshape(-1, self.levels * 4, self.chunk_size) - f_gate = self.sigmoid(x_out[:, : self.levels]).to(device=device) - i_gate = self.sigmoid(x_out[:, self.levels : self.levels * 2]).to( - device=device - ) + f_gate = self.sigmoid(x_out[:, : self.levels]) + i_gate = self.sigmoid(x_out[:, self.levels : self.levels * 2]) o_gate = self.sigmoid(x_out[:, self.levels * 2 : self.levels * 3]) - c_in = self.tanh(x_out[:, self.levels * 3 :]).to(device=device) - c_last = c_last.reshape(-1, self.levels, self.chunk_size).to(device=device) - overlap = (f_master_gate * i_master_gate).to(device=device) + c_in = self.tanh(x_out[:, self.levels * 3 :]) + c_last = c_last.reshape(-1, self.levels, self.chunk_size) + overlap = f_master_gate * i_master_gate c_out = ( overlap * (f_gate * c_last + i_gate * c_in) + (f_master_gate - overlap) * c_last diff --git a/pyhealth/tasks/benchmark_ehrshot.py b/pyhealth/tasks/benchmark_ehrshot.py index 1d72ec2f1..1a8b4375b 100644 --- a/pyhealth/tasks/benchmark_ehrshot.py +++ b/pyhealth/tasks/benchmark_ehrshot.py @@ -71,7 +71,8 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: def __call__(self, patient: Any) -> List[Dict[str, Any]]: samples = [] split = patient.get_events("splits") - assert len(split) == 1, "Only one split is allowed" + if len(split) != 1: + raise ValueError(f"Only one split is allowed, got {len(split)}") split = split[0].split labels = patient.get_events(self.task) for label in labels: diff --git a/pyhealth/tasks/covid19_cxr_classification.py b/pyhealth/tasks/covid19_cxr_classification.py index 0a7f526f5..ad6a97c23 100644 --- a/pyhealth/tasks/covid19_cxr_classification.py +++ b/pyhealth/tasks/covid19_cxr_classification.py @@ -42,11 +42,12 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: - "disease": The disease classification label Raises: - AssertionError: If the patient has more than one chest X-ray event. + ValueError: If the patient has more than one chest X-ray event. """ event = patient.get_events(event_type="covid19_cxr") # There should be only one event - assert len(event) == 1 + if len(event) != 1: + raise ValueError(f"Expected exactly 1 covid19_cxr event, got {len(event)}") event = event[0] image = event.path disease = event.label diff --git a/pyhealth/tasks/in_hospital_mortality_mimic4.py b/pyhealth/tasks/in_hospital_mortality_mimic4.py index 334239714..d573bacdd 100644 --- a/pyhealth/tasks/in_hospital_mortality_mimic4.py +++ b/pyhealth/tasks/in_hospital_mortality_mimic4.py @@ -63,7 +63,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: samples = [] demographics = patient.get_events(event_type="patients") - assert len(demographics) == 1 + if len(demographics) != 1: + raise ValueError(f"Expected exactly 1 demographics record, got {len(demographics)}") demographics = demographics[0] anchor_age = int(demographics.anchor_age) if anchor_age < 18: diff --git a/pyhealth/tasks/medical_transcriptions_classification.py b/pyhealth/tasks/medical_transcriptions_classification.py index 37a84fa21..436509880 100644 --- a/pyhealth/tasks/medical_transcriptions_classification.py +++ b/pyhealth/tasks/medical_transcriptions_classification.py @@ -44,7 +44,8 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]: """ event = patient.get_events(event_type="mtsamples") # There should be only one event - assert len(event) == 1 + if len(event) != 1: + raise ValueError(f"Expected exactly 1 mtsamples event, got {len(event)}") event = event[0] transcription_valid = isinstance(event.transcription, str) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 88ecc0109..d48ce60c0 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -71,7 +71,10 @@ def __call__(self, patient: Patient) -> List[Dict]: ValueError: If any `str` to `datetime` conversions fail. """ patients: List[Event] = patient.get_events(event_type="patients") - assert len(patients) == 1 + if len(patients) != 1: + raise ValueError( + f"Expected exactly 1 patient record, got {len(patients)}" + ) if self.exclude_minors: try: @@ -203,10 +206,13 @@ def __call__(self, patient: Patient) -> List[Dict]: Raises: ValueError: If any `str` to `datetime` conversions fail. - AssertionError: If any icd_version value in the diagnoses_icd or procedures_icd tables is not "9" or "10" + ValueError: If any icd_version value in the diagnoses_icd or procedures_icd tables is not "9" or "10" """ patients: List[Event] = patient.get_events(event_type="patients") - assert len(patients) == 1 + if len(patients) != 1: + raise ValueError( + f"Expected exactly 1 patient record, got {len(patients)}" + ) if self.exclude_minors and int(patients[0]["anchor_age"]) < 18: return [] @@ -224,7 +230,10 @@ def __call__(self, patient: Patient) -> List[Dict]: for event in patient.get_events( event_type="diagnoses_icd", filters=[filter] ): - assert event.icd_version in ("9", "10") + if event.icd_version not in ("9", "10"): + raise ValueError( + f"Unexpected icd_version {event.icd_version!r}, expected '9' or '10'" + ) diagnoses.append(f"{event.icd_version}_{event.icd_code}") if len(diagnoses) == 0: continue @@ -233,7 +242,10 @@ def __call__(self, patient: Patient) -> List[Dict]: for event in patient.get_events( event_type="procedures_icd", filters=[filter] ): - assert event.icd_version in ("9", "10") + if event.icd_version not in ("9", "10"): + raise ValueError( + f"Unexpected icd_version {event.icd_version!r}, expected '9' or '10'" + ) procedures.append(f"{event.icd_version}_{event.icd_code}") if len(procedures) == 0: continue @@ -483,7 +495,10 @@ def __call__(self, patient: Patient) -> List[Dict]: - 'readmission': binary label. """ patients: List[Event] = patient.get_events(event_type="person") - assert len(patients) == 1 + if len(patients) != 1: + raise ValueError( + f"Expected exactly 1 person record, got {len(patients)}" + ) if self.exclude_minors: year = int(patients[0].year_of_birth) diff --git a/pyhealth/tasks/sleep_staging.py b/pyhealth/tasks/sleep_staging.py index 911b672c8..0ac86ab74 100644 --- a/pyhealth/tasks/sleep_staging.py +++ b/pyhealth/tasks/sleep_staging.py @@ -42,7 +42,10 @@ def sleep_staging_isruc_fn(record, epoch_seconds=10, label_id=1): } """ SAMPLE_RATE = 200 - assert 30 % epoch_seconds == 0, "ISRUC is annotated every 30 seconds." + if 30 % epoch_seconds != 0: + raise ValueError( + f"epoch_seconds must be a factor of 30 (ISRUC is annotated every 30 seconds), got {epoch_seconds!r}" + ) _channels = [ "F3", "F4", @@ -63,11 +66,11 @@ def _find_channels(potential_channels): .replace("-A1", "") ) if new_c in _channels: - assert new_c not in keep, f"Unrecognized channels: {potential_channels}" + if new_c in keep: + raise ValueError(f"Duplicate channel mapping for {new_c!r}: {potential_channels}") keep[new_c] = c - assert len(keep) == len( - _channels - ), f"Unrecognized channels: {potential_channels}" + if len(keep) != len(_channels): + raise ValueError(f"Unrecognized channels: {potential_channels}") return {v: k for k, v in keep.items()} record = record[0] @@ -89,7 +92,8 @@ def _find_channels(potential_channels): header=None, )[0] ann = ann.map(["W", "N1", "N2", "N3", "Unknown", "R"].__getitem__) - assert "Unknown" not in ann.values, "bad annotations" + if "Unknown" in ann.values: + raise ValueError("bad annotations: found 'Unknown' stage labels") samples = [] sample_length = SAMPLE_RATE * epoch_seconds for i, epoch_label in enumerate(np.repeat(ann.values, 30 // epoch_seconds)):