@@ -347,8 +347,6 @@ def fit(
347347 ValueError
348348 Missing columns, unbalanced panel, non-absorbing treatment,
349349 or PT-Post without a never-treated group.
350- NotImplementedError
351- If ``covariates`` and ``survey_design`` are both set.
352350 """
353351 self ._validate_params ()
354352
@@ -381,16 +379,6 @@ def fit(
381379
382380 # Bootstrap + survey supported via PSU-level multiplier bootstrap.
383381
384- # Guard covariates + survey (DR path does not yet thread survey weights)
385- if covariates is not None and len (covariates ) > 0 and resolved_survey is not None :
386- raise NotImplementedError (
387- "Survey weights with covariates are not yet supported for "
388- "EfficientDiD. The doubly robust covariate path does not "
389- "thread survey weights through nuisance estimation. "
390- "Use covariates=None with survey_design, or drop survey_design "
391- "when using covariates."
392- )
393-
394382 # Normalize empty covariates list to None (use nocov path)
395383 if covariates is not None and len (covariates ) == 0 :
396384 covariates = None
@@ -583,6 +571,7 @@ def fit(
583571 # Use the resolved survey's weights (already normalized per weight_type)
584572 # subset to unit level via _unit_first_panel_row (aligned to all_units)
585573 unit_level_weights = self ._unit_resolved_survey .weights
574+ self ._unit_level_weights = unit_level_weights
586575
587576 cohort_fractions : Dict [float , float ] = {}
588577 if unit_level_weights is not None :
@@ -617,6 +606,15 @@ def fit(
617606 stacklevel = 2 ,
618607 )
619608
609+ # Guard: never-treated with zero survey weight → no valid comparisons
610+ # Applies to both covariates (DR nuisance) and nocov (weighted means) paths
611+ if cohort_fractions .get (np .inf , 0.0 ) <= 0 and unit_level_weights is not None :
612+ raise ValueError (
613+ "Never-treated group has zero survey weight. EfficientDiD "
614+ "requires a never-treated control group with positive "
615+ "survey weight for estimation."
616+ )
617+
620618 # ----- Covariate preparation (if provided) -----
621619 covariate_matrix : Optional [np .ndarray ] = None
622620 m_hat_cache : Dict [Tuple , np .ndarray ] = {}
@@ -686,6 +684,15 @@ def fit(
686684 else :
687685 effective_p1_col = period_1_col
688686
687+ # Guard: skip cohorts with zero survey weight (all units zero-weighted)
688+ if cohort_fractions [g ] <= 0 :
689+ warnings .warn (
690+ f"Cohort { g } has zero survey weight; skipping." ,
691+ UserWarning ,
692+ stacklevel = 2 ,
693+ )
694+ continue
695+
689696 # Estimate all (g, t) cells including pre-treatment. Under PT-Post,
690697 # pre-treatment cells serve as placebo/pre-trend diagnostics, matching
691698 # the CallawaySantAnna implementation. Users filter to t >= g for
@@ -707,6 +714,15 @@ def fit(
707714 anticipation = self .anticipation ,
708715 )
709716
717+ # Filter out comparison pairs with zero survey weight
718+ if unit_level_weights is not None and pairs :
719+ pairs = [
720+ (gp , tpre ) for gp , tpre in pairs
721+ if np .sum (unit_level_weights [
722+ never_treated_mask if np .isinf (gp ) else cohort_masks [gp ]
723+ ]) > 0
724+ ]
725+
710726 if not pairs :
711727 warnings .warn (
712728 f"No valid comparison pairs for (g={ g } , t={ t } ). " "ATT will be NaN." ,
@@ -742,6 +758,7 @@ def fit(
742758 never_treated_mask ,
743759 t_col_val ,
744760 tpre_col_val ,
761+ unit_weights = unit_level_weights ,
745762 )
746763 # m_{g', tpre, 1}(X)
747764 key_gp_tpre = (gp , tpre_col_val , effective_p1_col )
@@ -755,6 +772,7 @@ def fit(
755772 gp_mask_for_reg ,
756773 tpre_col_val ,
757774 effective_p1_col ,
775+ unit_weights = unit_level_weights ,
758776 )
759777 # r_{g, inf}(X) and r_{g, g'}(X) via sieve (Eq 4.1-4.2)
760778 for comp in {np .inf , gp }:
@@ -770,6 +788,7 @@ def fit(
770788 k_max = self .sieve_k_max ,
771789 criterion = self .sieve_criterion ,
772790 ratio_clip = self .ratio_clip ,
791+ unit_weights = unit_level_weights ,
773792 )
774793
775794 # Per-unit DR generated outcomes: shape (n_units, H)
@@ -801,6 +820,7 @@ def fit(
801820 group_mask_s ,
802821 k_max = self .sieve_k_max ,
803822 criterion = self .sieve_criterion ,
823+ unit_weights = unit_level_weights ,
804824 )
805825
806826 # Conditional Omega*(X) with per-unit propensities (Eq 3.12)
@@ -817,14 +837,19 @@ def fit(
817837 covariate_matrix = covariate_matrix ,
818838 s_hat_cache = s_hat_cache ,
819839 bandwidth = self .kernel_bandwidth ,
840+ unit_weights = unit_level_weights ,
820841 )
821842
822843 # Per-unit weights: (n_units, H)
823844 per_unit_w = compute_per_unit_weights (omega_cond )
824845
825- # ATT = mean_i( w(X_i) @ gen_out[i] )
846+ # ATT = (survey-)weighted mean of per-unit DR scores
826847 if per_unit_w .shape [1 ] > 0 :
827- att_gt = float (np .mean (np .sum (per_unit_w * gen_out , axis = 1 )))
848+ per_unit_scores = np .sum (per_unit_w * gen_out , axis = 1 )
849+ if unit_level_weights is not None :
850+ att_gt = float (np .average (per_unit_scores , weights = unit_level_weights ))
851+ else :
852+ att_gt = float (np .mean (per_unit_scores ))
828853 else :
829854 att_gt = np .nan
830855
@@ -979,6 +1004,7 @@ def fit(
9791004 cluster_indices = unit_cluster_indices ,
9801005 n_clusters = n_clusters ,
9811006 resolved_survey = self ._unit_resolved_survey ,
1007+ unit_level_weights = self ._unit_level_weights ,
9821008 )
9831009 # Update estimates with bootstrap inference
9841010 overall_se = bootstrap_results .overall_att_se
@@ -1140,6 +1166,7 @@ def _compute_wif_contribution(
11401166 unit_cohorts : np .ndarray ,
11411167 cohort_fractions : Dict [float , float ],
11421168 n_units : int ,
1169+ unit_weights : Optional [np .ndarray ] = None ,
11431170 ) -> np .ndarray :
11441171 """Compute weight influence function correction (O(1) scale, matching EIF).
11451172
@@ -1159,6 +1186,9 @@ def _compute_wif_contribution(
11591186 ``{cohort: n_cohort / n}`` for each cohort.
11601187 n_units : int
11611188 Total number of units.
1189+ unit_weights : ndarray, shape (n_units,), optional
1190+ Survey weights at the unit level. When provided, uses the
1191+ survey-weighted WIF formula: IF_i(p_g) = (w_i * 1{G_i=g} - pg_k).
11621192
11631193 Returns
11641194 -------
@@ -1172,10 +1202,19 @@ def _compute_wif_contribution(
11721202 return np .zeros (n_units )
11731203
11741204 indicator = (unit_cohorts [:, None ] == groups_for_keepers [None , :]).astype (float )
1175- indicator_sum = np .sum (indicator - pg_keepers , axis = 1 )
1205+
1206+ if unit_weights is not None :
1207+ # Survey-weighted WIF (matches staggered_aggregation.py:392-401):
1208+ # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT (1{G_i=g} - pg_k)
1209+ weighted_indicator = indicator * unit_weights [:, None ]
1210+ indicator_diff = weighted_indicator - pg_keepers
1211+ indicator_sum = np .sum (indicator_diff , axis = 1 )
1212+ else :
1213+ indicator_diff = indicator - pg_keepers
1214+ indicator_sum = np .sum (indicator_diff , axis = 1 )
11761215
11771216 with np .errstate (divide = "ignore" , invalid = "ignore" , over = "ignore" ):
1178- if1 = ( indicator - pg_keepers ) / sum_pg
1217+ if1 = indicator_diff / sum_pg
11791218 if2 = np .outer (indicator_sum , pg_keepers ) / sum_pg ** 2
11801219 wif_matrix = if1 - if2
11811220 wif_contrib = wif_matrix @ effects
@@ -1232,13 +1271,34 @@ def _aggregate_overall(
12321271
12331272 # WIF correction: accounts for uncertainty in cohort-size weights
12341273 wif = self ._compute_wif_contribution (
1235- keepers , effects , unit_cohorts , cohort_fractions , n_units
1274+ keepers , effects , unit_cohorts , cohort_fractions , n_units ,
1275+ unit_weights = self ._unit_level_weights ,
12361276 )
1237- agg_eif_total = agg_eif + wif # both O(1) scale
1277+ # Compute SE: survey path uses score-level psi to avoid double-weighting
1278+ # (compute_survey_vcov applies w_i internally, which would double-weight
1279+ # the survey-weighted WIF term). Dispatch replicate vs TSL.
1280+ if self ._unit_resolved_survey is not None :
1281+ uw = self ._unit_level_weights
1282+ total_w = float (np .sum (uw ))
1283+ psi_total = uw * agg_eif / total_w + wif / total_w
1284+
1285+ if (hasattr (self ._unit_resolved_survey , 'uses_replicate_variance' )
1286+ and self ._unit_resolved_survey .uses_replicate_variance ):
1287+ from diff_diff .survey import compute_replicate_if_variance
1288+
1289+ variance , _ = compute_replicate_if_variance (
1290+ psi_total , self ._unit_resolved_survey
1291+ )
1292+ else :
1293+ from diff_diff .survey import compute_survey_if_variance
12381294
1239- # SE = sqrt(mean(EIF^2) / n) — standard IF-based SE
1240- # (dispatches to survey TSL or cluster-robust when active)
1241- se = self ._eif_se (agg_eif_total , n_units , cluster_indices , n_clusters )
1295+ variance = compute_survey_if_variance (
1296+ psi_total , self ._unit_resolved_survey
1297+ )
1298+ se = float (np .sqrt (max (variance , 0.0 ))) if np .isfinite (variance ) else np .nan
1299+ else :
1300+ agg_eif_total = agg_eif + wif
1301+ se = self ._eif_se (agg_eif_total , n_units , cluster_indices , n_clusters )
12421302
12431303 return overall_att , se
12441304
@@ -1324,15 +1384,37 @@ def _aggregate_event_study(
13241384 agg_eif += w [k ] * eif_by_gt [gt ]
13251385
13261386 # WIF correction for event-study aggregation
1387+ wif_e = np .zeros (n_units )
13271388 if unit_cohorts is not None :
13281389 es_keepers = [(g , t ) for (g , t ) in gt_pairs ]
13291390 es_effects = effs
1330- wif = self ._compute_wif_contribution (
1331- es_keepers , es_effects , unit_cohorts , cohort_fractions , n_units
1391+ wif_e = self ._compute_wif_contribution (
1392+ es_keepers , es_effects , unit_cohorts , cohort_fractions , n_units ,
1393+ unit_weights = self ._unit_level_weights ,
13321394 )
1333- agg_eif = agg_eif + wif
13341395
1335- agg_se = self ._eif_se (agg_eif , n_units , cluster_indices , n_clusters )
1396+ if self ._unit_resolved_survey is not None :
1397+ uw = self ._unit_level_weights
1398+ total_w = float (np .sum (uw ))
1399+ psi_total = uw * agg_eif / total_w + wif_e / total_w
1400+
1401+ if (hasattr (self ._unit_resolved_survey , 'uses_replicate_variance' )
1402+ and self ._unit_resolved_survey .uses_replicate_variance ):
1403+ from diff_diff .survey import compute_replicate_if_variance
1404+
1405+ variance , _ = compute_replicate_if_variance (
1406+ psi_total , self ._unit_resolved_survey
1407+ )
1408+ else :
1409+ from diff_diff .survey import compute_survey_if_variance
1410+
1411+ variance = compute_survey_if_variance (
1412+ psi_total , self ._unit_resolved_survey
1413+ )
1414+ agg_se = float (np .sqrt (max (variance , 0.0 ))) if np .isfinite (variance ) else np .nan
1415+ else :
1416+ agg_eif = agg_eif + wif_e
1417+ agg_se = self ._eif_se (agg_eif , n_units , cluster_indices , n_clusters )
13361418
13371419 t_stat , p_val , ci = safe_inference (
13381420 agg_eff , agg_se , alpha = self .alpha , df = self ._survey_df
0 commit comments