diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index b219d00..72c0145 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,5 +1,3 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python name: Test pySEQTarget on: @@ -13,39 +11,34 @@ permissions: jobs: test: - runs-on: ubuntu-latest + runs-on: macos-26 strategy: matrix: python-version: ["3.11", "3.12", "3.13", "3.14"] - + steps: - uses: actions/checkout@v6 - + + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - + run: uv python install ${{ matrix.python-version }} + - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install flake8 pytest pytest-cov - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - - name: Install pySEQTarget package - run: | - pip install -e . - + uv venv --python ${{ matrix.python-version }} + uv pip install flake8 pytest pytest-cov + uv pip install -e . + - name: Lint with flake8 run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - + uv run flake8 . --exclude=.venv --count --select=E9,F63,F7,F82 --show-source --statistics + uv run flake8 . --exclude=.venv --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest run: | - pytest tests/ -v --cov=pySEQTarget --cov-report=xml + uv run pytest tests/ -v --cov=pySEQTarget --cov-report=xml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/docs/conf.py b/docs/conf.py index 8f384a7..c366c95 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,14 +12,14 @@ version = importlib.metadata.version("pySEQTarget") if not version: - version = "0.12.1" + version = "0.12.2" sys.path.insert(0, os.path.abspath("../")) project = "pySEQTarget" copyright = ( - f"{date.today().year}, Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan" + f"{date.today().year}, Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernán" ) -author = "Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan" +author = "Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernán" release = version # -- General configuration --------------------------------------------------- diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 8d8602d..fdf1602 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -183,7 +183,7 @@ def bootstrap(self, **kwargs) -> None: setattr(self, key, value) else: raise ValueError(f"Unknown argument: {key}") - UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list() + UIDs = self.DT.select(pl.col(self.id_col)).unique().sort(self.id_col).to_series().to_list() NIDs = len(UIDs) self._boot_samples = [] diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 3240447..4f39a6c 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -24,6 +24,8 @@ def _calculate_hazard(self): def _calculate_hazard_single(self, data, idx=None, val=None): + if self.seed is not None: + self._rng = np.random.RandomState(self.seed) full_log_hr = _hazard_handler(self, data, idx, 0, self._rng) if full_log_hr is None or np.isnan(full_log_hr): @@ -33,6 +35,8 @@ def _calculate_hazard_single(self, data, idx=None, val=None): boot_log_hrs = [] for boot_idx in range(len(self._boot_samples)): + if self.seed is not None: + self._rng = np.random.RandomState(self.seed + boot_idx + 1) id_counts = self._boot_samples[boot_idx] boot_data_list = [] @@ -83,6 +87,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng): data.select(keep_cols) .group_by([self.id_col, "trial"]) .first() + .sort([self.id_col, "trial"]) .with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")]) .explode("followup") .with_columns( diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index b314ed4..6e6579c 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -55,6 +55,7 @@ def _calculate_risk(self, data, idx=None, val=None): ) .group_by("TID") .first() + .sort("TID") .drop(["followup", f"followup{self.indicator_squared}"]) .with_columns([pl.lit(followup_range).alias("followup")]) .explode("followup") diff --git a/pySEQTarget/expansion/_selection.py b/pySEQTarget/expansion/_selection.py index c7a03d0..63b7361 100644 --- a/pySEQTarget/expansion/_selection.py +++ b/pySEQTarget/expansion/_selection.py @@ -20,6 +20,7 @@ def _random_selection(self): == self.treatment_level[0] ) .unique("trialID") + .sort("trialID") .get_column("trialID") .to_list() ) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 08becbc..b828176 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -77,6 +77,10 @@ def wrapper(self, *args, **kwargs): results = [] original_DT = self.DT + seed = getattr(self, "seed", None) + if seed is not None: + self._rng = np.random.RandomState(seed) + self._current_boot_idx = None full = method(self, *args, **kwargs) results.append(full) @@ -127,6 +131,8 @@ def wrapper(self, *args, **kwargs): for i in tqdm(range(nboot), desc="Bootstrapping..."): self._current_boot_idx = i + 1 + if seed is not None: + self._rng = np.random.RandomState(seed + i) tmp = self._offloader.load_dataframe(original_DT_ref) self.DT = _prepare_boot_data(self, tmp, i) if self._offloader.enabled: diff --git a/pyproject.toml b/pyproject.toml index 2c35aed..99330f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.1" +version = "0.12.2" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} @@ -56,7 +56,7 @@ Repository = "https://github.com/CausalInference/pySEQTarget" "Ryan O'Dea (ORCID)" = "https://orcid.org/0009-0000-0103-9546" "Alejandro Szmulewicz (ORCID)" = "https://orcid.org/0000-0002-2664-802X" "Tom Palmer (ORCID)" = "https://orcid.org/0000-0003-4655-4511" -"Miguel Hernan (ORCID)" = "https://orcid.org/0000-0003-1619-8456" +"Miguel Hernán (ORCID)" = "https://orcid.org/0000-0003-1619-8456" "University of Bristol (ROR)" = "https://ror.org/0524sp257" "Harvard University (ROR)" = "https://ror.org/03vek6s52" diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py new file mode 100644 index 0000000..4f0969e --- /dev/null +++ b/tests/test_reproducibility.py @@ -0,0 +1,109 @@ +import os + +import numpy as np +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _make_seq(seed, **extra_opts): + data = load_data("SEQdata") + return SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(seed=seed, **extra_opts), + ) + + +def test_hazard_reproducible_with_seed(): + results = [] + for _ in range(2): + s = _make_seq(seed=42, hazard_estimate=True) + s.expand() + s.fit() + s.hazard() + results.append(s.hazard_ratio) + + assert results[0]["Hazard ratio"][0] == results[1]["Hazard ratio"][0] + + +def test_hazard_bootstrap_se_reproducible_with_seed(): + results = [] + for _ in range(2): + s = _make_seq(seed=42, hazard_estimate=True, bootstrap_nboot=3) + s.expand() + s.bootstrap() + s.fit() + s.hazard() + results.append(s.hazard_ratio) + + assert results[0]["Hazard ratio"][0] == results[1]["Hazard ratio"][0] + assert results[0]["LCI"][0] == results[1]["LCI"][0] + assert results[0]["UCI"][0] == results[1]["UCI"][0] + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Bootstrap reproducibility test hangs in CI" +) +def test_hazard_bootstrap_percentile_reproducible_with_seed(): + results = [] + for _ in range(2): + s = _make_seq( + seed=42, + hazard_estimate=True, + bootstrap_nboot=3, + bootstrap_CI_method="percentile", + ) + s.expand() + s.bootstrap() + s.fit() + s.hazard() + results.append(s.hazard_ratio) + + assert results[0]["Hazard ratio"][0] == results[1]["Hazard ratio"][0] + assert results[0]["LCI"][0] == results[1]["LCI"][0] + assert results[0]["UCI"][0] == results[1]["UCI"][0] + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Reproducibility test hangs in CI" +) +def test_survival_reproducible_with_seed(): + results = [] + for _ in range(2): + s = _make_seq(seed=42, km_curves=True) + s.expand() + s.fit() + s.survival() + results.append(s.km_data) + + np.testing.assert_allclose( + results[0]["pred"].to_numpy(), results[1]["pred"].to_numpy(), atol=1e-14 + ) + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Bootstrap reproducibility test hangs in CI" +) +def test_survival_bootstrap_reproducible_with_seed(): + results = [] + for _ in range(2): + s = _make_seq(seed=42, km_curves=True, bootstrap_nboot=3) + s.expand() + s.bootstrap() + s.fit() + s.survival() + results.append(s.km_data) + + for col in ["pred", "SE", "LCI", "UCI"]: + np.testing.assert_allclose( + results[0][col].to_numpy(), results[1][col].to_numpy(), atol=1e-14 + )