Skip to content

Commit 7d94a72

Browse files
authored
Merge pull request #34 from CausalInference/2026-02-devel
Make main estimate and bootstrapping deterministic for a seed
2 parents 11d07e5 + e15942c commit 7d94a72

9 files changed

Lines changed: 144 additions & 29 deletions

File tree

.github/workflows/python-app.yml

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# This workflow will install Python dependencies, run tests and lint with a single version of Python
2-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
31
name: Test pySEQTarget
42

53
on:
@@ -13,39 +11,34 @@ permissions:
1311

1412
jobs:
1513
test:
16-
runs-on: ubuntu-latest
14+
runs-on: macos-26
1715
strategy:
1816
matrix:
1917
python-version: ["3.11", "3.12", "3.13", "3.14"]
20-
18+
2119
steps:
2220
- uses: actions/checkout@v6
23-
21+
22+
- name: Install uv
23+
uses: astral-sh/setup-uv@v7
24+
2425
- name: Set up Python ${{ matrix.python-version }}
25-
uses: actions/setup-python@v6
26-
with:
27-
python-version: ${{ matrix.python-version }}
28-
26+
run: uv python install ${{ matrix.python-version }}
27+
2928
- name: Install dependencies
3029
run: |
31-
python -m pip install --upgrade pip
32-
pip install flake8 pytest pytest-cov
33-
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
34-
35-
- name: Install pySEQTarget package
36-
run: |
37-
pip install -e .
38-
30+
uv venv --python ${{ matrix.python-version }}
31+
uv pip install flake8 pytest pytest-cov
32+
uv pip install -e .
33+
3934
- name: Lint with flake8
4035
run: |
41-
# stop the build if there are Python syntax errors or undefined names
42-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
43-
# exit-zero treats all errors as warnings
44-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
45-
36+
uv run flake8 . --exclude=.venv --count --select=E9,F63,F7,F82 --show-source --statistics
37+
uv run flake8 . --exclude=.venv --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38+
4639
- name: Test with pytest
4740
run: |
48-
pytest tests/ -v --cov=pySEQTarget --cov-report=xml
41+
uv run pytest tests/ -v --cov=pySEQTarget --cov-report=xml
4942
5043
- name: Upload coverage reports to Codecov
5144
uses: codecov/codecov-action@v5

docs/conf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212

1313
version = importlib.metadata.version("pySEQTarget")
1414
if not version:
15-
version = "0.12.1"
15+
version = "0.12.2"
1616
sys.path.insert(0, os.path.abspath("../"))
1717

1818
project = "pySEQTarget"
1919
copyright = (
20-
f"{date.today().year}, Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan"
20+
f"{date.today().year}, Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernán"
2121
)
22-
author = "Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan"
22+
author = "Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernán"
2323
release = version
2424

2525
# -- General configuration ---------------------------------------------------

pySEQTarget/SEQuential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def bootstrap(self, **kwargs) -> None:
183183
setattr(self, key, value)
184184
else:
185185
raise ValueError(f"Unknown argument: {key}")
186-
UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
186+
UIDs = self.DT.select(pl.col(self.id_col)).unique().sort(self.id_col).to_series().to_list()
187187
NIDs = len(UIDs)
188188

189189
self._boot_samples = []

pySEQTarget/analysis/_hazard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def _calculate_hazard(self):
2424

2525

2626
def _calculate_hazard_single(self, data, idx=None, val=None):
27+
if self.seed is not None:
28+
self._rng = np.random.RandomState(self.seed)
2729
full_log_hr = _hazard_handler(self, data, idx, 0, self._rng)
2830

2931
if full_log_hr is None or np.isnan(full_log_hr):
@@ -33,6 +35,8 @@ def _calculate_hazard_single(self, data, idx=None, val=None):
3335
boot_log_hrs = []
3436

3537
for boot_idx in range(len(self._boot_samples)):
38+
if self.seed is not None:
39+
self._rng = np.random.RandomState(self.seed + boot_idx + 1)
3640
id_counts = self._boot_samples[boot_idx]
3741

3842
boot_data_list = []
@@ -83,6 +87,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
8387
data.select(keep_cols)
8488
.group_by([self.id_col, "trial"])
8589
.first()
90+
.sort([self.id_col, "trial"])
8691
.with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")])
8792
.explode("followup")
8893
.with_columns(

pySEQTarget/analysis/_survival_pred.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _calculate_risk(self, data, idx=None, val=None):
5555
)
5656
.group_by("TID")
5757
.first()
58+
.sort("TID")
5859
.drop(["followup", f"followup{self.indicator_squared}"])
5960
.with_columns([pl.lit(followup_range).alias("followup")])
6061
.explode("followup")

pySEQTarget/expansion/_selection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _random_selection(self):
2020
== self.treatment_level[0]
2121
)
2222
.unique("trialID")
23+
.sort("trialID")
2324
.get_column("trialID")
2425
.to_list()
2526
)

pySEQTarget/helpers/_bootstrap.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def wrapper(self, *args, **kwargs):
7777
results = []
7878
original_DT = self.DT
7979

80+
seed = getattr(self, "seed", None)
81+
if seed is not None:
82+
self._rng = np.random.RandomState(seed)
83+
8084
self._current_boot_idx = None
8185
full = method(self, *args, **kwargs)
8286
results.append(full)
@@ -127,6 +131,8 @@ def wrapper(self, *args, **kwargs):
127131

128132
for i in tqdm(range(nboot), desc="Bootstrapping..."):
129133
self._current_boot_idx = i + 1
134+
if seed is not None:
135+
self._rng = np.random.RandomState(seed + i)
130136
tmp = self._offloader.load_dataframe(original_DT_ref)
131137
self.DT = _prepare_boot_data(self, tmp, i)
132138
if self._offloader.enabled:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "pySEQTarget"
7-
version = "0.12.1"
7+
version = "0.12.2"
88
description = "Sequentially Nested Target Trial Emulation"
99
readme = "README.md"
1010
license = {text = "MIT"}
@@ -56,7 +56,7 @@ Repository = "https://github.com/CausalInference/pySEQTarget"
5656
"Ryan O'Dea (ORCID)" = "https://orcid.org/0009-0000-0103-9546"
5757
"Alejandro Szmulewicz (ORCID)" = "https://orcid.org/0000-0002-2664-802X"
5858
"Tom Palmer (ORCID)" = "https://orcid.org/0000-0003-4655-4511"
59-
"Miguel Hernan (ORCID)" = "https://orcid.org/0000-0003-1619-8456"
59+
"Miguel Hernán (ORCID)" = "https://orcid.org/0000-0003-1619-8456"
6060
"University of Bristol (ROR)" = "https://ror.org/0524sp257"
6161
"Harvard University (ROR)" = "https://ror.org/03vek6s52"
6262

tests/test_reproducibility.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
3+
import numpy as np
4+
import pytest
5+
6+
from pySEQTarget import SEQopts, SEQuential
7+
from pySEQTarget.data import load_data
8+
9+
10+
def _make_seq(seed, **extra_opts):
11+
data = load_data("SEQdata")
12+
return SEQuential(
13+
data,
14+
id_col="ID",
15+
time_col="time",
16+
eligible_col="eligible",
17+
treatment_col="tx_init",
18+
outcome_col="outcome",
19+
time_varying_cols=["N", "L", "P"],
20+
fixed_cols=["sex"],
21+
method="ITT",
22+
parameters=SEQopts(seed=seed, **extra_opts),
23+
)
24+
25+
26+
def test_hazard_reproducible_with_seed():
27+
results = []
28+
for _ in range(2):
29+
s = _make_seq(seed=42, hazard_estimate=True)
30+
s.expand()
31+
s.fit()
32+
s.hazard()
33+
results.append(s.hazard_ratio)
34+
35+
assert results[0]["Hazard ratio"][0] == results[1]["Hazard ratio"][0]
36+
37+
38+
def test_hazard_bootstrap_se_reproducible_with_seed():
39+
results = []
40+
for _ in range(2):
41+
s = _make_seq(seed=42, hazard_estimate=True, bootstrap_nboot=3)
42+
s.expand()
43+
s.bootstrap()
44+
s.fit()
45+
s.hazard()
46+
results.append(s.hazard_ratio)
47+
48+
assert results[0]["Hazard ratio"][0] == results[1]["Hazard ratio"][0]
49+
assert results[0]["LCI"][0] == results[1]["LCI"][0]
50+
assert results[0]["UCI"][0] == results[1]["UCI"][0]
51+
52+
53+
@pytest.mark.skipif(
54+
os.getenv("CI") == "true", reason="Bootstrap reproducibility test hangs in CI"
55+
)
56+
def test_hazard_bootstrap_percentile_reproducible_with_seed():
57+
results = []
58+
for _ in range(2):
59+
s = _make_seq(
60+
seed=42,
61+
hazard_estimate=True,
62+
bootstrap_nboot=3,
63+
bootstrap_CI_method="percentile",
64+
)
65+
s.expand()
66+
s.bootstrap()
67+
s.fit()
68+
s.hazard()
69+
results.append(s.hazard_ratio)
70+
71+
assert results[0]["Hazard ratio"][0] == results[1]["Hazard ratio"][0]
72+
assert results[0]["LCI"][0] == results[1]["LCI"][0]
73+
assert results[0]["UCI"][0] == results[1]["UCI"][0]
74+
75+
76+
@pytest.mark.skipif(
77+
os.getenv("CI") == "true", reason="Reproducibility test hangs in CI"
78+
)
79+
def test_survival_reproducible_with_seed():
80+
results = []
81+
for _ in range(2):
82+
s = _make_seq(seed=42, km_curves=True)
83+
s.expand()
84+
s.fit()
85+
s.survival()
86+
results.append(s.km_data)
87+
88+
np.testing.assert_allclose(
89+
results[0]["pred"].to_numpy(), results[1]["pred"].to_numpy(), atol=1e-14
90+
)
91+
92+
93+
@pytest.mark.skipif(
94+
os.getenv("CI") == "true", reason="Bootstrap reproducibility test hangs in CI"
95+
)
96+
def test_survival_bootstrap_reproducible_with_seed():
97+
results = []
98+
for _ in range(2):
99+
s = _make_seq(seed=42, km_curves=True, bootstrap_nboot=3)
100+
s.expand()
101+
s.bootstrap()
102+
s.fit()
103+
s.survival()
104+
results.append(s.km_data)
105+
106+
for col in ["pred", "SE", "LCI", "UCI"]:
107+
np.testing.assert_allclose(
108+
results[0][col].to_numpy(), results[1][col].to_numpy(), atol=1e-14
109+
)

0 commit comments

Comments
 (0)