22from scipy import stats
33
44
5+ def _compute_rd_rr (comp , has_bootstrap , z = None , group_cols = None ):
6+ """
7+ Compute Risk Difference and Risk Ratio from a comparison dataframe.
8+ Consolidates the repeated calculation logic.
9+ """
10+ if group_cols is None :
11+ group_cols = []
12+
13+ if has_bootstrap :
14+ rd_se = (pl .col ("se_x" ).pow (2 ) + pl .col ("se_y" ).pow (2 )).sqrt ()
15+ rd_comp = comp .with_columns (
16+ [
17+ (pl .col ("risk_x" ) - pl .col ("risk_y" )).alias ("Risk Difference" ),
18+ (pl .col ("risk_x" ) - pl .col ("risk_y" ) - z * rd_se ).alias ("RD 95% LCI" ),
19+ (pl .col ("risk_x" ) - pl .col ("risk_y" ) + z * rd_se ).alias ("RD 95% UCI" ),
20+ ]
21+ )
22+ rd_comp = rd_comp .drop (["risk_x" , "risk_y" , "se_x" , "se_y" ])
23+ col_order = group_cols + [
24+ "A_x" ,
25+ "A_y" ,
26+ "Risk Difference" ,
27+ "RD 95% LCI" ,
28+ "RD 95% UCI" ,
29+ ]
30+ rd_comp = rd_comp .select ([c for c in col_order if c in rd_comp .columns ])
31+
32+ rr_log_se = (
33+ (pl .col ("se_x" ) / pl .col ("risk_x" )).pow (2 )
34+ + (pl .col ("se_y" ) / pl .col ("risk_y" )).pow (2 )
35+ ).sqrt ()
36+ rr_comp = comp .with_columns (
37+ [
38+ (pl .col ("risk_x" ) / pl .col ("risk_y" )).alias ("Risk Ratio" ),
39+ (
40+ (pl .col ("risk_x" ) / pl .col ("risk_y" )) * (- z * rr_log_se ).exp ()
41+ ).alias ("RR 95% LCI" ),
42+ (
43+ (pl .col ("risk_x" ) / pl .col ("risk_y" )) * (z * rr_log_se ).exp ()
44+ ).alias ("RR 95% UCI" ),
45+ ]
46+ )
47+ rr_comp = rr_comp .drop (["risk_x" , "risk_y" , "se_x" , "se_y" ])
48+ col_order = group_cols + ["A_x" , "A_y" , "Risk Ratio" , "RR 95% LCI" , "RR 95% UCI" ]
49+ rr_comp = rr_comp .select ([c for c in col_order if c in rr_comp .columns ])
50+ else :
51+ rd_comp = comp .with_columns (
52+ (pl .col ("risk_x" ) - pl .col ("risk_y" )).alias ("Risk Difference" )
53+ )
54+ rd_comp = rd_comp .drop (["risk_x" , "risk_y" ])
55+ col_order = group_cols + ["A_x" , "A_y" , "Risk Difference" ]
56+ rd_comp = rd_comp .select ([c for c in col_order if c in rd_comp .columns ])
57+
58+ rr_comp = comp .with_columns (
59+ (pl .col ("risk_x" ) / pl .col ("risk_y" )).alias ("Risk Ratio" )
60+ )
61+ rr_comp = rr_comp .drop (["risk_x" , "risk_y" ])
62+ col_order = group_cols + ["A_x" , "A_y" , "Risk Ratio" ]
63+ rr_comp = rr_comp .select ([c for c in col_order if c in rr_comp .columns ])
64+
65+ return rd_comp , rr_comp
66+
67+
568def _risk_estimates (self ):
669 last_followup = self .km_data ["followup" ].max ()
770 risk = self .km_data .filter (
871 (pl .col ("followup" ) == last_followup ) & (pl .col ("estimate" ) == "risk" )
972 )
1073
1174 group_cols = [self .subgroup_colname ] if self .subgroup_colname else []
12- rd_comparisons = []
13- rr_comparisons = []
75+ has_bootstrap = self .bootstrap_nboot > 0
1476
15- if self . bootstrap_nboot > 0 :
77+ if has_bootstrap :
1678 alpha = 1 - self .bootstrap_CI
1779 z = stats .norm .ppf (1 - alpha / 2 )
80+ else :
81+ z = None
82+
83+ # Pre-extract data for each treatment level once (avoid repeated filtering)
84+ risk_by_level = {}
85+ for tx in self .treatment_level :
86+ level_data = risk .filter (pl .col (self .treatment_col ) == tx )
87+ risk_by_level [tx ] = {
88+ "pred" : level_data .select (group_cols + ["pred" ]),
89+ }
90+ if has_bootstrap :
91+ risk_by_level [tx ]["SE" ] = level_data .select (group_cols + ["SE" ])
92+
93+ rd_comparisons = []
94+ rr_comparisons = []
1895
1996 for tx_x in self .treatment_level :
2097 for tx_y in self .treatment_level :
2198 if tx_x == tx_y :
2299 continue
23100
24- risk_x = (
25- risk .filter (pl .col (self .treatment_col ) == tx_x )
26- .select (group_cols + ["pred" ])
27- .rename ({"pred" : "risk_x" })
28- )
29-
30- risk_y = (
31- risk .filter (pl .col (self .treatment_col ) == tx_y )
32- .select (group_cols + ["pred" ])
33- .rename ({"pred" : "risk_y" })
34- )
101+ # Use pre-extracted data instead of filtering again
102+ risk_x = risk_by_level [tx_x ]["pred" ].rename ({"pred" : "risk_x" })
103+ risk_y = risk_by_level [tx_y ]["pred" ].rename ({"pred" : "risk_y" })
35104
36105 if group_cols :
37106 comp = risk_x .join (risk_y , on = group_cols , how = "left" )
@@ -42,18 +111,9 @@ def _risk_estimates(self):
42111 [pl .lit (tx_x ).alias ("A_x" ), pl .lit (tx_y ).alias ("A_y" )]
43112 )
44113
45- if self .bootstrap_nboot > 0 :
46- se_x = (
47- risk .filter (pl .col (self .treatment_col ) == tx_x )
48- .select (group_cols + ["SE" ])
49- .rename ({"SE" : "se_x" })
50- )
51-
52- se_y = (
53- risk .filter (pl .col (self .treatment_col ) == tx_y )
54- .select (group_cols + ["SE" ])
55- .rename ({"SE" : "se_y" })
56- )
114+ if has_bootstrap :
115+ se_x = risk_by_level [tx_x ]["SE" ].rename ({"SE" : "se_x" })
116+ se_y = risk_by_level [tx_y ]["SE" ].rename ({"SE" : "se_y" })
57117
58118 if group_cols :
59119 comp = comp .join (se_x , on = group_cols , how = "left" )
@@ -62,73 +122,9 @@ def _risk_estimates(self):
62122 comp = comp .join (se_x , how = "cross" )
63123 comp = comp .join (se_y , how = "cross" )
64124
65- rd_se = (pl .col ("se_x" ).pow (2 ) + pl .col ("se_y" ).pow (2 )).sqrt ()
66- rd_comp = comp .with_columns (
67- [
68- (pl .col ("risk_x" ) - pl .col ("risk_y" )).alias ("Risk Difference" ),
69- (pl .col ("risk_x" ) - pl .col ("risk_y" ) - z * rd_se ).alias (
70- "RD 95% LCI"
71- ),
72- (pl .col ("risk_x" ) - pl .col ("risk_y" ) + z * rd_se ).alias (
73- "RD 95% UCI"
74- ),
75- ]
76- )
77- rd_comp = rd_comp .drop (["risk_x" , "risk_y" , "se_x" , "se_y" ])
78- col_order = group_cols + [
79- "A_x" ,
80- "A_y" ,
81- "Risk Difference" ,
82- "RD 95% LCI" ,
83- "RD 95% UCI" ,
84- ]
85- rd_comp = rd_comp .select ([c for c in col_order if c in rd_comp .columns ])
86- rd_comparisons .append (rd_comp )
87-
88- rr_log_se = (
89- (pl .col ("se_x" ) / pl .col ("risk_x" )).pow (2 )
90- + (pl .col ("se_y" ) / pl .col ("risk_y" )).pow (2 )
91- ).sqrt ()
92- rr_comp = comp .with_columns (
93- [
94- (pl .col ("risk_x" ) / pl .col ("risk_y" )).alias ("Risk Ratio" ),
95- (
96- (pl .col ("risk_x" ) / pl .col ("risk_y" ))
97- * (- z * rr_log_se ).exp ()
98- ).alias ("RR 95% LCI" ),
99- (
100- (pl .col ("risk_x" ) / pl .col ("risk_y" ))
101- * (z * rr_log_se ).exp ()
102- ).alias ("RR 95% UCI" ),
103- ]
104- )
105- rr_comp = rr_comp .drop (["risk_x" , "risk_y" , "se_x" , "se_y" ])
106- col_order = group_cols + [
107- "A_x" ,
108- "A_y" ,
109- "Risk Ratio" ,
110- "RR 95% LCI" ,
111- "RR 95% UCI" ,
112- ]
113- rr_comp = rr_comp .select ([c for c in col_order if c in rr_comp .columns ])
114- rr_comparisons .append (rr_comp )
115-
116- else :
117- rd_comp = comp .with_columns (
118- (pl .col ("risk_x" ) - pl .col ("risk_y" )).alias ("Risk Difference" )
119- )
120- rd_comp = rd_comp .drop (["risk_x" , "risk_y" ])
121- col_order = group_cols + ["A_x" , "A_y" , "Risk Difference" ]
122- rd_comp = rd_comp .select ([c for c in col_order if c in rd_comp .columns ])
123- rd_comparisons .append (rd_comp )
124-
125- rr_comp = comp .with_columns (
126- (pl .col ("risk_x" ) / pl .col ("risk_y" )).alias ("Risk Ratio" )
127- )
128- rr_comp = rr_comp .drop (["risk_x" , "risk_y" ])
129- col_order = group_cols + ["A_x" , "A_y" , "Risk Ratio" ]
130- rr_comp = rr_comp .select ([c for c in col_order if c in rr_comp .columns ])
131- rr_comparisons .append (rr_comp )
125+ rd_comp , rr_comp = _compute_rd_rr (comp , has_bootstrap , z , group_cols )
126+ rd_comparisons .append (rd_comp )
127+ rr_comparisons .append (rr_comp )
132128
133129 risk_difference = pl .concat (rd_comparisons ) if rd_comparisons else pl .DataFrame ()
134130 risk_ratio = pl .concat (rr_comparisons ) if rr_comparisons else pl .DataFrame ()
0 commit comments