Skip to content

Commit 1f6ade8

Browse files
authored
Merge pull request #264 from igerber/edid-survey-covariates
Add survey design support to EfficientDiD covariates (DR) path
2 parents c184e14 + ca9522c commit 1f6ade8

8 files changed

Lines changed: 594 additions & 110 deletions

File tree

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Deferred items from PR reviews that were not addressed before merge.
6969
| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low |
7070
| ContinuousDiD event-study aggregation anticipation filter — **Resolved**. `_aggregate_event_study()` now filters `e < -anticipation` when `anticipation > 0`, matching CallawaySantAnna behavior. Bootstrap paths also filtered. | `continuous_did.py` | #226 | Resolved |
7171
| Survey design resolution/collapse patterns are inconsistent across panel estimators — ContinuousDiD rebuilds unit-level design in SE code, EfficientDiD builds once in fit(), StackedDiD re-resolves on stacked data; extract shared helpers for panel-to-unit collapse, post-filter re-resolution, and metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low |
72+
| Survey-weighted Silverman bandwidth in EfficientDiD conditional Omega*`_silverman_bandwidth()` uses unweighted mean/std for bandwidth selection; survey-weighted statistics would better reflect the population distribution but is a second-order refinement | `efficient_did_covariates.py` || Low |
7273
| Survey metadata formatting dedup — **Resolved**. Extracted `_format_survey_block()` helper in `results.py`, replaced 13 occurrences across 11 files. | `results.py` + 10 results files || Resolved |
7374
| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup (panel pivoting, absorbing-state validation, first-treatment detection, effective rank, NaN warnings). Both bootstrap methods also duplicate the stratified resampling loop. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` || Low |
7475
| StaggeredTripleDifference R cross-validation: CSV fixtures not committed (gitignored); tests skip without local R + triplediff. Commit fixtures or generate deterministically. | `tests/test_methodology_staggered_triple_diff.py` | #245 | Medium |

diff_diff/efficient_did.py

Lines changed: 107 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,6 @@ def fit(
347347
ValueError
348348
Missing columns, unbalanced panel, non-absorbing treatment,
349349
or PT-Post without a never-treated group.
350-
NotImplementedError
351-
If ``covariates`` and ``survey_design`` are both set.
352350
"""
353351
self._validate_params()
354352

@@ -381,16 +379,6 @@ def fit(
381379

382380
# Bootstrap + survey supported via PSU-level multiplier bootstrap.
383381

384-
# Guard covariates + survey (DR path does not yet thread survey weights)
385-
if covariates is not None and len(covariates) > 0 and resolved_survey is not None:
386-
raise NotImplementedError(
387-
"Survey weights with covariates are not yet supported for "
388-
"EfficientDiD. The doubly robust covariate path does not "
389-
"thread survey weights through nuisance estimation. "
390-
"Use covariates=None with survey_design, or drop survey_design "
391-
"when using covariates."
392-
)
393-
394382
# Normalize empty covariates list to None (use nocov path)
395383
if covariates is not None and len(covariates) == 0:
396384
covariates = None
@@ -583,6 +571,7 @@ def fit(
583571
# Use the resolved survey's weights (already normalized per weight_type)
584572
# subset to unit level via _unit_first_panel_row (aligned to all_units)
585573
unit_level_weights = self._unit_resolved_survey.weights
574+
self._unit_level_weights = unit_level_weights
586575

587576
cohort_fractions: Dict[float, float] = {}
588577
if unit_level_weights is not None:
@@ -617,6 +606,15 @@ def fit(
617606
stacklevel=2,
618607
)
619608

609+
# Guard: never-treated with zero survey weight → no valid comparisons
610+
# Applies to both covariates (DR nuisance) and nocov (weighted means) paths
611+
if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
612+
raise ValueError(
613+
"Never-treated group has zero survey weight. EfficientDiD "
614+
"requires a never-treated control group with positive "
615+
"survey weight for estimation."
616+
)
617+
620618
# ----- Covariate preparation (if provided) -----
621619
covariate_matrix: Optional[np.ndarray] = None
622620
m_hat_cache: Dict[Tuple, np.ndarray] = {}
@@ -686,6 +684,15 @@ def fit(
686684
else:
687685
effective_p1_col = period_1_col
688686

687+
# Guard: skip cohorts with zero survey weight (all units zero-weighted)
688+
if cohort_fractions[g] <= 0:
689+
warnings.warn(
690+
f"Cohort {g} has zero survey weight; skipping.",
691+
UserWarning,
692+
stacklevel=2,
693+
)
694+
continue
695+
689696
# Estimate all (g, t) cells including pre-treatment. Under PT-Post,
690697
# pre-treatment cells serve as placebo/pre-trend diagnostics, matching
691698
# the CallawaySantAnna implementation. Users filter to t >= g for
@@ -707,6 +714,15 @@ def fit(
707714
anticipation=self.anticipation,
708715
)
709716

717+
# Filter out comparison pairs with zero survey weight
718+
if unit_level_weights is not None and pairs:
719+
pairs = [
720+
(gp, tpre) for gp, tpre in pairs
721+
if np.sum(unit_level_weights[
722+
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
723+
]) > 0
724+
]
725+
710726
if not pairs:
711727
warnings.warn(
712728
f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
@@ -742,6 +758,7 @@ def fit(
742758
never_treated_mask,
743759
t_col_val,
744760
tpre_col_val,
761+
unit_weights=unit_level_weights,
745762
)
746763
# m_{g', tpre, 1}(X)
747764
key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
@@ -755,6 +772,7 @@ def fit(
755772
gp_mask_for_reg,
756773
tpre_col_val,
757774
effective_p1_col,
775+
unit_weights=unit_level_weights,
758776
)
759777
# r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
760778
for comp in {np.inf, gp}:
@@ -770,6 +788,7 @@ def fit(
770788
k_max=self.sieve_k_max,
771789
criterion=self.sieve_criterion,
772790
ratio_clip=self.ratio_clip,
791+
unit_weights=unit_level_weights,
773792
)
774793

775794
# Per-unit DR generated outcomes: shape (n_units, H)
@@ -801,6 +820,7 @@ def fit(
801820
group_mask_s,
802821
k_max=self.sieve_k_max,
803822
criterion=self.sieve_criterion,
823+
unit_weights=unit_level_weights,
804824
)
805825

806826
# Conditional Omega*(X) with per-unit propensities (Eq 3.12)
@@ -817,14 +837,19 @@ def fit(
817837
covariate_matrix=covariate_matrix,
818838
s_hat_cache=s_hat_cache,
819839
bandwidth=self.kernel_bandwidth,
840+
unit_weights=unit_level_weights,
820841
)
821842

822843
# Per-unit weights: (n_units, H)
823844
per_unit_w = compute_per_unit_weights(omega_cond)
824845

825-
# ATT = mean_i( w(X_i) @ gen_out[i] )
846+
# ATT = (survey-)weighted mean of per-unit DR scores
826847
if per_unit_w.shape[1] > 0:
827-
att_gt = float(np.mean(np.sum(per_unit_w * gen_out, axis=1)))
848+
per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
849+
if unit_level_weights is not None:
850+
att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
851+
else:
852+
att_gt = float(np.mean(per_unit_scores))
828853
else:
829854
att_gt = np.nan
830855

@@ -979,6 +1004,7 @@ def fit(
9791004
cluster_indices=unit_cluster_indices,
9801005
n_clusters=n_clusters,
9811006
resolved_survey=self._unit_resolved_survey,
1007+
unit_level_weights=self._unit_level_weights,
9821008
)
9831009
# Update estimates with bootstrap inference
9841010
overall_se = bootstrap_results.overall_att_se
@@ -1140,6 +1166,7 @@ def _compute_wif_contribution(
11401166
unit_cohorts: np.ndarray,
11411167
cohort_fractions: Dict[float, float],
11421168
n_units: int,
1169+
unit_weights: Optional[np.ndarray] = None,
11431170
) -> np.ndarray:
11441171
"""Compute weight influence function correction (O(1) scale, matching EIF).
11451172
@@ -1159,6 +1186,9 @@ def _compute_wif_contribution(
11591186
``{cohort: n_cohort / n}`` for each cohort.
11601187
n_units : int
11611188
Total number of units.
1189+
unit_weights : ndarray, shape (n_units,), optional
1190+
Survey weights at the unit level. When provided, uses the
1191+
survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
11621192
11631193
Returns
11641194
-------
@@ -1172,10 +1202,19 @@ def _compute_wif_contribution(
11721202
return np.zeros(n_units)
11731203

11741204
indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
1175-
indicator_sum = np.sum(indicator - pg_keepers, axis=1)
1205+
1206+
if unit_weights is not None:
1207+
# Survey-weighted WIF (matches staggered_aggregation.py:392-401):
1208+
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
1209+
weighted_indicator = indicator * unit_weights[:, None]
1210+
indicator_diff = weighted_indicator - pg_keepers
1211+
indicator_sum = np.sum(indicator_diff, axis=1)
1212+
else:
1213+
indicator_diff = indicator - pg_keepers
1214+
indicator_sum = np.sum(indicator_diff, axis=1)
11761215

11771216
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
1178-
if1 = (indicator - pg_keepers) / sum_pg
1217+
if1 = indicator_diff / sum_pg
11791218
if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
11801219
wif_matrix = if1 - if2
11811220
wif_contrib = wif_matrix @ effects
@@ -1232,13 +1271,34 @@ def _aggregate_overall(
12321271

12331272
# WIF correction: accounts for uncertainty in cohort-size weights
12341273
wif = self._compute_wif_contribution(
1235-
keepers, effects, unit_cohorts, cohort_fractions, n_units
1274+
keepers, effects, unit_cohorts, cohort_fractions, n_units,
1275+
unit_weights=self._unit_level_weights,
12361276
)
1237-
agg_eif_total = agg_eif + wif # both O(1) scale
1277+
# Compute SE: survey path uses score-level psi to avoid double-weighting
1278+
# (compute_survey_vcov applies w_i internally, which would double-weight
1279+
# the survey-weighted WIF term). Dispatch replicate vs TSL.
1280+
if self._unit_resolved_survey is not None:
1281+
uw = self._unit_level_weights
1282+
total_w = float(np.sum(uw))
1283+
psi_total = uw * agg_eif / total_w + wif / total_w
1284+
1285+
if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1286+
and self._unit_resolved_survey.uses_replicate_variance):
1287+
from diff_diff.survey import compute_replicate_if_variance
1288+
1289+
variance, _ = compute_replicate_if_variance(
1290+
psi_total, self._unit_resolved_survey
1291+
)
1292+
else:
1293+
from diff_diff.survey import compute_survey_if_variance
12381294

1239-
# SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
1240-
# (dispatches to survey TSL or cluster-robust when active)
1241-
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
1295+
variance = compute_survey_if_variance(
1296+
psi_total, self._unit_resolved_survey
1297+
)
1298+
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1299+
else:
1300+
agg_eif_total = agg_eif + wif
1301+
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
12421302

12431303
return overall_att, se
12441304

@@ -1324,15 +1384,37 @@ def _aggregate_event_study(
13241384
agg_eif += w[k] * eif_by_gt[gt]
13251385

13261386
# WIF correction for event-study aggregation
1387+
wif_e = np.zeros(n_units)
13271388
if unit_cohorts is not None:
13281389
es_keepers = [(g, t) for (g, t) in gt_pairs]
13291390
es_effects = effs
1330-
wif = self._compute_wif_contribution(
1331-
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units
1391+
wif_e = self._compute_wif_contribution(
1392+
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
1393+
unit_weights=self._unit_level_weights,
13321394
)
1333-
agg_eif = agg_eif + wif
13341395

1335-
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
1396+
if self._unit_resolved_survey is not None:
1397+
uw = self._unit_level_weights
1398+
total_w = float(np.sum(uw))
1399+
psi_total = uw * agg_eif / total_w + wif_e / total_w
1400+
1401+
if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
1402+
and self._unit_resolved_survey.uses_replicate_variance):
1403+
from diff_diff.survey import compute_replicate_if_variance
1404+
1405+
variance, _ = compute_replicate_if_variance(
1406+
psi_total, self._unit_resolved_survey
1407+
)
1408+
else:
1409+
from diff_diff.survey import compute_survey_if_variance
1410+
1411+
variance = compute_survey_if_variance(
1412+
psi_total, self._unit_resolved_survey
1413+
)
1414+
agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
1415+
else:
1416+
agg_eif = agg_eif + wif_e
1417+
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
13361418

13371419
t_stat, p_val, ci = safe_inference(
13381420
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df

diff_diff/efficient_did_bootstrap.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _run_multiplier_bootstrap(
6363
cluster_indices: Optional[np.ndarray] = None,
6464
n_clusters: Optional[int] = None,
6565
resolved_survey: object = None,
66+
unit_level_weights: Optional[np.ndarray] = None,
6667
) -> EDiDBootstrapResults:
6768
"""Run multiplier bootstrap on stored EIF values.
6869
@@ -136,11 +137,18 @@ def _run_multiplier_bootstrap(
136137
original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])
137138

138139
# Perturbed ATTs: (n_bootstrap, n_gt)
140+
# Under survey design, perturb survey-score object w_i * eif_i / sum(w)
141+
# to match the analytical variance convention (compute_survey_if_variance).
139142
bootstrap_atts = np.zeros((self.n_bootstrap, n_gt))
140143
for j, gt in enumerate(gt_pairs):
141144
eif_gt = eif_by_gt[gt] # shape (n_units,)
142145
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
143-
perturbation = (all_weights @ eif_gt) / n_units
146+
if unit_level_weights is not None:
147+
total_w = float(np.sum(unit_level_weights))
148+
eif_scaled = unit_level_weights * eif_gt / total_w
149+
perturbation = all_weights @ eif_scaled
150+
else:
151+
perturbation = (all_weights @ eif_gt) / n_units
144152
bootstrap_atts[:, j] = original_atts[j] + perturbation
145153

146154
# Post-treatment mask — also exclude NaN effects

0 commit comments

Comments
 (0)