Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions pyhealth/calib/calibration/hb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ def _nudge(matrix, delta):


def _bin_points(scores, bin_edges):
assert bin_edges is not None, "Bins have not been defined"
if bin_edges is None:
raise ValueError("Bins have not been defined")
scores = scores.squeeze()
assert np.size(scores.shape) < 2, "scores should be a 1D vector or singleton"
if np.size(scores.shape) >= 2:
raise ValueError(f"scores should be a 1D vector or singleton, got shape {scores.shape}")
scores = np.reshape(scores, (scores.size, 1))
bin_edges = np.reshape(bin_edges, (1, bin_edges.size))
return np.sum(scores > bin_edges, axis=1)


def _get_uniform_mass_bins(probs, n_bins):
assert probs.size >= n_bins, "Fewer points than bins"
if probs.size < n_bins:
raise ValueError(f"Fewer points than bins ({probs.size} < {n_bins})")

probs_sorted = np.sort(probs)

Expand Down Expand Up @@ -63,13 +66,19 @@ def __init__(self, n_bins=15):
self.fitted = False

def fit(self, y_score, y):
assert self.n_bins is not None, "Number of bins has to be specified"
if self.n_bins is None:
raise ValueError("Number of bins has to be specified")
y_score = y_score.squeeze()
y = y.squeeze()
assert y_score.size == y.size, "Check dimensions of input matrices"
assert (
y.size >= self.n_bins
), "Number of bins should be less than the number of calibration points"
if y_score.size != y.size:
raise ValueError(
f"y_score and y must have the same size, got {y_score.size} and {y.size}"
)
if y.size < self.n_bins:
raise ValueError(
f"Number of bins should be less than the number of calibration points "
f"({self.n_bins} bins > {y.size} points)"
)

### All required (hyper-)parameters have been passed correctly
### Uniform-mass binning/histogram binning code starts below
Expand Down Expand Up @@ -104,7 +113,8 @@ def fit(self, y_score, y):
return self

def predict_proba(self, y_score):
assert self.fitted is True, "Call HB_binary.fit() first"
if not self.fitted:
raise ValueError("Call HB_binary.fit() first")
y_score = y_score.squeeze()

# delta-randomization
Expand Down Expand Up @@ -222,7 +232,8 @@ def forward(self, normalization="sum", **kwargs) -> Dict[str, torch.Tensor]:
``loss``: Cross entropy loss with the new y_prob.
:rtype: Dict[str, torch.Tensor]
"""
assert normalization is None or normalization == "sum"
if normalization is not None and normalization != "sum":
raise ValueError(f"normalization must be None or 'sum', got {normalization!r}")
ret = self.model(**kwargs)
y_prob = ret["y_prob"].cpu().numpy()
for k in range(self.num_classes):
Expand Down
111 changes: 66 additions & 45 deletions pyhealth/datasets/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def split_by_visit(
`val_dataset.dataset`, and `test_dataset.dataset`.
"""
rng = np.random.default_rng(seed)
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
if sum(ratios) != 1.0:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")
index = np.arange(len(dataset))
rng.shuffle(index)
train_index = index[: int(len(dataset) * ratios[0])]
Expand Down Expand Up @@ -160,7 +161,8 @@ def split_by_patient(
`val_dataset.dataset`, and `test_dataset.dataset`.
"""
rng = np.random.default_rng(seed)
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
if sum(ratios) != 1.0:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")
patient_indx = list(dataset.patient_to_index.keys())
num_patients = len(patient_indx)
rng.shuffle(patient_indx)
Expand Down Expand Up @@ -202,7 +204,8 @@ def split_by_sample(
`val_dataset.dataset`, and `test_dataset.dataset`.
"""
rng = np.random.default_rng(seed)
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
if sum(ratios) != 1.0:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")
index = np.arange(len(dataset))
rng.shuffle(index)
train_index = index[: int(len(dataset) * ratios[0])]
Expand Down Expand Up @@ -246,8 +249,10 @@ def split_by_visit_conformal(
`test_dataset.dataset`.
"""
rng = np.random.default_rng(seed)
assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test"
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
if len(ratios) != 4:
raise ValueError(f"ratios must have 4 elements for train/val/cal/test, got {len(ratios)}")
if sum(ratios) != 1.0:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

index = np.arange(len(dataset))
rng.shuffle(index)
Expand Down Expand Up @@ -292,8 +297,10 @@ def split_by_patient_conformal(
`test_dataset.dataset`.
"""
rng = np.random.default_rng(seed)
assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test"
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
if len(ratios) != 4:
raise ValueError(f"ratios must have 4 elements for train/val/cal/test, got {len(ratios)}")
if sum(ratios) != 1.0:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

patient_indx = list(dataset.patient_to_index.keys())
num_patients = len(patient_indx)
Expand Down Expand Up @@ -359,11 +366,13 @@ def split_by_patient_conformal_tuh(
Returns:
``(train_dataset, val_dataset, cal_dataset, test_dataset)``
"""
assert len(ratios) == 3, (
"ratios must have exactly 3 elements (train/val/cal). "
"The test set is determined by the TUH eval partition."
)
assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0"
if len(ratios) != 3:
raise ValueError(
f"ratios must have exactly 3 elements (train/val/cal). "
f"The test set is determined by the TUH eval partition. Got {len(ratios)} elements."
)
if abs(sum(ratios) - 1.0) >= 1e-6:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

# Bucket patients by partition using dataset.patient_to_index (fast path).
# TUH guarantees each patient is in exactly one partition, so inspecting the
Expand All @@ -373,10 +382,11 @@ def split_by_patient_conformal_tuh(

for pid, indices in dataset.patient_to_index.items():
first_sample = dataset[indices[0]]
assert "split" in first_sample, (
f"Patient {pid}: sample missing 'split' field. "
"Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)
if "split" not in first_sample:
raise ValueError(
f"Patient {pid}: sample missing 'split' field. "
"Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)
if first_sample["split"] == "train":
train_patient_to_indices[pid] = list(indices)
else:
Expand Down Expand Up @@ -434,18 +444,21 @@ def split_by_sample_conformal_tuh(
Returns:
train_dataset, val_dataset, cal_dataset, test_dataset
"""
assert len(ratios) == 3, (
"ratios must have exactly 3 elements (train/val/cal). "
"The test set is determined by the dataset's own eval partition."
)
assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0"
if len(ratios) != 3:
raise ValueError(
f"ratios must have exactly 3 elements (train/val/cal). "
f"The test set is determined by the dataset's own eval partition. Got {len(ratios)} elements."
)
if abs(sum(ratios) - 1.0) >= 1e-6:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

# verify every sample has the required "split" field
for i in range(len(dataset)):
assert "split" in dataset[i], (
f"Sample {i} is missing the 'split' field. "
"Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)
if "split" not in dataset[i]:
raise ValueError(
f"Sample {i} is missing the 'split' field. "
"Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)

train_pool: List[int] = []
test_list: List[int] = []
Expand Down Expand Up @@ -516,21 +529,24 @@ def split_by_patient_tuh(
Returns:
``(train_dataset, val_dataset, test_dataset)``
"""
assert len(ratios) == 2, (
"ratios must have exactly 2 elements (train/val). "
"The test set is determined by the TUH eval partition."
)
assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0"
if len(ratios) != 2:
raise ValueError(
f"ratios must have exactly 2 elements (train/val). "
f"The test set is determined by the TUH eval partition. Got {len(ratios)} elements."
)
if abs(sum(ratios) - 1.0) >= 1e-6:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

train_patient_to_indices: dict = {}
test_list: List[int] = []

for pid, indices in dataset.patient_to_index.items():
first_sample = dataset[indices[0]]
assert "split" in first_sample, (
f"Patient {pid}: sample missing 'split' field. "
"Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)
if "split" not in first_sample:
raise ValueError(
f"Patient {pid}: sample missing 'split' field. "
"Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)
if first_sample["split"] == "train":
train_patient_to_indices[pid] = list(indices)
else:
Expand Down Expand Up @@ -592,17 +608,20 @@ def split_by_sample_tuh(
Returns:
train_dataset, val_dataset, test_dataset
"""
assert len(ratios) == 2, (
"ratios must have exactly 2 elements (train/val). "
"The test set is determined by the dataset's own eval partition."
)
assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0"
if len(ratios) != 2:
raise ValueError(
f"ratios must have exactly 2 elements (train/val). "
f"The test set is determined by the dataset's own eval partition. Got {len(ratios)} elements."
)
if abs(sum(ratios) - 1.0) >= 1e-6:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

for i in range(len(dataset)):
assert "split" in dataset[i], (
f"Sample {i} is missing the 'split' field. "
"Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)
if "split" not in dataset[i]:
raise ValueError(
f"Sample {i} is missing the 'split' field. "
"Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset."
)

train_pool: List[int] = []
test_list: List[int] = []
Expand Down Expand Up @@ -662,8 +681,10 @@ def split_by_sample_conformal(
`test_dataset.dataset`.
"""
rng = np.random.default_rng(seed)
assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test"
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
if len(ratios) != 4:
raise ValueError(f"ratios must have 4 elements for train/val/cal/test, got {len(ratios)}")
if sum(ratios) != 1.0:
raise ValueError(f"ratios must sum to 1.0, got {sum(ratios)!r}")

index = np.arange(len(dataset))
rng.shuffle(index)
Expand Down
22 changes: 10 additions & 12 deletions pyhealth/models/stagenet_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,33 +149,31 @@ def cumax(self, x, mode="l2r"):
return x

def step(self, inputs, c_last, h_last, interval, device):
x_in = inputs.to(device=device)
x_in = inputs

# Integrate inter-visit time intervals
interval = interval.unsqueeze(-1).to(device=device)
x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)).to(device)
interval = interval.unsqueeze(-1)
x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1))
x_out2 = self.recurrent_kernel(
torch.cat((h_last.to(device=device), interval), dim=-1)
torch.cat((h_last, interval), dim=-1)
)

if self.dropconnect:
x_out1 = self.nn_dropconnect(x_out1)
x_out2 = self.nn_dropconnect_r(x_out2)
x_out = x_out1 + x_out2
f_master_gate = self.cumax(x_out[:, : self.levels], "l2r")
f_master_gate = f_master_gate.unsqueeze(2).to(device=device)
f_master_gate = f_master_gate.unsqueeze(2)
i_master_gate = self.cumax(x_out[:, self.levels : self.levels * 2], "r2l")
i_master_gate = i_master_gate.unsqueeze(2)
x_out = x_out[:, self.levels * 2 :]
x_out = x_out.reshape(-1, self.levels * 4, self.chunk_size)
f_gate = self.sigmoid(x_out[:, : self.levels]).to(device=device)
i_gate = self.sigmoid(x_out[:, self.levels : self.levels * 2]).to(
device=device
)
f_gate = self.sigmoid(x_out[:, : self.levels])
i_gate = self.sigmoid(x_out[:, self.levels : self.levels * 2])
o_gate = self.sigmoid(x_out[:, self.levels * 2 : self.levels * 3])
c_in = self.tanh(x_out[:, self.levels * 3 :]).to(device=device)
c_last = c_last.reshape(-1, self.levels, self.chunk_size).to(device=device)
overlap = (f_master_gate * i_master_gate).to(device=device)
c_in = self.tanh(x_out[:, self.levels * 3 :])
c_last = c_last.reshape(-1, self.levels, self.chunk_size)
overlap = f_master_gate * i_master_gate
c_out = (
overlap * (f_gate * c_last + i_gate * c_in)
+ (f_master_gate - overlap) * c_last
Expand Down
3 changes: 2 additions & 1 deletion pyhealth/tasks/benchmark_ehrshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
samples = []
split = patient.get_events("splits")
assert len(split) == 1, "Only one split is allowed"
if len(split) != 1:
raise ValueError(f"Only one split is allowed, got {len(split)}")
split = split[0].split
labels = patient.get_events(self.task)
for label in labels:
Expand Down
5 changes: 3 additions & 2 deletions pyhealth/tasks/covid19_cxr_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
- "disease": The disease classification label

Raises:
AssertionError: If the patient has more than one chest X-ray event.
ValueError: If the patient has more than one chest X-ray event.
"""
event = patient.get_events(event_type="covid19_cxr")
# There should be only one event
assert len(event) == 1
if len(event) != 1:
raise ValueError(f"Expected exactly 1 covid19_cxr event, got {len(event)}")
event = event[0]
image = event.path
disease = event.label
Expand Down
3 changes: 2 additions & 1 deletion pyhealth/tasks/in_hospital_mortality_mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
samples = []

demographics = patient.get_events(event_type="patients")
assert len(demographics) == 1
if len(demographics) != 1:
raise ValueError(f"Expected exactly 1 demographics record, got {len(demographics)}")
demographics = demographics[0]
anchor_age = int(demographics.anchor_age)
if anchor_age < 18:
Expand Down
3 changes: 2 additions & 1 deletion pyhealth/tasks/medical_transcriptions_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
"""
event = patient.get_events(event_type="mtsamples")
# There should be only one event
assert len(event) == 1
if len(event) != 1:
raise ValueError(f"Expected exactly 1 mtsamples event, got {len(event)}")
event = event[0]

transcription_valid = isinstance(event.transcription, str)
Expand Down
Loading
Loading