Skip to content

Commit 5b4f6f7

Browse files
authored
Merge pull request #32 from CausalInference/suggestions
Some optimizations
2 parents 4054533 + 400971c commit 5b4f6f7

9 files changed

Lines changed: 225 additions & 205 deletions

File tree

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
version = importlib.metadata.version("pySEQTarget")
1414
if not version:
15-
version = "0.11.0"
15+
version = "0.12.0"
1616
sys.path.insert(0, os.path.abspath("../"))
1717

1818
project = "pySEQTarget"

pySEQTarget/analysis/_risk_estimates.py

Lines changed: 89 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,105 @@
22
from scipy import stats
33

44

5+
def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
6+
"""
7+
Compute Risk Difference and Risk Ratio from a comparison dataframe.
8+
Consolidates the repeated calculation logic.
9+
"""
10+
if group_cols is None:
11+
group_cols = []
12+
13+
if has_bootstrap:
14+
rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
15+
rd_comp = comp.with_columns(
16+
[
17+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
18+
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"),
19+
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"),
20+
]
21+
)
22+
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
23+
col_order = group_cols + [
24+
"A_x",
25+
"A_y",
26+
"Risk Difference",
27+
"RD 95% LCI",
28+
"RD 95% UCI",
29+
]
30+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
31+
32+
rr_log_se = (
33+
(pl.col("se_x") / pl.col("risk_x")).pow(2)
34+
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
35+
).sqrt()
36+
rr_comp = comp.with_columns(
37+
[
38+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
39+
(
40+
(pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp()
41+
).alias("RR 95% LCI"),
42+
(
43+
(pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp()
44+
).alias("RR 95% UCI"),
45+
]
46+
)
47+
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
48+
col_order = group_cols + ["A_x", "A_y", "Risk Ratio", "RR 95% LCI", "RR 95% UCI"]
49+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
50+
else:
51+
rd_comp = comp.with_columns(
52+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
53+
)
54+
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
55+
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
56+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
57+
58+
rr_comp = comp.with_columns(
59+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
60+
)
61+
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
62+
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
63+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
64+
65+
return rd_comp, rr_comp
66+
67+
568
def _risk_estimates(self):
669
last_followup = self.km_data["followup"].max()
770
risk = self.km_data.filter(
871
(pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
972
)
1073

1174
group_cols = [self.subgroup_colname] if self.subgroup_colname else []
12-
rd_comparisons = []
13-
rr_comparisons = []
75+
has_bootstrap = self.bootstrap_nboot > 0
1476

15-
if self.bootstrap_nboot > 0:
77+
if has_bootstrap:
1678
alpha = 1 - self.bootstrap_CI
1779
z = stats.norm.ppf(1 - alpha / 2)
80+
else:
81+
z = None
82+
83+
# Pre-extract data for each treatment level once (avoid repeated filtering)
84+
risk_by_level = {}
85+
for tx in self.treatment_level:
86+
level_data = risk.filter(pl.col(self.treatment_col) == tx)
87+
risk_by_level[tx] = {
88+
"pred": level_data.select(group_cols + ["pred"]),
89+
}
90+
if has_bootstrap:
91+
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])
92+
93+
rd_comparisons = []
94+
rr_comparisons = []
1895

1996
for tx_x in self.treatment_level:
2097
for tx_y in self.treatment_level:
2198
if tx_x == tx_y:
2299
continue
23100

24-
risk_x = (
25-
risk.filter(pl.col(self.treatment_col) == tx_x)
26-
.select(group_cols + ["pred"])
27-
.rename({"pred": "risk_x"})
28-
)
29-
30-
risk_y = (
31-
risk.filter(pl.col(self.treatment_col) == tx_y)
32-
.select(group_cols + ["pred"])
33-
.rename({"pred": "risk_y"})
34-
)
101+
# Use pre-extracted data instead of filtering again
102+
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
103+
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})
35104

36105
if group_cols:
37106
comp = risk_x.join(risk_y, on=group_cols, how="left")
@@ -42,18 +111,9 @@ def _risk_estimates(self):
42111
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
43112
)
44113

45-
if self.bootstrap_nboot > 0:
46-
se_x = (
47-
risk.filter(pl.col(self.treatment_col) == tx_x)
48-
.select(group_cols + ["SE"])
49-
.rename({"SE": "se_x"})
50-
)
51-
52-
se_y = (
53-
risk.filter(pl.col(self.treatment_col) == tx_y)
54-
.select(group_cols + ["SE"])
55-
.rename({"SE": "se_y"})
56-
)
114+
if has_bootstrap:
115+
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
116+
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
57117

58118
if group_cols:
59119
comp = comp.join(se_x, on=group_cols, how="left")
@@ -62,73 +122,9 @@ def _risk_estimates(self):
62122
comp = comp.join(se_x, how="cross")
63123
comp = comp.join(se_y, how="cross")
64124

65-
rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
66-
rd_comp = comp.with_columns(
67-
[
68-
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
69-
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias(
70-
"RD 95% LCI"
71-
),
72-
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias(
73-
"RD 95% UCI"
74-
),
75-
]
76-
)
77-
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
78-
col_order = group_cols + [
79-
"A_x",
80-
"A_y",
81-
"Risk Difference",
82-
"RD 95% LCI",
83-
"RD 95% UCI",
84-
]
85-
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
86-
rd_comparisons.append(rd_comp)
87-
88-
rr_log_se = (
89-
(pl.col("se_x") / pl.col("risk_x")).pow(2)
90-
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
91-
).sqrt()
92-
rr_comp = comp.with_columns(
93-
[
94-
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
95-
(
96-
(pl.col("risk_x") / pl.col("risk_y"))
97-
* (-z * rr_log_se).exp()
98-
).alias("RR 95% LCI"),
99-
(
100-
(pl.col("risk_x") / pl.col("risk_y"))
101-
* (z * rr_log_se).exp()
102-
).alias("RR 95% UCI"),
103-
]
104-
)
105-
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
106-
col_order = group_cols + [
107-
"A_x",
108-
"A_y",
109-
"Risk Ratio",
110-
"RR 95% LCI",
111-
"RR 95% UCI",
112-
]
113-
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
114-
rr_comparisons.append(rr_comp)
115-
116-
else:
117-
rd_comp = comp.with_columns(
118-
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
119-
)
120-
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
121-
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
122-
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
123-
rd_comparisons.append(rd_comp)
124-
125-
rr_comp = comp.with_columns(
126-
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
127-
)
128-
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
129-
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
130-
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
131-
rr_comparisons.append(rr_comp)
125+
rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
126+
rd_comparisons.append(rd_comp)
127+
rr_comparisons.append(rr_comp)
132128

133129
risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
134130
risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()

pySEQTarget/analysis/_survival_pred.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None):
4646
lci = a / 2
4747
uci = 1 - lci
4848

49+
# Pre-compute the followup range once (starts at 1, not 0)
50+
followup_range = list(range(1, self.followup_max + 1))
51+
4952
SDT = (
5053
data.with_columns(
51-
[
52-
(
53-
pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
54-
).alias("TID")
55-
]
54+
[pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")]
5655
)
5756
.group_by("TID")
5857
.first()
5958
.drop(["followup", f"followup{self.indicator_squared}"])
60-
.with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
59+
.with_columns([pl.lit(followup_range).alias("followup")])
6160
.explode("followup")
6261
.with_columns(
63-
[
64-
(pl.col("followup") + 1).alias("followup"),
65-
(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
66-
]
62+
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
6763
)
6864
).sort([self.id_col, "trial", "followup"])
6965

pySEQTarget/expansion/_mapper.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
1313
.with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
1414
.with_columns(
1515
[
16-
pl.struct(
17-
[
18-
pl.col(time_col),
19-
pl.col(time_col).max().over(id_col).alias("max_time"),
20-
]
21-
)
22-
.map_elements(
23-
lambda x: list(range(x[time_col], x["max_time"] + 1)),
24-
return_dtype=pl.List(pl.Int64),
25-
)
26-
.alias("period")
16+
pl.int_ranges(
17+
pl.col(time_col),
18+
pl.col(time_col).max().over(id_col) + 1,
19+
).alias("period")
2720
]
2821
)
2922
.explode("period")

pySEQTarget/helpers/_bootstrap.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ def _prepare_boot_data(self, data, boot_id):
3535

3636

3737
def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
38-
obj = copy.deepcopy(obj)
38+
# Shallow copy the object and only deep copy mutable state that changes per-bootstrap
39+
obj = copy.copy(obj)
40+
# Deep copy only the mutable attributes that get modified during fitting
41+
obj.outcome_model = []
42+
obj.numerator_model = copy.copy(obj.numerator_model) if hasattr(obj, 'numerator_model') and obj.numerator_model else []
43+
obj.denominator_model = copy.copy(obj.denominator_model) if hasattr(obj, 'denominator_model') and obj.denominator_model else []
44+
3945
obj._rng = (
4046
np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
4147
)
@@ -104,13 +110,19 @@ def wrapper(self, *args, **kwargs):
104110
self._rng = original_rng
105111
self.DT = self._offloader.load_dataframe(original_DT_ref)
106112
else:
107-
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
108-
del original_DT
113+
# Keep original data in memory if offloading is disabled to avoid unnecessary I/O
114+
if self._offloader.enabled:
115+
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
116+
del original_DT
117+
else:
118+
original_DT_ref = original_DT
119+
109120
for i in tqdm(range(nboot), desc="Bootstrapping..."):
110121
self._current_boot_idx = i + 1
111122
tmp = self._offloader.load_dataframe(original_DT_ref)
112123
self.DT = _prepare_boot_data(self, tmp, i)
113-
del tmp
124+
if self._offloader.enabled:
125+
del tmp
114126
self.bootstrap_nboot = 0
115127
boot_fit = method(self, *args, **kwargs)
116128
results.append(boot_fit)

pySEQTarget/helpers/_offloader.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import lru_cache
12
from pathlib import Path
23
from typing import Any, Optional, Union
34

@@ -12,6 +13,25 @@ def __init__(self, enabled: bool, dir: str, compression: int = 3):
1213
self.enabled = enabled
1314
self.dir = Path(dir)
1415
self.compression = compression
16+
# Create a cached loader bound to this instance
17+
self._init_cache()
18+
19+
def _init_cache(self):
20+
"""Initialize the LRU cache for model loading."""
21+
self._cached_load = lru_cache(maxsize=32)(self._load_from_disk)
22+
23+
def __getstate__(self):
24+
"""Prepare state for pickling - exclude the unpicklable cache."""
25+
state = self.__dict__.copy()
26+
# Remove the cache wrapper which can't be pickled
27+
del state['_cached_load']
28+
return state
29+
30+
def __setstate__(self, state):
31+
"""Restore state after unpickling - recreate the cache."""
32+
self.__dict__.update(state)
33+
# Recreate the cache after unpickling
34+
self._init_cache()
1535

1636
def save_model(
1737
self, model: Any, name: str, boot_idx: Optional[int] = None
@@ -29,11 +49,20 @@ def save_model(
2949

3050
return str(filepath)
3151

52+
def _load_from_disk(self, filepath: str) -> Any:
53+
"""Internal method to load a model from disk (cached)."""
54+
return joblib.load(filepath)
55+
3256
def load_model(self, ref: Union[Any, str]) -> Any:
57+
"""Load a model, using cache for repeated loads of the same file."""
3358
if not self.enabled or not isinstance(ref, str):
3459
return ref
3560

36-
return joblib.load(ref)
61+
return self._cached_load(ref)
62+
63+
def clear_cache(self) -> None:
64+
"""Clear the model loading cache. Call between bootstrap iterations if needed."""
65+
self._cached_load.cache_clear()
3766

3867
def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]:
3968
if not self.enabled:

0 commit comments

Comments
 (0)