Skip to content

Commit 29de8a0

Browse files
authored
Allow fit methods to accept pd.Series and pd.DataFrame (#62) (#92)
* Allow fit methods to accept pd.Series and pd.DataFrame (#62) * Add dev dependencies for CI * Fix type hints * Add a test for utils * Fix docstrings of fit() methods
1 parent 32e0873 commit 29de8a0

7 files changed

Lines changed: 1216 additions & 987 deletions

File tree

dte_adj/local.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from typing import Tuple
35
from dte_adj.stratified import (
46
SimpleStratifiedDistributionEstimator,
57
AdjustedStratifiedDistributionEstimator,
68
)
7-
from dte_adj.util import compute_ldte, compute_lpte
9+
from dte_adj.util import ArrayLike, compute_ldte, compute_lpte, _convert_to_ndarray
810

911

1012
class SimpleLocalDistributionEstimator(SimpleStratifiedDistributionEstimator):
@@ -28,25 +30,26 @@ def __init__(self):
2830

2931
def fit(
3032
self,
31-
covariates: np.ndarray,
32-
treatment_arms: np.ndarray,
33-
treatment_indicator: np.ndarray,
34-
outcomes: np.ndarray,
35-
strata: np.ndarray,
36-
) -> "SimpleLocalDistributionEstimator":
33+
covariates: ArrayLike,
34+
treatment_arms: ArrayLike,
35+
treatment_indicator: ArrayLike,
36+
outcomes: ArrayLike,
37+
strata: ArrayLike,
38+
) -> SimpleLocalDistributionEstimator:
3739
"""
3840
Train the SimpleLocalDistributionEstimator.
3941
4042
Args:
41-
covariates (np.ndarray): Pre-treatment covariates.
42-
treatment_arms (np.ndarray): Treatment assignment variable (Z).
43-
treatment_indicator (np.ndarray): Treatment indicator variable (D).
44-
outcomes (np.ndarray): Scalar-valued observed outcome.
45-
strata (np.ndarray): Stratum indicators.
43+
covariates: Pre-treatment covariates.
44+
treatment_arms: Treatment assignment variable (Z).
45+
treatment_indicator: Treatment indicator variable (D).
46+
outcomes: Scalar-valued observed outcome.
47+
strata: Stratum indicators.
4648
4749
Returns:
4850
SimpleLocalDistributionEstimator: The fitted estimator.
4951
"""
52+
treatment_indicator = _convert_to_ndarray(treatment_indicator)
5053
super().fit(covariates, treatment_arms, outcomes, strata)
5154
self.treatment_indicator = treatment_indicator
5255

@@ -196,25 +199,26 @@ class AdjustedLocalDistributionEstimator(AdjustedStratifiedDistributionEstimator
196199

197200
def fit(
198201
self,
199-
covariates: np.ndarray,
200-
treatment_arms: np.ndarray,
201-
treatment_indicator: np.ndarray,
202-
outcomes: np.ndarray,
203-
strata: np.ndarray,
204-
) -> "AdjustedLocalDistributionEstimator":
202+
covariates: ArrayLike,
203+
treatment_arms: ArrayLike,
204+
treatment_indicator: ArrayLike,
205+
outcomes: ArrayLike,
206+
strata: ArrayLike,
207+
) -> AdjustedLocalDistributionEstimator:
205208
"""
206209
Train the AdjustedLocalDistributionEstimator.
207210
208211
Args:
209-
covariates (np.ndarray): Pre-treatment covariates.
210-
treatment_arms (np.ndarray): Treatment assignment variable (Z).
211-
treatment_indicator (np.ndarray): Treatment indicator variable (D).
212-
outcomes (np.ndarray): Scalar-valued observed outcome.
213-
strata (np.ndarray): Stratum indicators.
212+
covariates: Pre-treatment covariates.
213+
treatment_arms: Treatment assignment variable (Z).
214+
treatment_indicator: Treatment indicator variable (D).
215+
outcomes: Scalar-valued observed outcome.
216+
strata: Stratum indicators.
214217
215218
Returns:
216219
AdjustedLocalDistributionEstimator: The fitted estimator.
217220
"""
221+
treatment_indicator = _convert_to_ndarray(treatment_indicator)
218222
super().fit(covariates, treatment_arms, outcomes, strata)
219223
self.treatment_indicator = treatment_indicator
220224

dte_adj/simple.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from dte_adj.stratified import (
35
SimpleStratifiedDistributionEstimator,
46
AdjustedStratifiedDistributionEstimator,
57
)
8+
from dte_adj.util import ArrayLike, _convert_to_ndarray
69

710

811
class SimpleDistributionEstimator(SimpleStratifiedDistributionEstimator):
@@ -45,19 +48,23 @@ def __init__(self):
4548
super().__init__()
4649

4750
def fit(
48-
self, covariates: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
49-
) -> "SimpleDistributionEstimator":
51+
self, covariates: ArrayLike, treatment_arms: ArrayLike, outcomes: ArrayLike
52+
) -> SimpleDistributionEstimator:
5053
"""
5154
Set parameters.
5255
5356
Args:
54-
covariates (np.ndarray): Pre-treatment covariates.
55-
treatment_arms (np.ndarray): The index of the treatment arm.
56-
outcomes (np.ndarray): Scalar-valued observed outcome.
57+
covariates: Pre-treatment covariates.
58+
treatment_arms: The index of the treatment arm.
59+
outcomes: Scalar-valued observed outcome.
5760
5861
Returns:
5962
SimpleDistributionEstimator: The fitted estimator.
6063
"""
64+
covariates = _convert_to_ndarray(covariates)
65+
treatment_arms = _convert_to_ndarray(treatment_arms)
66+
outcomes = _convert_to_ndarray(outcomes)
67+
6168
if covariates.shape[0] != treatment_arms.shape[0]:
6269
raise ValueError("The shape of covariates and treatment_arm should be same")
6370

@@ -105,19 +112,23 @@ class AdjustedDistributionEstimator(AdjustedStratifiedDistributionEstimator):
105112
"""
106113

107114
def fit(
108-
self, covariates: np.ndarray, treatment_arms: np.ndarray, outcomes: np.ndarray
109-
) -> "AdjustedDistributionEstimator":
115+
self, covariates: ArrayLike, treatment_arms: ArrayLike, outcomes: ArrayLike
116+
) -> AdjustedDistributionEstimator:
110117
"""
111118
Set parameters.
112119
113120
Args:
114-
covariates (np.ndarray): Pre-treatment covariates.
115-
treatment_arms (np.ndarray): The index of the treatment arm.
116-
outcomes (np.ndarray): Scalar-valued observed outcome.
121+
covariates: Pre-treatment covariates.
122+
treatment_arms: The index of the treatment arm.
123+
outcomes: Scalar-valued observed outcome.
117124
118125
Returns:
119126
AdjustedDistributionEstimator: The fitted estimator.
120127
"""
128+
covariates = _convert_to_ndarray(covariates)
129+
treatment_arms = _convert_to_ndarray(treatment_arms)
130+
outcomes = _convert_to_ndarray(outcomes)
131+
121132
if covariates.shape[0] != treatment_arms.shape[0]:
122133
raise ValueError("The shape of covariates and treatment_arm should be same")
123134

dte_adj/stratified.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,39 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from typing import Tuple, Any
35
from copy import deepcopy
46
from dte_adj.base import DistributionEstimatorBase
7+
from dte_adj.util import ArrayLike, _convert_to_ndarray
58

69

710
class SimpleStratifiedDistributionEstimator(DistributionEstimatorBase):
811
"""A class is for estimating the empirical distribution function and computing the Distributional parameters for CAR."""
912

1013
def fit(
1114
self,
12-
covariates: np.ndarray,
13-
treatment_arms: np.ndarray,
14-
outcomes: np.ndarray,
15-
strata: np.ndarray,
16-
) -> "DistributionEstimatorBase":
15+
covariates: ArrayLike,
16+
treatment_arms: ArrayLike,
17+
outcomes: ArrayLike,
18+
strata: ArrayLike,
19+
) -> DistributionEstimatorBase:
1720
"""
1821
Train the DistributionEstimatorBase.
1922
2023
Args:
21-
covariates (np.ndarray): Pre-treatment covariates.
22-
treatment_arms (np.ndarray): The index of the treatment arm.
23-
outcomes (np.ndarray): Scalar-valued observed outcome.
24+
covariates: Pre-treatment covariates.
25+
treatment_arms: The index of the treatment arm.
26+
outcomes: Scalar-valued observed outcome.
27+
strata: Stratum indicators.
2428
2529
Returns:
2630
DistributionEstimatorBase: The fitted estimator.
2731
"""
32+
covariates = _convert_to_ndarray(covariates)
33+
treatment_arms = _convert_to_ndarray(treatment_arms)
34+
outcomes = _convert_to_ndarray(outcomes)
35+
strata = _convert_to_ndarray(strata)
36+
2837
if covariates.shape[0] != treatment_arms.shape[0]:
2938
raise ValueError("The shape of covariates and treatment_arm should be same")
3039

@@ -168,22 +177,28 @@ def __init__(self, base_model: Any, folds=3, is_multi_task=False):
168177

169178
def fit(
170179
self,
171-
covariates: np.ndarray,
172-
treatment_arms: np.ndarray,
173-
outcomes: np.ndarray,
174-
strata: np.ndarray,
175-
) -> "DistributionEstimatorBase":
180+
covariates: ArrayLike,
181+
treatment_arms: ArrayLike,
182+
outcomes: ArrayLike,
183+
strata: ArrayLike,
184+
) -> DistributionEstimatorBase:
176185
"""
177186
Train the DistributionEstimatorBase.
178187
179188
Args:
180-
covariates (np.ndarray): Pre-treatment covariates.
181-
treatment_arms (np.ndarray): The index of the treatment arm.
182-
outcomes (np.ndarray): Scalar-valued observed outcome.
189+
covariates: Pre-treatment covariates.
190+
treatment_arms: The index of the treatment arm.
191+
outcomes: Scalar-valued observed outcome.
192+
strata: Stratum indicators.
183193
184194
Returns:
185195
DistributionEstimatorBase: The fitted estimator.
186196
"""
197+
covariates = _convert_to_ndarray(covariates)
198+
treatment_arms = _convert_to_ndarray(treatment_arms)
199+
outcomes = _convert_to_ndarray(outcomes)
200+
strata = _convert_to_ndarray(strata)
201+
187202
if covariates.shape[0] != treatment_arms.shape[0]:
188203
raise ValueError("The shape of covariates and treatment_arm should be same")
189204

dte_adj/util.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
from scipy.stats import norm
3-
from typing import Tuple, TYPE_CHECKING
5+
from typing import Tuple, Union, TYPE_CHECKING
46

57
if TYPE_CHECKING:
8+
import pandas as pd
9+
import polars as pl
10+
611
from dte_adj.local import (
712
SimpleStratifiedDistributionEstimator,
813
AdjustedLocalDistributionEstimator,
914
)
1015

16+
ArrayLike = Union[
17+
np.ndarray,
18+
list,
19+
tuple,
20+
"pd.DataFrame",
21+
"pd.Series",
22+
"pl.DataFrame",
23+
"pl.Series",
24+
]
25+
26+
def _convert_to_ndarray(data: ArrayLike) -> np.ndarray:
27+
"""Convert array-like data to np.ndarray if needed."""
28+
if isinstance(data, np.ndarray):
29+
return data
30+
if hasattr(data, "to_numpy"):
31+
return data.to_numpy()
32+
return np.asarray(data)
33+
1134

1235
def compute_confidence_intervals(
1336
vec_y: np.ndarray,

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ dev = [
3030
"ruff>=0.12.2,<0.16.0",
3131
"sphinx>=7.3.7,<8.2.0",
3232
"scikit-learn>=1.5,<1.9",
33-
"pre-commit>=4.0.1,<4.6.0"
33+
"pre-commit>=4.0.1,<4.6.0",
34+
"pandas>=2.0",
35+
"polars>=1.0"
3436
]
3537

3638
[tool.setuptools.packages.find]
@@ -47,7 +49,9 @@ dev-dependencies = [
4749
"ruff>=0.12.2,<0.16.0",
4850
"sphinx>=7.3.7,<8.2.0",
4951
"scikit-learn>=1.5,<1.9",
50-
"pre-commit>=4.0.1,<4.6.0"
52+
"pre-commit>=4.0.1,<4.6.0",
53+
"pandas>=2.0",
54+
"polars>=1.0"
5155
]
5256

5357
[tool.ruff.lint]

tests/test_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import unittest
2+
import numpy as np
3+
import pandas as pd
4+
import polars as pl
5+
from dte_adj.util import _convert_to_ndarray
6+
7+
8+
class TestConvertToNdarray(unittest.TestCase):
9+
"""Test that _convert_to_ndarray correctly converts various array-like inputs."""
10+
11+
def test_ndarray(self):
12+
data = np.array([1, 2, 3])
13+
result = _convert_to_ndarray(data)
14+
self.assertIsInstance(result, np.ndarray)
15+
np.testing.assert_array_equal(result, data)
16+
17+
def test_ndarray_2d(self):
18+
data = np.array([[1, 2], [3, 4]])
19+
result = _convert_to_ndarray(data)
20+
self.assertIsInstance(result, np.ndarray)
21+
np.testing.assert_array_equal(result, data)
22+
23+
def test_pandas_series(self):
24+
data = pd.Series([1, 2, 3])
25+
result = _convert_to_ndarray(data)
26+
self.assertIsInstance(result, np.ndarray)
27+
np.testing.assert_array_equal(result, np.array([1, 2, 3]))
28+
29+
def test_pandas_dataframe(self):
30+
data = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
31+
result = _convert_to_ndarray(data)
32+
self.assertIsInstance(result, np.ndarray)
33+
np.testing.assert_array_equal(result, np.array([[1, 3], [2, 4]]))
34+
35+
def test_polars_series(self):
36+
data = pl.Series([1, 2, 3])
37+
result = _convert_to_ndarray(data)
38+
self.assertIsInstance(result, np.ndarray)
39+
np.testing.assert_array_equal(result, np.array([1, 2, 3]))
40+
41+
def test_polars_dataframe(self):
42+
data = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
43+
result = _convert_to_ndarray(data)
44+
self.assertIsInstance(result, np.ndarray)
45+
np.testing.assert_array_equal(result, np.array([[1, 3], [2, 4]]))
46+
47+
def test_list(self):
48+
data = [1, 2, 3]
49+
result = _convert_to_ndarray(data)
50+
self.assertIsInstance(result, np.ndarray)
51+
np.testing.assert_array_equal(result, np.array([1, 2, 3]))
52+
53+
def test_tuple(self):
54+
data = (1, 2, 3)
55+
result = _convert_to_ndarray(data)
56+
self.assertIsInstance(result, np.ndarray)
57+
np.testing.assert_array_equal(result, np.array([1, 2, 3]))

0 commit comments

Comments
 (0)