Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Deferred items from PR reviews that were not addressed before merge.
| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low |
| ContinuousDiD event-study aggregation anticipation filter — **Resolved**. `_aggregate_event_study()` now filters `e < -anticipation` when `anticipation > 0`, matching CallawaySantAnna behavior. Bootstrap paths also filtered. | `continuous_did.py` | #226 | Resolved |
| Survey design resolution/collapse patterns are inconsistent across panel estimators — ContinuousDiD rebuilds unit-level design in SE code, EfficientDiD builds once in fit(), StackedDiD re-resolves on stacked data; extract shared helpers for panel-to-unit collapse, post-filter re-resolution, and metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low |
| Survey-weighted Silverman bandwidth in EfficientDiD conditional Omega* — `_silverman_bandwidth()` uses unweighted mean/std for bandwidth selection; survey-weighted statistics would better reflect the population distribution but is a second-order refinement | `efficient_did_covariates.py` | — | Low |
| Survey metadata formatting dedup — **Resolved**. Extracted `_format_survey_block()` helper in `results.py`, replaced 13 occurrences across 11 files. | `results.py` + 10 results files | — | Resolved |
| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup (panel pivoting, absorbing-state validation, first-treatment detection, effective rank, NaN warnings). Both bootstrap methods also duplicate the stratified resampling loop. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` | — | Low |
| StaggeredTripleDifference R cross-validation: CSV fixtures not committed (gitignored); tests skip without local R + triplediff. Commit fixtures or generate deterministically. | `tests/test_methodology_staggered_triple_diff.py` | #245 | Medium |
Expand Down
132 changes: 107 additions & 25 deletions diff_diff/efficient_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,6 @@ def fit(
ValueError
Missing columns, unbalanced panel, non-absorbing treatment,
or PT-Post without a never-treated group.
NotImplementedError
If ``covariates`` and ``survey_design`` are both set.
"""
self._validate_params()

Expand Down Expand Up @@ -381,16 +379,6 @@ def fit(

# Bootstrap + survey supported via PSU-level multiplier bootstrap.

# Guard covariates + survey (DR path does not yet thread survey weights)
if covariates is not None and len(covariates) > 0 and resolved_survey is not None:
raise NotImplementedError(
"Survey weights with covariates are not yet supported for "
"EfficientDiD. The doubly robust covariate path does not "
"thread survey weights through nuisance estimation. "
"Use covariates=None with survey_design, or drop survey_design "
"when using covariates."
)

# Normalize empty covariates list to None (use nocov path)
if covariates is not None and len(covariates) == 0:
covariates = None
Expand Down Expand Up @@ -583,6 +571,7 @@ def fit(
# Use the resolved survey's weights (already normalized per weight_type)
# subset to unit level via _unit_first_panel_row (aligned to all_units)
unit_level_weights = self._unit_resolved_survey.weights
self._unit_level_weights = unit_level_weights

cohort_fractions: Dict[float, float] = {}
if unit_level_weights is not None:
Expand Down Expand Up @@ -617,6 +606,15 @@ def fit(
stacklevel=2,
)

# Guard: never-treated with zero survey weight → no valid comparisons
# Applies to both covariates (DR nuisance) and nocov (weighted means) paths
if cohort_fractions.get(np.inf, 0.0) <= 0 and unit_level_weights is not None:
raise ValueError(
"Never-treated group has zero survey weight. EfficientDiD "
"requires a never-treated control group with positive "
"survey weight for estimation."
)

# ----- Covariate preparation (if provided) -----
covariate_matrix: Optional[np.ndarray] = None
m_hat_cache: Dict[Tuple, np.ndarray] = {}
Expand Down Expand Up @@ -686,6 +684,15 @@ def fit(
else:
effective_p1_col = period_1_col

# Guard: skip cohorts with zero survey weight (all units zero-weighted)
if cohort_fractions[g] <= 0:
warnings.warn(
f"Cohort {g} has zero survey weight; skipping.",
UserWarning,
stacklevel=2,
)
continue

# Estimate all (g, t) cells including pre-treatment. Under PT-Post,
# pre-treatment cells serve as placebo/pre-trend diagnostics, matching
# the CallawaySantAnna implementation. Users filter to t >= g for
Expand All @@ -707,6 +714,15 @@ def fit(
anticipation=self.anticipation,
)

# Filter out comparison pairs with zero survey weight
if unit_level_weights is not None and pairs:
pairs = [
(gp, tpre) for gp, tpre in pairs
if np.sum(unit_level_weights[
never_treated_mask if np.isinf(gp) else cohort_masks[gp]
]) > 0
]

if not pairs:
warnings.warn(
f"No valid comparison pairs for (g={g}, t={t}). " "ATT will be NaN.",
Expand Down Expand Up @@ -742,6 +758,7 @@ def fit(
never_treated_mask,
t_col_val,
tpre_col_val,
unit_weights=unit_level_weights,
)
# m_{g', tpre, 1}(X)
key_gp_tpre = (gp, tpre_col_val, effective_p1_col)
Expand All @@ -755,6 +772,7 @@ def fit(
gp_mask_for_reg,
tpre_col_val,
effective_p1_col,
unit_weights=unit_level_weights,
)
# r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
for comp in {np.inf, gp}:
Expand All @@ -770,6 +788,7 @@ def fit(
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
ratio_clip=self.ratio_clip,
unit_weights=unit_level_weights,
)

# Per-unit DR generated outcomes: shape (n_units, H)
Expand Down Expand Up @@ -801,6 +820,7 @@ def fit(
group_mask_s,
k_max=self.sieve_k_max,
criterion=self.sieve_criterion,
unit_weights=unit_level_weights,
)

# Conditional Omega*(X) with per-unit propensities (Eq 3.12)
Expand All @@ -817,14 +837,19 @@ def fit(
covariate_matrix=covariate_matrix,
s_hat_cache=s_hat_cache,
bandwidth=self.kernel_bandwidth,
unit_weights=unit_level_weights,
)

# Per-unit weights: (n_units, H)
per_unit_w = compute_per_unit_weights(omega_cond)

# ATT = mean_i( w(X_i) @ gen_out[i] )
# ATT = (survey-)weighted mean of per-unit DR scores
if per_unit_w.shape[1] > 0:
att_gt = float(np.mean(np.sum(per_unit_w * gen_out, axis=1)))
per_unit_scores = np.sum(per_unit_w * gen_out, axis=1)
if unit_level_weights is not None:
att_gt = float(np.average(per_unit_scores, weights=unit_level_weights))
else:
att_gt = float(np.mean(per_unit_scores))
else:
att_gt = np.nan

Expand Down Expand Up @@ -979,6 +1004,7 @@ def fit(
cluster_indices=unit_cluster_indices,
n_clusters=n_clusters,
resolved_survey=self._unit_resolved_survey,
unit_level_weights=self._unit_level_weights,
)
# Update estimates with bootstrap inference
overall_se = bootstrap_results.overall_att_se
Expand Down Expand Up @@ -1140,6 +1166,7 @@ def _compute_wif_contribution(
unit_cohorts: np.ndarray,
cohort_fractions: Dict[float, float],
n_units: int,
unit_weights: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Compute weight influence function correction (O(1) scale, matching EIF).

Expand All @@ -1159,6 +1186,9 @@ def _compute_wif_contribution(
``{cohort: n_cohort / n}`` for each cohort.
n_units : int
Total number of units.
unit_weights : ndarray, shape (n_units,), optional
Survey weights at the unit level. When provided, uses the
survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).

Returns
-------
Expand All @@ -1172,10 +1202,19 @@ def _compute_wif_contribution(
return np.zeros(n_units)

indicator = (unit_cohorts[:, None] == groups_for_keepers[None, :]).astype(float)
indicator_sum = np.sum(indicator - pg_keepers, axis=1)

if unit_weights is not None:
# Survey-weighted WIF (matches staggered_aggregation.py:392-401):
# IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
weighted_indicator = indicator * unit_weights[:, None]
indicator_diff = weighted_indicator - pg_keepers
indicator_sum = np.sum(indicator_diff, axis=1)
else:
indicator_diff = indicator - pg_keepers
indicator_sum = np.sum(indicator_diff, axis=1)

with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
if1 = (indicator - pg_keepers) / sum_pg
if1 = indicator_diff / sum_pg
if2 = np.outer(indicator_sum, pg_keepers) / sum_pg**2
wif_matrix = if1 - if2
wif_contrib = wif_matrix @ effects
Expand Down Expand Up @@ -1232,13 +1271,34 @@ def _aggregate_overall(

# WIF correction: accounts for uncertainty in cohort-size weights
wif = self._compute_wif_contribution(
keepers, effects, unit_cohorts, cohort_fractions, n_units
keepers, effects, unit_cohorts, cohort_fractions, n_units,
unit_weights=self._unit_level_weights,
)
agg_eif_total = agg_eif + wif # both O(1) scale
# Compute SE: survey path uses score-level psi to avoid double-weighting
# (compute_survey_vcov applies w_i internally, which would double-weight
# the survey-weighted WIF term). Dispatch replicate vs TSL.
if self._unit_resolved_survey is not None:
uw = self._unit_level_weights
total_w = float(np.sum(uw))
psi_total = uw * agg_eif / total_w + wif / total_w

if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
and self._unit_resolved_survey.uses_replicate_variance):
from diff_diff.survey import compute_replicate_if_variance

variance, _ = compute_replicate_if_variance(
psi_total, self._unit_resolved_survey
)
else:
from diff_diff.survey import compute_survey_if_variance

# SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
# (dispatches to survey TSL or cluster-robust when active)
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)
variance = compute_survey_if_variance(
psi_total, self._unit_resolved_survey
)
se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
agg_eif_total = agg_eif + wif
se = self._eif_se(agg_eif_total, n_units, cluster_indices, n_clusters)

return overall_att, se

Expand Down Expand Up @@ -1324,15 +1384,37 @@ def _aggregate_event_study(
agg_eif += w[k] * eif_by_gt[gt]

# WIF correction for event-study aggregation
wif_e = np.zeros(n_units)
if unit_cohorts is not None:
es_keepers = [(g, t) for (g, t) in gt_pairs]
es_effects = effs
wif = self._compute_wif_contribution(
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units
wif_e = self._compute_wif_contribution(
es_keepers, es_effects, unit_cohorts, cohort_fractions, n_units,
unit_weights=self._unit_level_weights,
)
agg_eif = agg_eif + wif

agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)
if self._unit_resolved_survey is not None:
uw = self._unit_level_weights
total_w = float(np.sum(uw))
psi_total = uw * agg_eif / total_w + wif_e / total_w

if (hasattr(self._unit_resolved_survey, 'uses_replicate_variance')
and self._unit_resolved_survey.uses_replicate_variance):
from diff_diff.survey import compute_replicate_if_variance

variance, _ = compute_replicate_if_variance(
psi_total, self._unit_resolved_survey
)
else:
from diff_diff.survey import compute_survey_if_variance

variance = compute_survey_if_variance(
psi_total, self._unit_resolved_survey
)
agg_se = float(np.sqrt(max(variance, 0.0))) if np.isfinite(variance) else np.nan
else:
agg_eif = agg_eif + wif_e
agg_se = self._eif_se(agg_eif, n_units, cluster_indices, n_clusters)

t_stat, p_val, ci = safe_inference(
agg_eff, agg_se, alpha=self.alpha, df=self._survey_df
Expand Down
10 changes: 9 additions & 1 deletion diff_diff/efficient_did_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _run_multiplier_bootstrap(
cluster_indices: Optional[np.ndarray] = None,
n_clusters: Optional[int] = None,
resolved_survey: object = None,
unit_level_weights: Optional[np.ndarray] = None,
) -> EDiDBootstrapResults:
"""Run multiplier bootstrap on stored EIF values.

Expand Down Expand Up @@ -136,11 +137,18 @@ def _run_multiplier_bootstrap(
original_atts = np.array([group_time_effects[gt]["effect"] for gt in gt_pairs])

# Perturbed ATTs: (n_bootstrap, n_gt)
# Under survey design, perturb survey-score object w_i * eif_i / sum(w)
# to match the analytical variance convention (compute_survey_if_variance).
bootstrap_atts = np.zeros((self.n_bootstrap, n_gt))
for j, gt in enumerate(gt_pairs):
eif_gt = eif_by_gt[gt] # shape (n_units,)
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
perturbation = (all_weights @ eif_gt) / n_units
if unit_level_weights is not None:
total_w = float(np.sum(unit_level_weights))
eif_scaled = unit_level_weights * eif_gt / total_w
perturbation = all_weights @ eif_scaled
else:
perturbation = (all_weights @ eif_gt) / n_units
bootstrap_atts[:, j] = original_atts[j] + perturbation

# Post-treatment mask — also exclude NaN effects
Expand Down
Loading
Loading