1+ from __future__ import annotations
2+
13import numpy as np
24from typing import Tuple
35from dte_adj .stratified import (
46 SimpleStratifiedDistributionEstimator ,
57 AdjustedStratifiedDistributionEstimator ,
68)
7- from dte_adj .util import compute_ldte , compute_lpte
9+ from dte_adj .util import ArrayLike , compute_ldte , compute_lpte , _convert_to_ndarray
810
911
1012class SimpleLocalDistributionEstimator (SimpleStratifiedDistributionEstimator ):
@@ -28,25 +30,26 @@ def __init__(self):
2830
2931 def fit (
3032 self ,
31- covariates : np . ndarray ,
32- treatment_arms : np . ndarray ,
33- treatment_indicator : np . ndarray ,
34- outcomes : np . ndarray ,
35- strata : np . ndarray ,
36- ) -> " SimpleLocalDistributionEstimator" :
33+ covariates : ArrayLike ,
34+ treatment_arms : ArrayLike ,
35+ treatment_indicator : ArrayLike ,
36+ outcomes : ArrayLike ,
37+ strata : ArrayLike ,
38+ ) -> SimpleLocalDistributionEstimator :
3739 """
3840 Train the SimpleLocalDistributionEstimator.
3941
4042 Args:
41- covariates (np.ndarray) : Pre-treatment covariates.
42- treatment_arms (np.ndarray) : Treatment assignment variable (Z).
43- treatment_indicator (np.ndarray) : Treatment indicator variable (D).
44- outcomes (np.ndarray) : Scalar-valued observed outcome.
45- strata (np.ndarray) : Stratum indicators.
43+ covariates: Pre-treatment covariates.
44+ treatment_arms: Treatment assignment variable (Z).
45+ treatment_indicator: Treatment indicator variable (D).
46+ outcomes: Scalar-valued observed outcome.
47+ strata: Stratum indicators.
4648
4749 Returns:
4850 SimpleLocalDistributionEstimator: The fitted estimator.
4951 """
52+ treatment_indicator = _convert_to_ndarray (treatment_indicator )
5053 super ().fit (covariates , treatment_arms , outcomes , strata )
5154 self .treatment_indicator = treatment_indicator
5255
@@ -196,25 +199,26 @@ class AdjustedLocalDistributionEstimator(AdjustedStratifiedDistributionEstimator
196199
197200 def fit (
198201 self ,
199- covariates : np . ndarray ,
200- treatment_arms : np . ndarray ,
201- treatment_indicator : np . ndarray ,
202- outcomes : np . ndarray ,
203- strata : np . ndarray ,
204- ) -> " AdjustedLocalDistributionEstimator" :
202+ covariates : ArrayLike ,
203+ treatment_arms : ArrayLike ,
204+ treatment_indicator : ArrayLike ,
205+ outcomes : ArrayLike ,
206+ strata : ArrayLike ,
207+ ) -> AdjustedLocalDistributionEstimator :
205208 """
206209 Train the AdjustedLocalDistributionEstimator.
207210
208211 Args:
209- covariates (np.ndarray) : Pre-treatment covariates.
210- treatment_arms (np.ndarray) : Treatment assignment variable (Z).
211- treatment_indicator (np.ndarray) : Treatment indicator variable (D).
212- outcomes (np.ndarray) : Scalar-valued observed outcome.
213- strata (np.ndarray) : Stratum indicators.
212+ covariates: Pre-treatment covariates.
213+ treatment_arms: Treatment assignment variable (Z).
214+ treatment_indicator: Treatment indicator variable (D).
215+ outcomes: Scalar-valued observed outcome.
216+ strata: Stratum indicators.
214217
215218 Returns:
216219 AdjustedLocalDistributionEstimator: The fitted estimator.
217220 """
221+ treatment_indicator = _convert_to_ndarray (treatment_indicator )
218222 super ().fit (covariates , treatment_arms , outcomes , strata )
219223 self .treatment_indicator = treatment_indicator
220224
0 commit comments