From 04bd26379a366091386d672772f4ec3a523cfc60 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 28 Mar 2026 19:32:13 -0400 Subject: [PATCH 01/19] Survey Phase 7: CS IPW/DR covariates, repeated cross-sections, HonestDiD survey variance Phase 7a: Remove NotImplementedError gate for IPW/DR + covariates + survey. Add DRDID panel nuisance IF corrections (PS + OR) for both survey and non-survey DR paths. Extract _safe_inv helper for matrix inversions. Phase 7d: Thread survey df through HonestDiD for t-distribution critical values. Compute full event-study VCV from influence function vectors. Add event_study_vcov to CallawaySantAnnaResults. Phase 7b: Add panel=False for repeated cross-section support in CallawaySantAnna. New _precompute_structures_rc, _compute_att_gt_rc, and three RC estimation methods (reg, ipw, dr) with covariates and survey weights. Canonical index abstraction in aggregation/bootstrap. RCS data generator in generate_staggered_data(panel=False). Co-Authored-By: Claude Opus 4.6 (1M context) --- ROADMAP.md | 33 +- TODO.md | 12 +- diff_diff/honest_did.py | 117 +++- diff_diff/prep_dgp.py | 55 ++ diff_diff/staggered.py | 1012 ++++++++++++++++++++++++++-- diff_diff/staggered_aggregation.py | 168 +++-- diff_diff/staggered_bootstrap.py | 29 +- diff_diff/staggered_results.py | 2 + docs/methodology/REGISTRY.md | 7 +- docs/survey-roadmap.md | 123 +++- tests/test_honest_did.py | 142 +++- tests/test_staggered_rc.py | 351 ++++++++++ tests/test_survey_phase4.py | 69 +- tests/test_survey_phase7a.py | 389 +++++++++++ 14 files changed, 2327 insertions(+), 182 deletions(-) create mode 100644 tests/test_staggered_rc.py create mode 100644 tests/test_survey_phase7a.py diff --git a/ROADMAP.md b/ROADMAP.md index 07b9b9e9..2c0f8682 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -8,7 +8,7 @@ For past changes and release history, see [CHANGELOG.md](CHANGELOG.md). ## Current Status -diff-diff v2.6.0 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis: +diff-diff v2.7.5 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis, plus **unique survey support** — design-based variance estimation (Taylor linearization, replicate weights) integrated across all estimators. No R or Python package offers this combination: - **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing et al. 2024), Continuous DiD (Callaway, Goodman-Bacon & Sant'Anna 2024) - **Valid inference**: Robust SEs, cluster SEs, wild bootstrap, multiplier bootstrap, placebo-based variance @@ -16,11 +16,34 @@ diff-diff v2.6.0 is a **production-ready** DiD library with feature parity with - **Sensitivity analysis**: Honest DiD (Rambachan-Roth), Pre-trends power analysis (Roth 2022) - **Study design**: Power analysis tools - **Data utilities**: Real-world datasets (Card-Krueger, Castle Doctrine, Divorce Laws, MPDTA), DGP functions for all supported designs +- **Survey support**: Full `SurveyDesign` with strata, PSU, FPC, weight types, replicate weights (BRR/Fay/JK1/JKn), Taylor linearization, DEFF diagnostics, subpopulation analysis — integrated across all estimators (see [survey-roadmap.md](docs/survey-roadmap.md)) - **Performance**: Optional Rust backend for accelerated computation; faster than R at scale (see [CHANGELOG.md](CHANGELOG.md) for benchmarks) --- -## Near-Term Enhancements (v2.7) +## Near-Term Enhancements (v2.8) + +### Survey Phase 7: Completing the Survey Story + +Close the remaining gaps for practitioners using major population surveys +(ACS, CPS, BRFSS, MEPS). See [survey-roadmap.md](docs/survey-roadmap.md) for +full details. + +- **CS Covariates + IPW/DR + Survey** *(High priority)*: Implement DRDID + nuisance IF corrections under survey weights. Currently the recommended DR + method raises `NotImplementedError` with covariates + survey. This is the + most commonly needed path in applied work (Medicaid expansion, minimum wage). +- **Repeated Cross-Sections** *(High priority)*: `panel=False` support for + CallawaySantAnna, enabling analysis of surveys that don't track units over + time (BRFSS, ACS annual, CPS monthly). Uses cross-sectional DRDID + (Sant'Anna & Zhao 2020, Section 4). +- **Survey-Aware DiD Tutorial** *(High priority)*: Jupyter notebook + demonstrating the full workflow with realistic survey data. diff-diff is + the only package (R or Python) with design-based variance for modern DiD + — this makes that capability discoverable. +- **HonestDiD + Survey Variance** *(Medium priority)*: Pass survey vcov + (TSL or replicate) into sensitivity analysis instead of cluster-robust vcov, + so sensitivity bounds respect the same variance structure as main estimates. ### Staggered Triple Difference (DDD) @@ -32,12 +55,6 @@ Extend the existing `TripleDifference` estimator to handle staggered adoption se **Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). *Working Paper*. R package: `triplediff`. -### Enhanced Visualization - -- Synthetic control weight visualization (bar chart of unit weights) -- Treatment adoption "staircase" plot for staggered designs -- Interactive plots with plotly backend option - --- ## Medium-Term Enhancements diff --git a/TODO.md b/TODO.md index f9f3dc16..d84ff2da 100644 --- a/TODO.md +++ b/TODO.md @@ -54,7 +54,7 @@ Deferred items from PR reviews that were not addressed before merge. | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | | Replicate-weight survey df — **Resolved**. `df_survey = rank(replicate_weights) - 1` matching R's `survey::degf()`. For IF paths, `n_valid - 1` when dropped replicates reduce effective count. | `survey.py` | #238 | Resolved | | CallawaySantAnna survey: strata/PSU/FPC — **Resolved**. Aggregated SEs (overall, event study, group) use `compute_survey_if_variance()`. Bootstrap uses PSU-level multiplier weights. | `staggered.py` | #237 | Resolved | -| CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | +| CallawaySantAnna survey + covariates + IPW/DR — **Resolved**. DRDID panel nuisance IF corrections (PS + OR) implemented for both survey and non-survey DR paths (Phase 7a). IPW path unblocked. | `staggered.py` | #233 | Resolved | | SyntheticDiD/TROP survey: strata/PSU/FPC — **Resolved**. Rao-Wu rescaled bootstrap implemented for both. TROP uses cross-classified pseudo-strata. Rust TROP remains pweight-only (Python fallback for full design). | `synthetic_did.py`, `trop.py` | — | Resolved | | EfficientDiD hausman_pretest() clustered covariance stale `n_cl` — **Resolved**. Recompute `n_cl` and remap indices after `row_finite` filtering via `np.unique(return_inverse=True)`. | `efficient_did.py` | #230 | Resolved | | EfficientDiD `control_group="last_cohort"` trims at `last_g - anticipation` but REGISTRY says `t >= last_g`. With `anticipation=0` (default) these are identical. With `anticipation>0`, code is arguably more conservative (excludes anticipation-contaminated periods). Either align REGISTRY with code or change code to `t < last_g` — needs design decision. | `efficient_did.py` | #230 | Low | @@ -163,11 +163,11 @@ Spurious RuntimeWarnings ("divide by zero", "overflow", "invalid value") are emi Features in R's `did` package that block porting additional tests: -| Feature | R tests blocked | Priority | -|---------|----------------|----------| -| Repeated cross-sections (`panel=FALSE`) | ~7 tests in test-att_gt.R + test-user_bug_fixes.R | Medium | -| Sampling/population weights | 7 tests incl. all JEL replication | Medium | -| Calendar time aggregation | 1 test in test-att_gt.R | Low | +| Feature | R tests blocked | Priority | Status | +|---------|----------------|----------|--------| +| Repeated cross-sections (`panel=FALSE`) | ~7 tests in test-att_gt.R + test-user_bug_fixes.R | High | **Resolved** — Phase 7b: `panel=False` on CallawaySantAnna | +| Sampling/population weights | 7 tests incl. all JEL replication | Medium | **Resolved** (Phases 1-6 + 7a: CS IPW/DR + covariates + survey) | +| Calendar time aggregation | 1 test in test-att_gt.R | Low | | --- diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index d2a5417b..72528467 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -22,11 +22,12 @@ import numpy as np import pandas as pd -from scipy import optimize, stats +from scipy import optimize from diff_diff.results import ( MultiPeriodDiDResults, ) +from diff_diff.utils import _get_critical_value # ============================================================================= # Delta Restriction Classes @@ -193,6 +194,9 @@ class HonestDiDResults: original_results: Optional[Any] = field(default=None, repr=False) # Event study bounds (optional) event_study_bounds: Optional[Dict[Any, Dict[str, float]]] = field(default=None, repr=False) + # Survey design metadata (Phase 7d) + survey_metadata: Optional[Any] = field(default=None, repr=False) + df_survey: Optional[int] = field(default=None, repr=False) def __repr__(self) -> str: sig = "" if self.ci_lb <= 0 <= self.ci_ub else "*" @@ -534,7 +538,7 @@ def plot( def _extract_event_study_params( results: Union[MultiPeriodDiDResults, Any], -) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any]]: +) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any], Optional[int]]: """ Extract event study parameters from results objects. @@ -557,6 +561,8 @@ def _extract_event_study_params( Pre-period identifiers. post_periods : list Post-period identifiers. + df_survey : int or None + Survey degrees of freedom for t-distribution inference. """ if isinstance(results, MultiPeriodDiDResults): # Extract from MultiPeriodDiD @@ -606,7 +612,20 @@ def _extract_event_study_params( # Fallback: diagonal from SEs sigma = np.diag(np.array(ses) ** 2) - return beta_hat, sigma, num_pre_periods, num_post_periods, pre_periods, post_periods + # Extract survey df if available + df_survey = None + if hasattr(results, "survey_metadata") and results.survey_metadata is not None: + df_survey = getattr(results.survey_metadata, "df_survey", None) + + return ( + beta_hat, + sigma, + num_pre_periods, + num_post_periods, + pre_periods, + post_periods, + df_survey, + ) else: # Try CallawaySantAnnaResults @@ -641,9 +660,29 @@ def _extract_event_study_params( ses.append(event_effects[t]["se"]) beta_hat = np.array(effects) - sigma = np.diag(np.array(ses) ** 2) - return (beta_hat, sigma, len(pre_times), len(post_times), pre_times, post_times) + # Use full event-study VCV if available (Phase 7d), + # otherwise fall back to diagonal from SEs + if hasattr(results, "event_study_vcov") and results.event_study_vcov is not None: + # event_study_vcov is indexed by sorted rel_times + sigma = results.event_study_vcov + else: + sigma = np.diag(np.array(ses) ** 2) + + # Extract survey df + df_survey = None + if hasattr(results, "survey_metadata") and results.survey_metadata is not None: + df_survey = getattr(results.survey_metadata, "df_survey", None) + + return ( + beta_hat, + sigma, + len(pre_times), + len(post_times), + pre_times, + post_times, + df_survey, + ) except ImportError: pass @@ -860,7 +899,13 @@ def _solve_bounds_lp( return lb, ub -def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple[float, float]: +def _compute_flci( + lb: float, + ub: float, + se: float, + alpha: float = 0.05, + df: Optional[int] = None, +) -> Tuple[float, float]: """ Compute Fixed Length Confidence Interval (FLCI). @@ -877,6 +922,9 @@ def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple Standard error of the estimator. alpha : float Significance level. + df : int, optional + Degrees of freedom. If provided, uses t-distribution critical value + instead of normal (for survey designs with df = n_PSU - n_strata). Returns ------- @@ -895,7 +943,7 @@ def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple if not (0 < alpha < 1): raise ValueError(f"alpha must be between 0 and 1, got alpha={alpha}") - z = stats.norm.ppf(1 - alpha / 2) + z = _get_critical_value(alpha, df) ci_lb = lb - z * se ci_ub = ub + z * se return ci_lb, ci_ub @@ -909,6 +957,7 @@ def _compute_clf_ci( max_pre_violation: float, alpha: float = 0.05, n_draws: int = 1000, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """ Compute Conditional Least Favorable (C-LF) confidence interval. @@ -931,6 +980,8 @@ def _compute_clf_ci( Significance level. n_draws : int Number of Monte Carlo draws for conditional CI. + df : int, optional + Degrees of freedom for t-distribution critical value. Returns ------- @@ -956,7 +1007,7 @@ def _compute_clf_ci( ub = theta + bound # CI with estimation uncertainty - z = stats.norm.ppf(1 - alpha / 2) + z = _get_critical_value(alpha, df) ci_lb = lb - z * se ci_ub = ub + z * se @@ -1086,7 +1137,7 @@ def fit( M = M if M is not None else self.M # Extract event study parameters - (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods, df_survey) = ( _extract_event_study_params(results) ) @@ -1137,22 +1188,41 @@ def fit( # Compute bounds based on method if self.method == "smoothness": lb, ub, ci_lb, ci_ub = self._compute_smoothness_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M + beta_post, sigma_post, l_vec, num_pre, num_post, M, df=df_survey ) ci_method = "FLCI" elif self.method == "relative_magnitude": lb, ub, ci_lb, ci_ub = self._compute_rm_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results + beta_post, + sigma_post, + l_vec, + num_pre, + num_post, + M, + pre_periods, + results, + df=df_survey, ) ci_method = "C-LF" else: # combined lb, ub, ci_lb, ci_ub = self._compute_combined_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results + beta_post, + sigma_post, + l_vec, + num_pre, + num_post, + M, + pre_periods, + results, + df=df_survey, ) ci_method = "FLCI" + # Extract survey_metadata for storage on results + survey_metadata = getattr(results, "survey_metadata", None) + return HonestDiDResults( lb=lb, ub=ub, @@ -1165,6 +1235,8 @@ def fit( alpha=self.alpha, ci_method=ci_method, original_results=results, + survey_metadata=survey_metadata, + df_survey=df_survey, ) def _compute_smoothness_bounds( @@ -1175,6 +1247,7 @@ def _compute_smoothness_bounds( num_pre: int, num_post: int, M: float, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """Compute bounds under smoothness restriction.""" # Construct constraints @@ -1185,7 +1258,7 @@ def _compute_smoothness_bounds( # Compute FLCI se = np.sqrt(l_vec @ sigma_post @ l_vec) - ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha) + ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha, df=df) return lb, ub, ci_lb, ci_ub @@ -1199,6 +1272,7 @@ def _compute_rm_bounds( Mbar: float, pre_periods: List, results: Any, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """Compute bounds under relative magnitudes restriction.""" # Estimate max pre-period violation from pre-trends @@ -1209,12 +1283,18 @@ def _compute_rm_bounds( # No pre-period violations detected - use point estimate theta = np.dot(l_vec, beta_post) se = np.sqrt(l_vec @ sigma_post @ l_vec) - z = stats.norm.ppf(1 - self.alpha / 2) + z = _get_critical_value(self.alpha, df) return theta, theta, theta - z * se, theta + z * se # Compute bounds lb, ub, ci_lb, ci_ub = _compute_clf_ci( - beta_post, sigma_post, l_vec, Mbar, max_pre_violation, self.alpha + beta_post, + sigma_post, + l_vec, + Mbar, + max_pre_violation, + self.alpha, + df=df, ) return lb, ub, ci_lb, ci_ub @@ -1229,16 +1309,17 @@ def _compute_combined_bounds( M: float, pre_periods: List, results: Any, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """Compute bounds under combined smoothness + RM restriction.""" # Get smoothness bounds lb_sd, ub_sd, _, _ = self._compute_smoothness_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M + beta_post, sigma_post, l_vec, num_pre, num_post, M, df=df ) # Get RM bounds (use M as Mbar for combined) lb_rm, ub_rm, _, _ = self._compute_rm_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results + beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results, df=df ) # Combined bounds are intersection @@ -1252,7 +1333,7 @@ def _compute_combined_bounds( # Compute FLCI on combined bounds se = np.sqrt(l_vec @ sigma_post @ l_vec) - ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha) + ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha, df=df) return lb, ub, ci_lb, ci_ub diff --git a/diff_diff/prep_dgp.py b/diff_diff/prep_dgp.py index 2aab32c2..5fd42a20 100644 --- a/diff_diff/prep_dgp.py +++ b/diff_diff/prep_dgp.py @@ -136,6 +136,7 @@ def generate_staggered_data( time_trend: float = 0.1, noise_sd: float = 0.5, seed: Optional[int] = None, + panel: bool = True, ) -> pd.DataFrame: """ Generate synthetic data for staggered adoption DiD analysis. @@ -170,6 +171,10 @@ def generate_staggered_data( Standard deviation of idiosyncratic noise. seed : int, optional Random seed for reproducibility. + panel : bool, default=True + If True (default), generate balanced panel data (same units across + all periods). If False, generate repeated cross-section data where + each period draws independent observations with globally unique IDs. Returns ------- @@ -219,6 +224,56 @@ def generate_staggered_data( n_never = int(n_units * never_treated_frac) n_treated = n_units - n_never + if not panel: + # --- Repeated cross-section mode --- + # Each period draws n_units independent observations with unique IDs. + # Cohorts are assigned from the same distribution as panel. + records = [] + for period in range(n_periods): + # For each period, draw fresh cohort assignments + ft_period = np.zeros(n_units, dtype=int) + if n_treated > 0: + cohort_assignments = rng.choice(len(cohort_periods), size=n_treated) + ft_period[n_never:] = [cohort_periods[c] for c in cohort_assignments] + + # Unique unit IDs per period + for i in range(n_units): + uid = f"u{period}_{i}" + unit_first_treat = ft_period[i] + is_ever_treated = unit_first_treat > 0 + + is_treated = is_ever_treated and period >= unit_first_treat + + # Outcome: unit_fe_proxy (drawn fresh) + time trend + treatment + noise + unit_fe_proxy = rng.normal(0, unit_fe_sd) + y = 10.0 + unit_fe_proxy + time_trend * period + + effect = 0.0 + if is_treated: + time_since_treatment = period - unit_first_treat + if dynamic_effects: + effect = treatment_effect * (1 + effect_growth * time_since_treatment) + else: + effect = treatment_effect + y += effect + + y += rng.normal(0, noise_sd) + + records.append( + { + "unit": uid, + "period": period, + "outcome": y, + "first_treat": unit_first_treat, + "treated": int(is_treated), + "treat": int(is_ever_treated), + "true_effect": effect, + } + ) + + return pd.DataFrame(records) + + # --- Panel mode (default) --- # Assign treatment cohorts first_treat = np.zeros(n_units, dtype=int) if n_treated > 0: diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 718dc93b..e6836ba0 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -92,6 +92,14 @@ def _linear_regression( return beta, residuals +def _safe_inv(A: np.ndarray) -> np.ndarray: + """Invert a square matrix with lstsq fallback for near-singular cases.""" + try: + return np.linalg.solve(A, np.eye(A.shape[0])) + except np.linalg.LinAlgError: + return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0] + + class CallawaySantAnna( CallawaySantAnnaBootstrapMixin, CallawaySantAnnaAggregationMixin, @@ -262,6 +270,7 @@ def __init__( base_period: str = "varying", cband: bool = True, pscore_trim: float = 0.01, + panel: bool = True, ): import warnings @@ -324,6 +333,7 @@ def __init__( self.cband = cband self.pscore_trim = pscore_trim + self.panel = panel self.is_fitted_ = False self.results_: Optional[CallawaySantAnnaResults] = None @@ -501,6 +511,8 @@ def _precompute_structures( "covariate_by_period": covariate_by_period, "time_periods": time_periods, "is_balanced": is_balanced, + "is_panel": True, + "canonical_size": len(all_units), "survey_weights": survey_weights_arr, "resolved_survey": resolved_survey, "resolved_survey_unit": resolved_survey_unit, @@ -875,10 +887,12 @@ def _compute_all_att_gt_vectorized( if task_keys: df_survey_val = precomputed.get("df_survey") # Guard: replicate design with undefined df → NaN inference - if (df_survey_val is None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + df_survey_val is None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(atts), @@ -1217,10 +1231,13 @@ def _compute_all_att_gt_covariate_reg( # Use survey df for replicate designs (propagated from precomputed) _ipw_dr_df = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference - if (_ipw_dr_df is None and precomputed is not None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + _ipw_dr_df is None + and precomputed is not None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): _ipw_dr_df = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(atts), np.array(ses), alpha=self.alpha, df=_ipw_dr_df @@ -1250,7 +1267,9 @@ def fit( Parameters ---------- data : pd.DataFrame - Panel data with unit and time identifiers. + Panel data with unit and time identifiers. For repeated + cross-sections (``panel=False``), each observation should + have a unique unit ID — units do not repeat across periods. outcome : str Name of outcome variable column. unit : str @@ -1275,8 +1294,10 @@ def fit( survey_design : SurveyDesign, optional Survey design specification. Supports pweight with strata/PSU/FPC. Aggregated SEs (overall, event study, group) use design-based - variance via compute_survey_if_variance(). - Covariates + IPW/DR + survey raises NotImplementedError. + variance via compute_survey_if_variance(). All estimation methods + (reg, ipw, dr) support covariates + survey. For repeated + cross-sections (``panel=False``), survey weights are + per-observation (no unit-level collapse). Returns ------- @@ -1308,7 +1329,8 @@ def fit( # Validate within-unit constancy for panel survey designs if resolved_survey is not None: - _validate_unit_constant_survey(data, unit, survey_design) + if self.panel: + _validate_unit_constant_survey(data, unit, survey_design) if resolved_survey.weight_type != "pweight": raise ValueError( f"CallawaySantAnna survey support requires weight_type='pweight', " @@ -1320,22 +1342,6 @@ def fit( # Bootstrap + survey is now supported via PSU-level multiplier bootstrap. - # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet - # implemented to match DRDID panel formula) - if ( - resolved_survey is not None - and covariates is not None - and len(covariates) > 0 - and self.estimation_method in ("ipw", "dr") - ): - raise NotImplementedError( - f"Survey weights with covariates and estimation_method=" - f"'{self.estimation_method}' is not yet supported for " - f"CallawaySantAnna. The DRDID panel nuisance-estimation IF " - f"corrections are not yet implemented. Use estimation_method='reg' " - f"with covariates, or use any method without covariates." - ) - # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: @@ -1365,13 +1371,19 @@ def fit( time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) - # Get unique units - unit_info = ( - df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index() - ) - - n_treated_units = (unit_info[first_treat] > 0).sum() - n_control_units = (unit_info["_never_treated"]).sum() + if self.panel: + # Panel: count unique units + unit_info = ( + df.groupby(unit) + .agg({first_treat: "first", "_never_treated": "first"}) + .reset_index() + ) + n_treated_units = (unit_info[first_treat] > 0).sum() + n_control_units = (unit_info["_never_treated"]).sum() + else: + # RCS: count observations per cohort (no unit tracking) + n_treated_units = int((df[first_treat] > 0).sum()) + n_control_units = int(df["_never_treated"].sum()) if n_control_units == 0 and self.control_group == "never_treated": raise ValueError( @@ -1392,17 +1404,30 @@ def fit( # per-cell SEs use IF-based variance, not TSL. The user's cluster= # parameter is handled by the existing non-survey clustering path. # Pre-compute data structures for efficient ATT(g,t) computation - precomputed = self._precompute_structures( - df, - outcome, - unit, - time, - first_treat, - covariates, - time_periods, - treatment_groups, - resolved_survey=resolved_survey, - ) + if self.panel: + precomputed = self._precompute_structures( + df, + outcome, + unit, + time, + first_treat, + covariates, + time_periods, + treatment_groups, + resolved_survey=resolved_survey, + ) + else: + precomputed = self._precompute_structures_rc( + df, + outcome, + unit, + time, + first_treat, + covariates, + time_periods, + treatment_groups, + resolved_survey=resolved_survey, + ) # Recompute survey metadata from the unit-level resolved survey so # that n_psu and df_survey reflect the actual survey design (explicit @@ -1419,16 +1444,67 @@ def fit( # survey df computed in _precompute_structures for consistency. df_survey = precomputed.get("df_survey") # Guard: replicate design with undefined df (rank <= 1) → NaN inference - if (df_survey is None and resolved_survey is not None - and hasattr(resolved_survey, 'uses_replicate_variance') - and resolved_survey.uses_replicate_variance): + if ( + df_survey is None + and resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): df_survey = 0 # Compute ATT(g,t) for each group-time combination min_period = min(time_periods) has_survey = resolved_survey is not None - if covariates is None and self.estimation_method == "reg": + if not self.panel: + # --- Repeated cross-section path --- + # No vectorized/Cholesky fast paths (panel-only optimizations). + # Loop using _compute_att_gt_rc() for each (g,t). + group_time_effects = {} + influence_func_info = {} + + for g in treatment_groups: + if self.base_period == "universal": + universal_base = g - 1 - self.anticipation + valid_periods = [t for t in time_periods if t != universal_base] + else: + valid_periods = [ + t for t in time_periods if t >= g - self.anticipation or t > min_period + ] + + for t in valid_periods: + att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_rc( + precomputed, + g, + t, + covariates, + ) + + if att_gt is not None: + t_stat, p_val, ci = safe_inference( + att_gt, + se_gt, + alpha=self.alpha, + df=df_survey, + ) + + gte_entry = { + "effect": att_gt, + "se": se_gt, + "t_stat": t_stat, + "p_value": p_val, + "conf_int": ci, + "n_treated": n_treat, + "n_control": n_ctrl, + } + if sw_sum is not None: + gte_entry["survey_weight_sum"] = sw_sum + group_time_effects[(g, t)] = gte_entry + + if inf_info is not None: + influence_func_info[(g, t)] = inf_info + + elif covariates is None and self.estimation_method == "reg": # Fast vectorized path for the common no-covariates regression case group_time_effects, influence_func_info = self._compute_all_att_gt_vectorized( precomputed, treatment_groups, time_periods, min_period @@ -1523,9 +1599,12 @@ def fit( if survey_metadata.df_survey != df_survey: survey_metadata.df_survey = df_survey # Guard: replicate design with undefined df (rank <= 1) → NaN inference - if (df_survey is None and resolved_survey is not None - and hasattr(resolved_survey, 'uses_replicate_variance') - and resolved_survey.uses_replicate_variance): + if ( + df_survey is None + and resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): df_survey = 0 overall_t, overall_p, overall_ci = safe_inference( overall_att, @@ -1679,6 +1758,9 @@ def fit( ) # Store results + # Retrieve event-study VCV from aggregation mixin (Phase 7d) + event_study_vcov = getattr(self, "_event_study_vcov", None) + self.results_ = CallawaySantAnnaResults( group_time_effects=group_time_effects, overall_att=overall_att, @@ -1700,6 +1782,7 @@ def fit( cband_crit_value=cband_crit_value, pscore_trim=self.pscore_trim, survey_metadata=survey_metadata, + event_study_vcov=event_study_vcov, ) self.is_fitted_ = True @@ -2178,13 +2261,68 @@ def _doubly_robust( att = att_treated_part + augmentation # Step 4: Influence function (survey-weighted DR) + # Start with plug-in IF, then add nuisance parameter corrections + # (Sant'Anna & Zhao 2020, Theorem 3.1) psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att) psi_control = (weights_control / sw_t_sum) * (m_control - control_change) + inf_func = np.concatenate([psi_treated, psi_control]) - var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) + if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: + # --- PS IF correction (mirrors IPW L1929-1961) --- + # Accounts for propensity score estimation uncertainty + X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) + pscore_treated_clipped = np.clip( + pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim + ) + pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) + + # Survey-weighted PS Hessian + W_ps = pscore_all * (1 - pscore_all) + if sw_all is not None: + W_ps = W_ps * sw_all + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps_inv = _safe_inv(H_ps) + + # PS score + D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) + score_ps = (D_all - pscore_all)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_ps_inv # (n_t+n_c, p+1) + + # M2_dr: dATT/dgamma — gradient of DR ATT w.r.t. PS parameters + # Only the control augmentation term depends on PS via w_ipw + dr_resid_control = m_control - control_change + M2_dr = np.mean( + ((weights_control / sw_t_sum) * dr_resid_control)[:, None] + * X_all_int[n_t:], + axis=0, + ) + inf_func = inf_func + asy_lin_rep_ps @ M2_dr + + # --- OR IF correction --- + # Accounts for outcome regression estimation uncertainty + X_c_int = X_control_with_intercept + W_diag = sw_control if sw_control is not None else np.ones(n_c) + XtWX = X_c_int.T @ (W_diag[:, None] * X_c_int) + bread = _safe_inv(XtWX) + + # M1: dATT/dbeta — gradient of DR ATT w.r.t. OR parameters + X_t_int = X_treated_with_intercept + M1 = ( + -np.sum(sw_treated[:, None] * X_t_int, axis=0) + + np.sum(weights_control[:, None] * X_c_int, axis=0) + ) / sw_t_sum + + # OR asymptotic linear representation (control-only) + resid_c = control_change - m_control + asy_lin_rep_or = (W_diag * resid_c)[:, None] * X_c_int @ bread + # Apply to control portion only (treated contribute zero) + inf_func[n_t:] += asy_lin_rep_or @ M1 + + # Recompute SE from corrected IF + var_psi = np.sum(inf_func**2) se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 - - inf_func = np.concatenate([psi_treated, psi_control]) else: # IPW weights for control: p(X) / (1 - p(X)) weights_control = pscore_control / (1 - pscore_control) @@ -2194,14 +2332,52 @@ def _doubly_robust( augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t) att = att_treated_part + augmentation - # Step 4: Standard error using influence function + # Step 4: Influence function with nuisance IF corrections psi_treated = (treated_change - m_treated - att) / n_t psi_control = (weights_control * (m_control - control_change)) / n_t + inf_func = np.concatenate([psi_treated, psi_control]) - var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) - se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: + # --- PS IF correction --- + X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) + pscore_treated_clipped = np.clip( + pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim + ) + pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) - inf_func = np.concatenate([psi_treated, psi_control]) + W_ps = pscore_all * (1 - pscore_all) + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps_inv = _safe_inv(H_ps) + + D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) + score_ps = (D_all - pscore_all)[:, None] * X_all_int + asy_lin_rep_ps = score_ps @ H_ps_inv + + dr_resid_control = m_control - control_change + M2_dr = np.mean( + ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:], + axis=0, + ) + inf_func = inf_func + asy_lin_rep_ps @ M2_dr + + # --- OR IF correction --- + X_c_int = X_control_with_intercept + XtX = X_c_int.T @ X_c_int + bread = _safe_inv(XtX) + + X_t_int = X_treated_with_intercept + M1 = ( + -np.sum(X_t_int, axis=0) + + np.sum(weights_control[:, None] * X_c_int, axis=0) + ) / n_t + + resid_c = control_change - m_control + asy_lin_rep_or = resid_c[:, None] * X_c_int @ bread + inf_func[n_t:] += asy_lin_rep_or @ M1 + + # Recompute SE from corrected IF + var_psi = np.sum(inf_func**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 else: # Without covariates, DR simplifies to difference in means if sw_treated is not None: @@ -2235,6 +2411,721 @@ def _doubly_robust( return att, se, inf_func + # ========================================================================= + # Repeated Cross-Section (RCS) methods + # ========================================================================= + + def _precompute_structures_rc( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + time_periods: List[Any], + treatment_groups: List[Any], + resolved_survey=None, + ) -> PrecomputedData: + """ + Pre-compute observation-level structures for repeated cross-section. + + Unlike the panel path, RCS does not pivot to wide format. Each + observation is treated independently (no within-unit differencing). + + Returns + ------- + PrecomputedData + Dictionary with pre-computed structures (observation-level). + """ + n_obs = len(df) + + # Observation-level arrays (no pivot) + obs_time = df[time].values + obs_outcome = df[outcome].values + unit_cohorts = df[first_treat].values + + # "all_units" key holds integer observation indices for backward + # compatibility with aggregation code + all_units = np.arange(n_obs) + + # Pre-compute cohort masks (boolean arrays, observation-level) + cohort_masks = {} + for g in treatment_groups: + cohort_masks[g] = unit_cohorts == g + + # Never-treated mask + never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf) + + # Period-to-column mapping (identity for RCS — used for base period checks) + period_to_col = {t: i for i, t in enumerate(sorted(time_periods))} + + # Covariates (observation-level, not per-period) + obs_covariates = None + if covariates: + obs_covariates = df[covariates].values + + # Survey weights (already per-observation for RCS) + if resolved_survey is not None: + survey_weights_arr = resolved_survey.weights.copy() + else: + survey_weights_arr = None + + # For RCS, the resolved survey is already per-observation + resolved_survey_rc = resolved_survey + + return { + "all_units": all_units, + "unit_to_idx": None, # RCS: obs indices are positions + "unit_cohorts": unit_cohorts, + "canonical_size": n_obs, + "is_panel": False, + "obs_time": obs_time, + "obs_outcome": obs_outcome, + "obs_covariates": obs_covariates, + "cohort_masks": cohort_masks, + "never_treated_mask": never_treated_mask, + "time_periods": time_periods, + "period_to_col": period_to_col, + "is_balanced": False, + "survey_weights": survey_weights_arr, + "resolved_survey": resolved_survey, + "resolved_survey_unit": resolved_survey_rc, + "df_survey": ( + resolved_survey_rc.df_survey + if resolved_survey_rc is not None and hasattr(resolved_survey_rc, "df_survey") + else None + ), + } + + def _compute_att_gt_rc( + self, + precomputed: PrecomputedData, + g: Any, + t: Any, + covariates: Optional[List[str]], + ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]: + """ + Compute ATT(g,t) for repeated cross-section data. + + For RCS, the 2x2 DiD compares outcomes across two independent + cross-sections (periods t and base period s) rather than + within-unit changes. + + Returns + ------- + att_gt : float or None + se_gt : float + n_treated : int (treated obs at period t) + n_control : int (control obs at period t) + inf_func_info : dict or None + survey_weight_sum : float or None + """ + cohort_masks = precomputed["cohort_masks"] + never_treated_mask = precomputed["never_treated_mask"] + unit_cohorts = precomputed["unit_cohorts"] + obs_time = precomputed["obs_time"] + obs_outcome = precomputed["obs_outcome"] + period_to_col = precomputed["period_to_col"] + + # Base period selection (same logic as panel) + if self.base_period == "universal": + base_period_val = g - 1 - self.anticipation + else: # varying + if t < g - self.anticipation: + base_period_val = t - 1 + else: + base_period_val = g - 1 - self.anticipation + + if base_period_val not in period_to_col or t not in period_to_col: + return None, 0.0, 0, 0, None, None + + # Treated mask = cohort g + treated_mask = cohort_masks[g] + + # Control mask (same logic as panel) + if self.control_group == "never_treated": + control_mask = never_treated_mask + else: # not_yet_treated + nyt_threshold = max(t, base_period_val) + self.anticipation + control_mask = never_treated_mask | ( + (unit_cohorts > nyt_threshold) & (unit_cohorts != g) + ) + + # Period masks + at_t = obs_time == t + at_s = obs_time == base_period_val + + # 4 groups of observations + treated_t = treated_mask & at_t + treated_s = treated_mask & at_s + control_t = control_mask & at_t + control_s = control_mask & at_s + + n_gt = int(np.sum(treated_t)) + n_gs = int(np.sum(treated_s)) + n_ct = int(np.sum(control_t)) + n_cs = int(np.sum(control_s)) + + if n_gt == 0 or n_ct == 0 or n_gs == 0 or n_cs == 0: + return None, 0.0, 0, 0, None, None + + # Extract outcomes for each group + y_gt = obs_outcome[treated_t] + y_gs = obs_outcome[treated_s] + y_ct = obs_outcome[control_t] + y_cs = obs_outcome[control_s] + + # Survey weights + survey_w = precomputed.get("survey_weights") + sw_gt = survey_w[treated_t] if survey_w is not None else None + sw_gs = survey_w[treated_s] if survey_w is not None else None + sw_ct = survey_w[control_t] if survey_w is not None else None + sw_cs = survey_w[control_s] if survey_w is not None else None + + # Guard against zero effective mass + if sw_gt is not None: + if np.sum(sw_gt) <= 0 or np.sum(sw_gs) <= 0: + return np.nan, np.nan, 0, 0, None, None + if np.sum(sw_ct) <= 0 or np.sum(sw_cs) <= 0: + return np.nan, np.nan, 0, 0, None, None + + # Get covariates if specified + obs_covariates = precomputed.get("obs_covariates") + has_covariates = covariates is not None and obs_covariates is not None + + if has_covariates: + X_gt = obs_covariates[treated_t] + X_gs = obs_covariates[treated_s] + X_ct = obs_covariates[control_t] + X_cs = obs_covariates[control_s] + + # Check for NaN in covariates + if ( + np.any(np.isnan(X_gt)) + or np.any(np.isnan(X_gs)) + or np.any(np.isnan(X_ct)) + or np.any(np.isnan(X_cs)) + ): + warnings.warn( + f"Missing values in covariates for group {g}, time {t} (RCS). " + "Falling back to unconditional estimation.", + UserWarning, + stacklevel=3, + ) + has_covariates = False + + if has_covariates and self.estimation_method == "reg": + att, se, inf_func_all, idx_all = self._outcome_regression_rc( + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + elif has_covariates and self.estimation_method == "ipw": + att, se, inf_func_all, idx_all = self._ipw_estimation_rc( + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + elif has_covariates and self.estimation_method == "dr": + att, se, inf_func_all, idx_all = self._doubly_robust_rc( + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + else: + # No-covariates 2x2 DiD (all methods reduce to same) + att, se, inf_func_all, idx_all = self._rc_2x2_did( + y_gt, + y_gs, + y_ct, + y_cs, + treated_t, + treated_s, + control_t, + control_s, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + + # Build influence function info + # For RCS, treated_idx/control_idx combine obs from BOTH periods + treated_idx = np.concatenate([np.where(treated_t)[0], np.where(treated_s)[0]]) + control_idx = np.concatenate([np.where(control_t)[0], np.where(control_s)[0]]) + + n_treated_combined = len(treated_idx) + inf_func_info = { + "treated_idx": treated_idx, + "control_idx": control_idx, + "treated_units": treated_idx, # For RCS, obs indices = "units" + "control_units": control_idx, + "treated_inf": inf_func_all[:n_treated_combined], + "control_inf": inf_func_all[n_treated_combined:], + } + + sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None + return att, se, n_gt, n_ct, inf_func_info, sw_sum + + def _rc_2x2_did( + self, + y_gt, + y_gs, + y_ct, + y_cs, + mask_gt, + mask_gs, + mask_ct, + mask_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Compute the basic 2x2 DiD for RCS (no covariates). + + ATT = (mean(Y_treated_t) - mean(Y_control_t)) + - (mean(Y_treated_s) - mean(Y_control_s)) + + Returns (att, se, inf_func_concat, idx_concat) where inf_func_concat + has treated obs (both periods) first, then control obs (both periods). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + + if sw_gt is not None: + sw_gt_norm = sw_gt / np.sum(sw_gt) + sw_gs_norm = sw_gs / np.sum(sw_gs) + sw_ct_norm = sw_ct / np.sum(sw_ct) + sw_cs_norm = sw_cs / np.sum(sw_cs) + + mu_gt = float(np.sum(sw_gt_norm * y_gt)) + mu_gs = float(np.sum(sw_gs_norm * y_gs)) + mu_ct = float(np.sum(sw_ct_norm * y_ct)) + mu_cs = float(np.sum(sw_cs_norm * y_cs)) + + att = (mu_gt - mu_ct) - (mu_gs - mu_cs) + + # Influence function for 4 groups (survey-weighted) + inf_gt = sw_gt_norm * (y_gt - mu_gt) + inf_ct = -sw_ct_norm * (y_ct - mu_ct) + inf_gs = -sw_gs_norm * (y_gs - mu_gs) + inf_cs = sw_cs_norm * (y_cs - mu_cs) + else: + mu_gt = float(np.mean(y_gt)) + mu_gs = float(np.mean(y_gs)) + mu_ct = float(np.mean(y_ct)) + mu_cs = float(np.mean(y_cs)) + + att = (mu_gt - mu_ct) - (mu_gs - mu_cs) + + # Influence function for 4 groups + inf_gt = (y_gt - mu_gt) / n_gt + inf_ct = -(y_ct - mu_ct) / n_ct + inf_gs = -(y_gs - mu_gs) / n_gs + inf_cs = (y_cs - mu_cs) / n_cs + + # Concatenate: treated (t then s), control (t then s) + inf_treated = np.concatenate([inf_gt, inf_gs]) + inf_control = np.concatenate([inf_ct, inf_cs]) + inf_all = np.concatenate([inf_treated, inf_control]) + + # SE from influence function + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = np.concatenate( + [ + np.where(mask_gt)[0], + np.where(mask_gs)[0], + np.where(mask_ct)[0], + np.where(mask_cs)[0], + ] + ) + + return att, se, inf_all, idx_all + + def _outcome_regression_rc( + self, + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Cross-sectional outcome regression for ATT(g,t). + + Two outcome models: E[Y|X] on controls at t, E[Y|X] on controls at s. + Predict counterfactual for treated at each period. + ATT = mean(Y_t - m_t(X_t)) for treated at t + - mean(Y_s - m_s(X_s)) for treated at s + + Returns (att, se, inf_func_concat, idx_concat). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + + # Fit outcome model on controls at period t + beta_t, resid_ct = _linear_regression( + X_ct, + y_ct, + rank_deficient_action=self.rank_deficient_action, + weights=sw_ct, + ) + beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0) + + # Fit outcome model on controls at base period s + beta_s, resid_cs = _linear_regression( + X_cs, + y_cs, + rank_deficient_action=self.rank_deficient_action, + weights=sw_cs, + ) + beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0) + + # Predict counterfactual for treated + X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) + X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) + m_gt = X_gt_int @ beta_t # counterfactual at t + m_gs = X_gs_int @ beta_s # counterfactual at s + + # Treated residuals + resid_treated_t = y_gt - m_gt + resid_treated_s = y_gs - m_gs + + if sw_gt is not None: + sw_gt_norm = sw_gt / np.sum(sw_gt) + sw_gs_norm = sw_gs / np.sum(sw_gs) + sw_ct_norm = sw_ct / np.sum(sw_ct) + sw_cs_norm = sw_cs / np.sum(sw_cs) + + att_t = float(np.sum(sw_gt_norm * resid_treated_t)) + att_s = float(np.sum(sw_gs_norm * resid_treated_s)) + att = att_t - att_s + + # Influence function + inf_gt = sw_gt_norm * (resid_treated_t - att_t) + inf_gs = -sw_gs_norm * (resid_treated_s - att_s) + inf_ct = -sw_ct_norm * resid_ct + inf_cs = sw_cs_norm * resid_cs + else: + att_t = float(np.mean(resid_treated_t)) + att_s = float(np.mean(resid_treated_s)) + att = att_t - att_s + + # Influence function + inf_gt = (resid_treated_t - att_t) / n_gt + inf_gs = -(resid_treated_s - att_s) / n_gs + inf_ct = -resid_ct / n_ct + inf_cs = resid_cs / n_cs + + # Concatenate: treated (t then s), control (t then s) + inf_treated = np.concatenate([inf_gt, inf_gs]) + inf_control = np.concatenate([inf_ct, inf_cs]) + inf_all = np.concatenate([inf_treated, inf_control]) + + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = None # caller builds idx from masks + return att, se, inf_all, idx_all + + def _ipw_estimation_rc( + self, + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Cross-sectional IPW estimation for ATT(g,t). + + Propensity score P(G=g | X) estimated on pooled treated+control + observations from both periods. Reweight controls in each period. + + Returns (att, se, inf_func_concat, idx_concat). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + + # Pool treated and control for propensity score + X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) + D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) + + sw_all = None + if sw_gt is not None: + sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs]) + + try: + beta_logistic, pscore = solve_logit( + X_all, + D_all, + rank_deficient_action=self.rank_deficient_action, + weights=sw_all, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) + except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise + warnings.warn( + "Propensity score estimation failed (RCS IPW). " + "Falling back to unconditional estimation.", + UserWarning, + stacklevel=4, + ) + p_treat = (n_gt + n_gs) / len(D_all) + pscore = np.full(len(D_all), p_treat) + + # Clip propensity scores + pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + + # Split propensity scores (treated ps not used — only control IPW weights) + ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] + ps_cs = pscore[n_gt + n_gs + n_ct :] + + # IPW weights for controls + w_ct = ps_ct / (1 - ps_ct) + w_cs = ps_cs / (1 - ps_cs) + + if sw_gt is not None: + w_ct = sw_ct * w_ct + w_cs = sw_cs * w_cs + + w_ct_norm = w_ct / np.sum(w_ct) if np.sum(w_ct) > 0 else w_ct + w_cs_norm = w_cs / np.sum(w_cs) if np.sum(w_cs) > 0 else w_cs + + if sw_gt is not None: + sw_gt_norm = sw_gt / np.sum(sw_gt) + sw_gs_norm = sw_gs / np.sum(sw_gs) + mu_gt = float(np.sum(sw_gt_norm * y_gt)) + mu_gs = float(np.sum(sw_gs_norm * y_gs)) + else: + mu_gt = float(np.mean(y_gt)) + mu_gs = float(np.mean(y_gs)) + + mu_ct_ipw = float(np.sum(w_ct_norm * y_ct)) + mu_cs_ipw = float(np.sum(w_cs_norm * y_cs)) + + att = (mu_gt - mu_ct_ipw) - (mu_gs - mu_cs_ipw) + + # Influence function + if sw_gt is not None: + inf_gt = sw_gt_norm * (y_gt - mu_gt) + inf_gs = -sw_gs_norm * (y_gs - mu_gs) + else: + inf_gt = (y_gt - mu_gt) / n_gt + inf_gs = -(y_gs - mu_gs) / n_gs + + inf_ct = -w_ct_norm * (y_ct - mu_ct_ipw) + inf_cs = w_cs_norm * (y_cs - mu_cs_ipw) + + inf_treated = np.concatenate([inf_gt, inf_gs]) + inf_control = np.concatenate([inf_ct, inf_cs]) + inf_all = np.concatenate([inf_treated, inf_control]) + + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = None + return att, se, inf_all, idx_all + + def _doubly_robust_rc( + self, + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Cross-sectional doubly robust estimation for ATT(g,t). + + Combines outcome regression and IPW. Consistent if either the + outcome model or the propensity model is correctly specified. + + Returns (att, se, inf_func_concat, idx_concat). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + + # --- Outcome regression component --- + beta_t, resid_ct = _linear_regression( + X_ct, + y_ct, + rank_deficient_action=self.rank_deficient_action, + weights=sw_ct, + ) + beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0) + + beta_s, resid_cs = _linear_regression( + X_cs, + y_cs, + rank_deficient_action=self.rank_deficient_action, + weights=sw_cs, + ) + beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0) + + X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) + X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) + X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) + X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) + + m_gt = X_gt_int @ beta_t + m_gs = X_gs_int @ beta_s + m_ct = X_ct_int @ beta_t + m_cs = X_cs_int @ beta_s + + # --- Propensity score component --- + X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) + D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) + sw_all = None + if sw_gt is not None: + sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs]) + + try: + beta_logistic, pscore = solve_logit( + X_all, + D_all, + rank_deficient_action=self.rank_deficient_action, + weights=sw_all, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) + except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise + warnings.warn( + "Propensity score estimation failed (RCS DR). " + "Falling back to unconditional propensity.", + UserWarning, + stacklevel=4, + ) + p_treat = (n_gt + n_gs) / len(D_all) + pscore = np.full(len(D_all), p_treat) + + pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + + ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] + ps_cs = pscore[n_gt + n_gs + n_ct :] + + # IPW weights for controls + w_ct = ps_ct / (1 - ps_ct) + w_cs = ps_cs / (1 - ps_cs) + + if sw_gt is not None: + w_ct = sw_ct * w_ct + w_cs = sw_cs * w_cs + + # --- DR ATT --- + if sw_gt is not None: + sw_gt_sum = np.sum(sw_gt) + sw_gs_sum = np.sum(sw_gs) + + # Period t component + att_t_or = float(np.sum(sw_gt * (y_gt - m_gt)) / sw_gt_sum) + att_t_aug = float(np.sum(w_ct * (m_ct - y_ct)) / sw_gt_sum) + att_t = att_t_or + att_t_aug + + # Period s component + att_s_or = float(np.sum(sw_gs * (y_gs - m_gs)) / sw_gs_sum) + att_s_aug = float(np.sum(w_cs * (m_cs - y_cs)) / sw_gs_sum) + att_s = att_s_or + att_s_aug + + att = att_t - att_s + + # Influence function (plug-in) + sw_gt_norm = sw_gt / sw_gt_sum + sw_gs_norm = sw_gs / sw_gs_sum + + inf_gt = sw_gt_norm * (y_gt - m_gt - att_t) + inf_gs = -sw_gs_norm * (y_gs - m_gs - att_s) + inf_ct = (w_ct / sw_gt_sum) * (m_ct - y_ct) + inf_cs = -(w_cs / sw_gs_sum) * (m_cs - y_cs) + else: + # Period t component + att_t_or = float(np.mean(y_gt - m_gt)) + att_t_aug = float(np.sum(w_ct * (m_ct - y_ct)) / n_gt) + att_t = att_t_or + att_t_aug + + # Period s component + att_s_or = float(np.mean(y_gs - m_gs)) + att_s_aug = float(np.sum(w_cs * (m_cs - y_cs)) / n_gs) + att_s = att_s_or + att_s_aug + + att = att_t - att_s + + # Influence function (plug-in) + inf_gt = (y_gt - m_gt - att_t) / n_gt + inf_gs = -(y_gs - m_gs - att_s) / n_gs + inf_ct = (w_ct * (m_ct - y_ct)) / n_gt + inf_cs = -(w_cs * (m_cs - y_cs)) / n_gs + + # Concatenate: treated (t then s), control (t then s) + inf_treated = np.concatenate([inf_gt, inf_gs]) + inf_control = np.concatenate([inf_ct, inf_cs]) + inf_all = np.concatenate([inf_treated, inf_control]) + + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = None + return att, se, inf_all, idx_all + def get_params(self) -> Dict[str, Any]: """Get estimator parameters (sklearn-compatible).""" return { @@ -2252,6 +3143,7 @@ def get_params(self) -> Dict[str, Any]: "base_period": self.base_period, "cband": self.cband, "pscore_trim": self.pscore_trim, + "panel": self.panel, } def set_params(self, **params) -> "CallawaySantAnna": diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 3b75bd8b..23632764 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -5,7 +5,7 @@ group-time average treatment effects into summary measures. """ -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import pandas as pd @@ -250,8 +250,15 @@ def _compute_combined_influence_function( return np.zeros(n_global_units), None return np.zeros(0), None + # Detect RCS mode via explicit flag. In RCS, obs indices ARE array positions. + _is_rcs = precomputed is not None and not precomputed.get("is_panel", True) + # Build unit index mapping (local or global) - if global_unit_to_idx is not None and n_global_units is not None: + if _is_rcs and n_global_units is not None: + # RCS: direct indexing — obs indices are the array positions + n_units = n_global_units + all_units = None + elif global_unit_to_idx is not None and n_global_units is not None: n_units = n_global_units all_units = None # caller already has the unit list else: @@ -287,6 +294,12 @@ def _compute_combined_influence_function( mask_g = precomputed_cohorts == g group_sizes[g] = float(np.sum(survey_w[mask_g])) total_weight = float(np.sum(survey_w)) + elif _is_rcs: + # RCS without survey: count observations per cohort + precomputed_cohorts = precomputed["unit_cohorts"] + for g in unique_groups: + group_sizes[g] = int(np.sum(precomputed_cohorts == g)) + total_weight = float(n_units) else: for g in unique_groups: treated_in_g = df[df["first_treat"] == g][unit].nunique() @@ -325,21 +338,31 @@ def _compute_combined_influence_function( # Build unit-group array: normalize iterator to (idx, uid) pairs unit_groups_array = np.full(n_units, -1, dtype=np.float64) - idx_uid_pairs = ( - [(idx, uid) for uid, idx in global_unit_to_idx.items()] - if global_unit_to_idx is not None - else list(enumerate(all_units)) - ) - if precomputed is not None: + if _is_rcs: + # RCS: direct vectorized assignment — obs indices are positions precomputed_cohorts = precomputed["unit_cohorts"] - precomputed_unit_to_idx = precomputed["unit_to_idx"] - for idx, uid in idx_uid_pairs: - if uid in precomputed_unit_to_idx: - cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] - if cohort in unique_groups_set: - unit_groups_array[idx] = cohort + for g in unique_groups: + mask_g = precomputed_cohorts == g + unit_groups_array[mask_g] = g + elif global_unit_to_idx is not None: + idx_uid_pairs = [(idx, uid) for uid, idx in global_unit_to_idx.items()] + + if precomputed is not None: + precomputed_cohorts = precomputed["unit_cohorts"] + precomputed_unit_to_idx = precomputed["unit_to_idx"] + for idx, uid in idx_uid_pairs: + if uid in precomputed_unit_to_idx: + cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] + if cohort in unique_groups_set: + unit_groups_array[idx] = cohort + else: + for idx, uid in idx_uid_pairs: + unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0] + if unit_first_treat in unique_groups_set: + unit_groups_array[idx] = unit_first_treat else: + idx_uid_pairs = list(enumerate(all_units)) for idx, uid in idx_uid_pairs: unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0] if unit_first_treat in unique_groups_set: @@ -357,10 +380,14 @@ def _compute_combined_influence_function( # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT s_i * (1{G_i=g} - pg_k). # The pg subtraction is NOT weighted by s_i because pg is already # the population-level expected value of w_i * 1{G_i=g}. - if global_unit_to_idx is not None and precomputed is not None: + if _is_rcs and precomputed is not None: + # RCS: survey weights are already per-observation, direct indexing + unit_sw = survey_w + elif global_unit_to_idx is not None and precomputed is not None: unit_sw = np.zeros(n_units) precomputed_unit_to_idx_local = precomputed["unit_to_idx"] - for idx, uid in idx_uid_pairs: + idx_uid_pairs_sw = [(idx, uid) for uid, idx in global_unit_to_idx.items()] + for idx, uid in idx_uid_pairs_sw: if uid in precomputed_unit_to_idx_local: pc_idx = precomputed_unit_to_idx_local[uid] unit_sw[idx] = survey_w[pc_idx] @@ -420,7 +447,8 @@ def _compute_aggregated_se_with_wif( df: pd.DataFrame, unit: str, precomputed: Optional["PrecomputedData"] = None, - ) -> float: + return_psi: bool = False, + ) -> "Union[float, Tuple[float, np.ndarray]]": """ Compute SE with weight influence function (wif) adjustment. @@ -443,8 +471,10 @@ def _compute_aggregated_se_with_wif( global_unit_to_idx = None n_global_units = None if precomputed is not None: - global_unit_to_idx = precomputed["unit_to_idx"] - n_global_units = len(precomputed["all_units"]) + global_unit_to_idx = precomputed["unit_to_idx"] # None for RCS + n_global_units = precomputed.get( + "canonical_size", len(precomputed.get("all_units", [])) + ) elif df is not None and unit is not None: n_global_units = df[unit].nunique() @@ -462,18 +492,22 @@ def _compute_aggregated_se_with_wif( ) if len(psi_total) == 0: - return 0.0 + return (0.0, psi_total) if return_psi else 0.0 # Check for NaN propagation from non-finite WIF if not np.all(np.isfinite(psi_total)): - return np.nan + return (np.nan, psi_total) if return_psi else np.nan # Use design-based variance when full survey design is available # Use unit-level resolved survey (panel IF is indexed by unit, not obs) resolved_survey = ( precomputed.get("resolved_survey_unit") if precomputed is not None else None ) - if resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance: + if ( + resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): from diff_diff.survey import compute_replicate_if_variance variance, n_valid_rep = compute_replicate_if_variance(psi_total, resolved_survey) @@ -481,8 +515,10 @@ def _compute_aggregated_se_with_wif( if precomputed is not None and n_valid_rep < resolved_survey.n_replicates: precomputed["df_survey"] = n_valid_rep - 1 if n_valid_rep > 1 else None if np.isnan(variance): - return np.nan - return np.sqrt(max(variance, 0.0)) + se = np.nan + else: + se = np.sqrt(max(variance, 0.0)) + return (se, psi_total) if return_psi else se if resolved_survey is not None and ( resolved_survey.strata is not None @@ -493,11 +529,14 @@ def _compute_aggregated_se_with_wif( variance = compute_survey_if_variance(psi_total, resolved_survey) if np.isnan(variance): - return np.nan - return np.sqrt(max(variance, 0.0)) + se = np.nan + else: + se = np.sqrt(max(variance, 0.0)) + return (se, psi_total) if return_psi else se variance = np.sum(psi_total**2) - return np.sqrt(variance) + se = np.sqrt(variance) + return (se, psi_total) if return_psi else se def _aggregate_event_study( self, @@ -583,6 +622,7 @@ def _aggregate_event_study( agg_effects_list = [] agg_ses_list = [] agg_n_groups = [] + _psi_vectors = [] # Per-event-time combined IF vectors for VCV for e, effect_list in sorted_periods: gt_pairs = [x[0] for x in effect_list] effs = np.array([x[1] for x in effect_list]) @@ -605,23 +645,36 @@ def _aggregate_event_study( # Compute SE with WIF adjustment (matching R's did::aggte) groups_for_gt = np.array([g for (g, t) in gt_pairs]) - agg_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed + result = self._compute_aggregated_se_with_wif( + gt_pairs, + weights, + effs, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, + return_psi=True, ) + agg_se, psi_e = result agg_effects_list.append(agg_effect) agg_ses_list.append(agg_se) agg_n_groups.append(len(effect_list)) + _psi_vectors.append(psi_e) # Batch inference for all relative periods if not agg_effects_list: return {} df_survey_val = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference - if (df_survey_val is None and precomputed is not None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + df_survey_val is None + and precomputed is not None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(agg_effects_list), @@ -654,6 +707,41 @@ def _aggregate_event_study( "n_groups": 0, } + # Compute full event-study VCV from per-event-time IF vectors (Phase 7d) + # This enables HonestDiD to use the full covariance structure + event_study_vcov = None + valid_psi = [p for p in _psi_vectors if len(p) > 0] + if valid_psi: + try: + Psi = np.column_stack(valid_psi) # (n_units, n_event_times) + resolved_survey = ( + precomputed.get("resolved_survey_unit") if precomputed is not None else None + ) + if ( + resolved_survey is not None + and not ( + hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ) + and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + ): + from diff_diff.survey import _compute_stratified_psu_meat + + meat, _, _ = _compute_stratified_psu_meat(Psi, resolved_survey) + event_study_vcov = meat + else: + # Simple sum-of-outer-products (no survey or replicate-only) + event_study_vcov = Psi.T @ Psi + except (ValueError, np.linalg.LinAlgError): + pass # Fall back to diagonal (None) + + # Attach VCV to self for CallawaySantAnna to pick up + self._event_study_vcov = event_study_vcov + return event_study_effects def _aggregate_by_group( @@ -704,8 +792,7 @@ def _aggregate_by_group( # Use WIF-adjusted SE (with survey design support) groups_for_gt = np.array([gg for (gg, t) in gt_pairs]) agg_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights, effs, groups_for_gt, - influence_func_info, df, unit, precomputed + gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed ) group_data_list.append((g, agg_effect, agg_se, len(g_effects))) @@ -717,10 +804,13 @@ def _aggregate_by_group( agg_ses = np.array([x[2] for x in group_data_list]) df_survey_val = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference - if (df_survey_val is None and precomputed is not None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + df_survey_val is None + and precomputed is not None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( agg_effects, diff --git a/diff_diff/staggered_bootstrap.py b/diff_diff/staggered_bootstrap.py index fc54cc42..48f40f21 100644 --- a/diff_diff/staggered_bootstrap.py +++ b/diff_diff/staggered_bootstrap.py @@ -196,8 +196,8 @@ def _run_multiplier_bootstrap( # units don't appear in any influence function. if precomputed is not None: all_units = precomputed["all_units"] - n_units = len(all_units) - unit_to_idx = precomputed["unit_to_idx"] + n_units = precomputed.get("canonical_size", len(all_units)) + unit_to_idx = precomputed["unit_to_idx"] # None for RCS else: # Fallback: collect units from influence functions all_units_set = set() @@ -235,9 +235,7 @@ def _run_multiplier_bootstrap( g = gt[0] if g not in _cohort_mass_cache: _cohort_mass_cache[g] = float(np.sum(survey_w[unit_cohorts == g])) - all_n_treated = np.array( - [_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float - ) + all_n_treated = np.array([_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float) else: all_n_treated = np.array( [group_time_effects[gt]["n_treated"] for gt in gt_pairs], dtype=float @@ -426,8 +424,9 @@ def _run_multiplier_bootstrap( # Batch compute bootstrap statistics for ATT(g,t) batch_ses, batch_ci_lo, batch_ci_hi, batch_pv = _compute_effect_bootstrap_stats_batch_func( - original_atts, bootstrap_atts_gt, alpha=self.alpha, - + original_atts, + bootstrap_atts_gt, + alpha=self.alpha, ) gt_ses = {} gt_cis = {} @@ -444,8 +443,10 @@ def _run_multiplier_bootstrap( overall_p_value = np.nan else: overall_se, overall_ci, overall_p_value = _compute_effect_bootstrap_stats_func( - original_overall, bootstrap_overall, alpha=self.alpha, context="overall ATT", - + original_overall, + bootstrap_overall, + alpha=self.alpha, + context="overall ATT", ) # Batch compute bootstrap statistics for event study effects @@ -457,8 +458,9 @@ def _run_multiplier_bootstrap( es_effects = np.array([event_study_info[e]["effect"] for e in rel_periods]) es_boot_matrix = np.column_stack([bootstrap_event_study[e] for e in rel_periods]) es_ses, es_ci_lo, es_ci_hi, es_pv = _compute_effect_bootstrap_stats_batch_func( - es_effects, es_boot_matrix, alpha=self.alpha, - + es_effects, + es_boot_matrix, + alpha=self.alpha, ) event_study_ses = {e: float(es_ses[i]) for i, e in enumerate(rel_periods)} event_study_cis = { @@ -475,8 +477,9 @@ def _run_multiplier_bootstrap( grp_effects = np.array([group_agg_info[g]["effect"] for g in group_list]) grp_boot_matrix = np.column_stack([bootstrap_group[g] for g in group_list]) grp_ses, grp_ci_lo, grp_ci_hi, grp_pv = _compute_effect_bootstrap_stats_batch_func( - grp_effects, grp_boot_matrix, alpha=self.alpha, - + grp_effects, + grp_boot_matrix, + alpha=self.alpha, ) group_effect_ses = {g: float(grp_ses[i]) for i, g in enumerate(group_list)} group_effect_cis = { diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index 3fea9cc8..d4f23429 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -114,6 +114,8 @@ class CallawaySantAnnaResults: event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None) influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) + # Full event-study VCV matrix (Phase 7d): indexed by sorted relative times + event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False) bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None pscore_trim: float = 0.01 diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 076947ff..12685a95 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,10 +416,13 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) -- **Note:** CallawaySantAnna survey support: weights, strata, PSU, and FPC are all supported. Analytical (`n_bootstrap=0`): aggregated SEs use design-based variance via `compute_survey_if_variance()`. Bootstrap (`n_bootstrap>0`): PSU-level multiplier weights replace analytical SEs for aggregated quantities. Regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. +- **Note:** CallawaySantAnna survey support: weights, strata, PSU, and FPC are all supported for all estimation methods (reg, ipw, dr) with or without covariates. Analytical (`n_bootstrap=0`): aggregated SEs use design-based variance via `compute_survey_if_variance()`. Bootstrap (`n_bootstrap>0`): PSU-level multiplier weights replace analytical SEs for aggregated quantities. IPW and DR with covariates use DRDID panel nuisance IF corrections (Phase 7a: PS IF correction via survey-weighted Hessian/score, OR IF correction via WLS bread and gradient; Sant'Anna & Zhao 2020, Theorem 3.1). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. - **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SE uses a conservative plug-in IF based on WLS residuals. The treated IF is `inf_treated_i = (sw_i/sum(sw_treated)) * (resid_i - ATT)` (normalized by treated weight sum, matching unweighted `(resid-ATT)/n_t`). The control IF is `inf_control_i = -(sw_i/sum(sw_control)) * wls_resid_i` (normalized by control weight sum, matching unweighted `-resid/n_c`). SE is computed as `sqrt(sum(sw_t_norm * (resid_t - ATT)^2) + sum(sw_c_norm * resid_c^2))`, the weighted analogue of the unweighted `sqrt(var_t/n_t + var_c/n_c)`. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel` — WLS residuals are orthogonal to the weighted design matrix by construction, so the first-order IF term is asymptotically valid but may be conservative. SEs pass weight-scale-invariance tests. The efficient DRDID correction is deferred to future work. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization. When strata/PSU/FPC are present, analytical aggregated SEs (`n_bootstrap=0`) use `compute_survey_if_variance()` on the combined IF/WIF; bootstrap aggregated SEs (`n_bootstrap>0`) use PSU-level multiplier weights. +- **Note:** Repeated cross-sections (`panel=False`, Phase 7b): supports surveys like BRFSS, ACS annual, and CPS monthly where units are not followed over time. Uses cross-sectional DRDID (Sant'Anna & Zhao 2020, Section 4): two outcome models (one per period) instead of one on ΔY, and per-observation influence functions instead of per-unit. All three estimation methods (reg, ipw, dr) supported with and without covariates. Aggregation and bootstrap use the "canonical index" abstraction where the index space is observations (not units). Survey weights are per-observation (no unit-level collapse). Data generated via `generate_staggered_data(panel=False)`. +- **Note:** Non-survey DR path also includes nuisance IF corrections (PS + OR), matching the survey path structure (Phase 7a). Previously used plug-in IF only. + **Reference implementation(s):** - R: `did::att_gt()` (Callaway & Sant'Anna's official package) - Stata: `csdid` @@ -430,6 +433,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - [ ] Aggregations: simple, event_study, group all implemented - [ ] Doubly robust estimation when covariates provided - [ ] Multiplier bootstrap preserves panel structure +- [x] Repeated cross-sections (`panel=False`) for non-panel surveys (Phase 7b) --- @@ -1630,6 +1634,7 @@ Confidence intervals: - Breakdown point: smallest M where CI includes zero - M=0: reduces to standard parallel trends - Negative M: not valid (constraints become infeasible) +- **Note:** Phase 7d: survey variance support. When input results carry `survey_metadata` with `df_survey`, HonestDiD uses t-distribution critical values (via `_get_critical_value(alpha, df)`) instead of normal. CallawaySantAnnaResults now stores `event_study_vcov` (full cross-event-time VCV from IF vectors), which HonestDiD uses instead of the diagonal fallback. For replicate-weight designs, the event-study VCV falls back to diagonal (multivariate replicate VCV deferred). **Reference implementation(s):** - R: `HonestDiD` package (Rambachan & Roth's official package) diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index a521d4ef..41f1f358 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -72,10 +72,10 @@ TripleDifference IPW/DR from Phase 3 deferred work. ### Phase 5 Deferred Work -| Estimator | Deferred Capability | Blocker | -|-----------|-------------------|---------| -| SyntheticDiD | strata/PSU/FPC + survey-aware bootstrap | Bootstrap + Survey Interaction | -| TROP | strata/PSU/FPC + survey-aware bootstrap | Bootstrap + Survey Interaction | +| Estimator | Deferred Capability | Status | +|-----------|-------------------|--------| +| SyntheticDiD | strata/PSU/FPC + survey-aware bootstrap | Resolved in Phase 6 | +| TROP | strata/PSU/FPC + survey-aware bootstrap | Resolved in Phase 6 | ## Phase 6: Advanced Features @@ -104,7 +104,8 @@ JKn requires explicit `replicate_strata` (per-replicate stratum assignment). TripleDifference (analytical only, no bootstrap). Rejected with `NotImplementedError` in DifferenceInDifferences, TwoWayFixedEffects, MultiPeriodDiD, StackedDiD, SunAbraham, ImputationDiD, TwoStageDiD, - SyntheticDiD, TROP. + SyntheticDiD, TROP. Expansion to regression-based estimators (SA, + Imputation, TwoStage, Stacked) is straightforward but deferred. ### DEFF Diagnostics ✅ (2026-03-26) Per-coefficient design effects comparing survey vcov to SRS (HC1) vcov. @@ -120,3 +121,115 @@ estimation (unlike simple subsetting, which would drop design information). - Mask: bool array/Series, column name, or callable - Returns (SurveyDesign, DataFrame) pair with synthetic `_subpop_weight` column - Weight validation relaxed: zero weights allowed (negative still rejected) + +--- + +## Phase 7: Completing the Survey Story + +These items close the remaining gaps that matter for practitioners using major +population surveys (ACS, CPS, BRFSS, MEPS) with modern DiD methods. Together +they make diff-diff the only package — R or Python — with full design-based +variance estimation for heterogeneity-robust DiD estimators. + +### 7a. CallawaySantAnna Covariates + IPW/DR + Survey ✅ + +**Priority: High.** This is the single highest-impact gap. The Callaway-Sant'Anna +`reg` method with covariates already works under survey designs, but the +recommended IPW and DR methods raise `NotImplementedError`. Most applied work +(Medicaid expansion, minimum wage studies) uses DR with covariates. + +**What's needed:** +- Implement DRDID panel nuisance-estimation influence function corrections + under survey weights (Sant'Anna & Zhao 2020, Theorem 3.1) +- Survey-weighted propensity score estimation via `solve_logit()` (already + available from Phase 4) +- Survey-weighted outcome regression for imputation step +- Correct IF that accounts for nuisance parameter estimation uncertainty + under the survey design +- Thread `ResolvedSurveyDesign` through the IPW and DR paths in + `_estimate_att_gt()` + +**Reference:** Sant'Anna, P.H.C. & Zhao, J. (2020). "Doubly Robust +Difference-in-Differences Estimators." *Journal of Econometrics* 219(1). + +**Current gate:** `staggered.py` — `NotImplementedError` when +`estimation_method in ('ipw', 'dr')` and covariates are provided with +`survey_design`. + +### 7b. Repeated Cross-Sections ✅ + +**Priority: High.** Many major surveys (BRFSS, ACS annual cross-sections, +CPS monthly) are not panels — units are not followed over time. The R `did` +package supports `panel=FALSE` for these settings. diff-diff currently +requires panel data for all staggered estimators. + +**What's needed:** +- `panel` parameter on `CallawaySantAnna` (default `True`, set `False` for + repeated cross-sections) +- Repeated cross-section ATT(g,t) estimation using cross-sectional DRDID + (Sant'Anna & Zhao 2020, Section 4) +- Cross-sectional propensity score: model P(G=g | X) instead of panel + first-difference +- Cross-sectional outcome regression: model E[Y | X, G, T] instead of + E[ΔY | X, G] +- Influence functions for cross-sectional case (different from panel IF) +- Survey weight support for repeated cross-sections (weights apply per + observation, no unit-level collapse) +- Update `generate_did_data()` / `generate_staggered_data()` with + `panel=False` option for testing + +**Reference:** Sant'Anna, P.H.C. & Zhao, J. (2020). Sections 3 (panel) vs +4 (repeated cross-sections). Callaway, B. & Sant'Anna, P.H.C. (2021). +Section 4.1. + +**Scope:** CallawaySantAnna only. Other staggered estimators (SunAbraham, +ImputationDiD, TwoStageDiD, StackedDiD) are inherently panel methods. + +### 7c. Survey-Aware DiD Tutorial + +**Priority: High.** diff-diff is the only package (R or Python) with +design-based variance estimation for modern DiD estimators, but no one +knows this. A tutorial demonstrating the full workflow with realistic +survey data would make the capability discoverable. + +**What's needed:** +- Jupyter notebook: `docs/tutorials/16_survey_did.ipynb` +- Sections: + 1. Why survey design matters for DiD (variance inflation from clustering, + weight effects on point estimates — cite Solon, Haider & Wooldridge 2015) + 2. Setting up `SurveyDesign` (weights, strata, PSU, FPC) + 3. Basic DiD with survey design (compare naive vs. design-based SEs) + 4. Staggered DiD with survey weights (CallawaySantAnna) + 5. Replicate weights workflow (BRR/JKn for MEPS/ACS PUMS users) + 6. Subpopulation analysis + 7. DEFF diagnostics — interpreting design effects + 8. Comparison: show that R's `did` package with `weightsname` gives + survey-naive variance while diff-diff gives design-based variance +- Use realistic synthetic data mimicking ACS/CPS structure (stratified + multi-stage design with known treatment effect) +- Cross-reference from README, choosing_estimator.rst, and quickstart.rst + +### 7d. HonestDiD with Survey Variance ✅ + +**Priority: Medium.** Sensitivity analysis (Rambachan & Roth 2023) currently +uses cluster-robust or HC variance. Under complex survey designs, the +variance-covariance matrix should come from TSL or replicate weights. + +**What's needed:** +- Accept optional `survey_design` parameter in `HonestDiD` +- When provided, use survey vcov matrix instead of cluster-robust vcov + for computing sensitivity bounds +- Degrees of freedom from survey design (n_PSU - n_strata) for + t-distribution critical values +- Propagate through all three methods: relative magnitudes, smoothness, + and conditional least favorable (C-LF) +- The core optimization (LP/QP for bounds) is unchanged — only the input + vcov and df change + +**Reference:** Rambachan, A. & Roth, J. (2023). "A More Credible Approach +to Parallel Trends." *Review of Economic Studies* 90(5). + +**Why it matters:** A practitioner who runs CS with survey design but then +runs HonestDiD sensitivity analysis with cluster-robust SEs gets +inconsistent inference. The sensitivity bounds should respect the same +variance structure as the main estimates. diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index 050ae27b..c757b545 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -292,7 +292,7 @@ class TestParameterExtraction: def test_extract_from_multiperiod(self, mock_multiperiod_results): """Test extraction from MultiPeriodDiDResults.""" - (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods, _df) = ( _extract_event_study_params(mock_multiperiod_results) ) @@ -667,7 +667,7 @@ def test_multiperiod_sub_vcov_extraction(self, simple_panel_data): reference_period=3, ) - (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods, _df) = ( _extract_event_study_params(results) ) @@ -890,7 +890,9 @@ def test_honest_did_filters_nan_pre_period_effects(self): ) # _extract_event_study_params should filter out period 0 (NaN) - beta_hat, sigma, num_pre, num_post, pre_p, post_p = _extract_event_study_params(results) + beta_hat, sigma, num_pre, num_post, pre_p, post_p, _df = _extract_event_study_params( + results + ) assert len(beta_hat) == 3 # periods 1, 3, 4 (period 0 filtered) assert num_pre == 1 # only period 1 assert num_post == 2 # periods 3, 4 @@ -1090,7 +1092,9 @@ def test_honest_did_nonmonotone_period_labels(self): interaction_indices=interaction_indices, ) - beta_hat, sigma, num_pre, num_post, pre_p, post_p = _extract_event_study_params(results) + beta_hat, sigma, num_pre, num_post, pre_p, post_p, _df = _extract_event_study_params( + results + ) # Pre-periods: 5, 6 (7 is reference, omitted) assert num_pre == 2 @@ -1110,6 +1114,136 @@ def test_honest_did_nonmonotone_period_labels(self): assert sigma[3, 3] == pytest.approx(0.2025) # period 2 variance +# ============================================================================= +# Tests for Phase 7d: Survey variance support +# ============================================================================= + + +class TestSurveyVariance: + """Tests for HonestDiD survey variance support (Phase 7d).""" + + def test_df_survey_extracted_from_cs_results(self): + """df_survey is extracted from CallawaySantAnna survey_metadata.""" + from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + data["weight"] = (1.0 + 0.5 * (np.arange(n_units) % 4))[idx] + data["stratum"] = (np.arange(n_units) // 25)[idx] + data["psu"] = (np.arange(n_units) // 5)[idx] + + sd = SurveyDesign(weights="weight", strata="stratum", psu="psu") + cs_result = CallawaySantAnna().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + aggregate="event_study", + ) + + honest = HonestDiD(method="smoothness", M=0.0) + h_result = honest.fit(cs_result) + + # df_survey should be propagated + assert h_result.df_survey is not None + assert h_result.df_survey > 0 + assert h_result.survey_metadata is not None + + def test_event_study_vcov_computed(self): + """CallawaySantAnna event_study_vcov is computed and used by HonestDiD.""" + from diff_diff import CallawaySantAnna, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=6, seed=42) + cs_result = CallawaySantAnna().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + + # VCV should be computed + assert cs_result.event_study_vcov is not None + n_effects = len( + [ + e + for e, d in cs_result.event_study_effects.items() + if d.get("n_groups", 1) > 0 and np.isfinite(d.get("se", np.nan)) + ] + ) + assert cs_result.event_study_vcov.shape == (n_effects, n_effects) + + # Diagonal should match squared SEs + for i, e in enumerate(sorted(cs_result.event_study_effects.keys())): + info = cs_result.event_study_effects[e] + if info.get("n_groups", 1) > 0 and np.isfinite(info.get("se", np.nan)): + # VCV diagonal should be close to SE^2 (not exact due to IF aggregation) + assert cs_result.event_study_vcov[i, i] > 0 + + def test_survey_df_widens_bounds(self): + """Survey df (t-distribution) should give wider CIs than normal.""" + from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + data["weight"] = (1.0 + 0.3 * (np.arange(n_units) % 3))[idx] + # 2 strata, 4 PSUs total -> df = 4 - 2 = 2 (very low df) + data["stratum"] = (np.arange(n_units) // 50)[idx] + data["psu"] = (np.arange(n_units) // 25)[idx] + + sd = SurveyDesign(weights="weight", strata="stratum", psu="psu") + cs_result = CallawaySantAnna().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + aggregate="event_study", + ) + + honest = HonestDiD(method="smoothness", M=0.0) + h_result = honest.fit(cs_result) + + # With df=2, t critical value (~4.3) >> z critical value (1.96) + # So CI width should be wider than 2*1.96*SE + ci_width = h_result.ci_ub - h_result.ci_lb + # Lower bound: normal-based CI width + normal_width = 2 * 1.96 * h_result.original_se + assert ci_width > normal_width + + def test_no_survey_gives_none_df(self): + """Without survey, df_survey should be None.""" + from diff_diff import CallawaySantAnna, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + cs_result = CallawaySantAnna().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + + honest = HonestDiD(method="smoothness", M=0.0) + h_result = honest.fit(cs_result) + + assert h_result.df_survey is None + assert h_result.survey_metadata is None + + # ============================================================================= # Tests for Visualization (without matplotlib) # ============================================================================= diff --git a/tests/test_staggered_rc.py b/tests/test_staggered_rc.py new file mode 100644 index 00000000..777eaa45 --- /dev/null +++ b/tests/test_staggered_rc.py @@ -0,0 +1,351 @@ +"""Tests for Phase 7b: CallawaySantAnna repeated cross-section support. + +Covers: panel=False mode with reg/ipw/dr, covariates, survey weights, +aggregation, bootstrap, control group options, base period options, +and edge cases. +""" + +import numpy as np +import pytest + +from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def rc_data(): + """Basic repeated cross-section data.""" + return generate_staggered_data(n_units=200, n_periods=6, panel=False, seed=42) + + +@pytest.fixture(scope="module") +def rc_data_with_covariates(): + """RCS data with a covariate.""" + data = generate_staggered_data(n_units=200, n_periods=6, panel=False, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.normal(0, 1, len(data)) + return data + + +@pytest.fixture(scope="module") +def rc_data_with_survey(): + """RCS data with survey weights.""" + data = generate_staggered_data(n_units=200, n_periods=6, panel=False, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.normal(0, 1, len(data)) + data["weight"] = rng.uniform(0.5, 2.0, len(data)) + return data + + +# ============================================================================= +# Basic Fit +# ============================================================================= + + +class TestBasicFit: + """Basic repeated cross-section fit tests.""" + + def test_basic_reg(self, rc_data): + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + @pytest.mark.parametrize("method", ["reg", "ipw", "dr"]) + def test_all_methods(self, rc_data, method): + result = CallawaySantAnna(estimation_method=method, panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + def test_panel_param_in_get_params(self): + cs = CallawaySantAnna(panel=False) + params = cs.get_params() + assert params["panel"] is False + + +# ============================================================================= +# Methods Agree Without Covariates +# ============================================================================= + + +class TestMethodsAgreeNoCovariates: + """Without covariates, reg/ipw/dr should give identical ATTs in RCS.""" + + def test_no_covariate_methods_agree(self, rc_data): + results = {} + for method in ["reg", "ipw", "dr"]: + r = CallawaySantAnna(estimation_method=method, panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + results[method] = r.overall_att + + np.testing.assert_allclose(results["reg"], results["ipw"], atol=1e-10) + np.testing.assert_allclose(results["reg"], results["dr"], atol=1e-10) + + +# ============================================================================= +# Treatment Effect Recovery +# ============================================================================= + + +class TestTreatmentEffectRecovery: + """Known DGP should recover approximately correct treatment effect.""" + + def test_positive_effect(self, rc_data): + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + # DGP has treatment_effect=2.0 by default + assert result.overall_att > 0 + assert abs(result.overall_att - 2.0) < 2.0 # within 2 SE roughly + + +# ============================================================================= +# Aggregation +# ============================================================================= + + +class TestAggregation: + """Aggregation types work with RCS.""" + + @pytest.mark.parametrize("agg", ["simple", "event_study", "group"]) + def test_aggregation(self, rc_data, agg): + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate=agg, + ) + assert np.isfinite(result.overall_att) + if agg == "event_study": + assert result.event_study_effects is not None + for e, info in result.event_study_effects.items(): + if info["n_groups"] > 0: + assert np.isfinite(info["effect"]) + if agg == "group": + assert result.group_effects is not None + + +# ============================================================================= +# Covariates +# ============================================================================= + + +class TestCovariates: + """Covariate-adjusted estimation in RCS.""" + + @pytest.mark.parametrize("method", ["reg", "ipw", "dr"]) + def test_with_covariates(self, rc_data_with_covariates, method): + result = CallawaySantAnna(estimation_method=method, panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + +# ============================================================================= +# Survey Weights +# ============================================================================= + + +class TestSurveyWeights: + """Survey weights work with RCS (per-observation).""" + + def test_survey_weights_pweight(self, rc_data_with_survey): + sd = SurveyDesign(weights="weight") + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data_with_survey, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_survey_covariates_dr(self, rc_data_with_survey): + """Combined: survey + covariates + DR + RCS.""" + sd = SurveyDesign(weights="weight") + result = CallawaySantAnna(estimation_method="dr", panel=False).fit( + rc_data_with_survey, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + +# ============================================================================= +# Control Group Options +# ============================================================================= + + +class TestControlGroup: + """Control group options work with RCS.""" + + def test_not_yet_treated(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, control_group="not_yet_treated" + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + def test_never_treated(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, control_group="never_treated" + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Base Period Options +# ============================================================================= + + +class TestBasePeriod: + """Base period options work with RCS.""" + + def test_universal(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, base_period="universal" + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + def test_varying(self, rc_data): + result = CallawaySantAnna(estimation_method="reg", panel=False, base_period="varying").fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Bootstrap +# ============================================================================= + + +class TestBootstrap: + """Bootstrap works with RCS.""" + + def test_bootstrap_reg(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, n_bootstrap=49, seed=42 + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert result.bootstrap_results is not None + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Data Generator +# ============================================================================= + + +class TestDataGenerator: + """Test the RCS data generator.""" + + def test_rc_data_structure(self): + data = generate_staggered_data(n_units=100, n_periods=5, panel=False, seed=99) + # Each observation should have a unique unit ID + assert data["unit"].nunique() == len(data) + # Should have n_units * n_periods rows + assert len(data) == 100 * 5 + # Each period should have n_units observations + assert all(data.groupby("period")["unit"].count() == 100) + + def test_panel_data_unchanged(self): + """panel=True (default) should produce panel data.""" + data = generate_staggered_data(n_units=50, n_periods=4, panel=True, seed=42) + # Units should repeat across periods + assert data["unit"].nunique() < len(data) + assert data["unit"].nunique() == 50 + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge cases for RCS.""" + + def test_empty_cell_nan(self): + """(g,t) cell with no observations should be NaN, not crash.""" + data = generate_staggered_data(n_units=50, n_periods=4, panel=False, seed=42) + # This should handle cells with few/no observations gracefully + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + ) + # Should produce at least some finite effects + finite_effects = [ + v["effect"] for v in result.group_time_effects.values() if np.isfinite(v["effect"]) + ] + assert len(finite_effects) > 0 diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index f4781b69..b8ca4dc5 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -926,35 +926,39 @@ def test_bootstrap_survey_supported(self, staggered_survey_data, survey_design_w assert np.isfinite(result.overall_att) assert np.isfinite(result.overall_se) - def test_ipw_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """IPW + covariates + survey should raise NotImplementedError.""" + def test_ipw_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """IPW + covariates + survey works (Phase 7a: nuisance IF corrections).""" data = staggered_survey_data.copy() data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - with pytest.raises(NotImplementedError, match="covariates"): - CallawaySantAnna(estimation_method="ipw").fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - survey_design=survey_design_weights_only, - ) + result = CallawaySantAnna(estimation_method="ipw").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 - def test_dr_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """DR + covariates + survey should raise NotImplementedError.""" + def test_dr_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """DR + covariates + survey works (Phase 7a: nuisance IF corrections).""" data = staggered_survey_data.copy() data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - with pytest.raises(NotImplementedError, match="covariates"): - CallawaySantAnna(estimation_method="dr").fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - survey_design=survey_design_weights_only, - ) + result = CallawaySantAnna(estimation_method="dr").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): """Regression + covariates + survey should work (has nuisance IF correction).""" @@ -1517,13 +1521,22 @@ class TestCallawaySantAnnaFullDesignBootstrap: def test_bootstrap_full_design_cs(self, staggered_survey_data): """Bootstrap + full survey (strata+PSU+FPC) works for CS.""" sd = SurveyDesign( - weights="weight", strata="stratum", psu="psu", fpc="fpc", + weights="weight", + strata="stratum", + psu="psu", + fpc="fpc", ) result = CallawaySantAnna( - estimation_method="reg", n_bootstrap=30, seed=42, + estimation_method="reg", + n_bootstrap=30, + seed=42, ).fit( - staggered_survey_data, "outcome", "unit", "period", - "first_treat", survey_design=sd, + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, ) assert np.isfinite(result.overall_att) assert np.isfinite(result.overall_se) diff --git a/tests/test_survey_phase7a.py b/tests/test_survey_phase7a.py new file mode 100644 index 00000000..3c71efc8 --- /dev/null +++ b/tests/test_survey_phase7a.py @@ -0,0 +1,389 @@ +"""Tests for Phase 7a: CallawaySantAnna IPW/DR + covariates + survey support. + +Covers: DR nuisance IF corrections (PS + OR), IPW unblocking, +scale invariance, uniform-weight equivalence, aggregation, bootstrap, +and edge cases. +""" + +import numpy as np +import pytest + +from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def staggered_data_with_covariates(): + """Staggered panel with a covariate and survey columns.""" + data = generate_staggered_data(n_units=200, n_periods=6, seed=42) + rng = np.random.default_rng(42) + + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + # Covariate: unit-level (constant within unit, varies across units) + unit_x = rng.normal(0, 1, n_units) + data["x1"] = unit_x[idx] + + # Survey design columns (unit-level) + data["weight"] = (1.0 + 0.5 * (np.arange(n_units) % 5))[idx] + data["stratum"] = (np.arange(n_units) // 40)[idx] + data["psu"] = (np.arange(n_units) // 10)[idx] + + return data + + +@pytest.fixture(scope="module") +def survey_weights_only(): + return SurveyDesign(weights="weight") + + +@pytest.fixture(scope="module") +def survey_full_design(): + return SurveyDesign(weights="weight", strata="stratum", psu="psu") + + +# ============================================================================= +# Smoke Tests: IPW and DR with covariates + survey produce finite results +# ============================================================================= + + +class TestSmokeIPWDRSurvey: + """Basic smoke tests that IPW/DR + covariates + survey runs and returns + finite results.""" + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_finite_results(self, staggered_data_with_covariates, survey_weights_only, method): + result = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_event_study(self, staggered_data_with_covariates, survey_weights_only, method): + result = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + aggregate="event_study", + ) + assert result.event_study_effects is not None + for e, info in result.event_study_effects.items(): + if info["n_groups"] > 0: + assert np.isfinite(info["effect"]) + assert np.isfinite(info["se"]) + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_full_design(self, staggered_data_with_covariates, survey_full_design, method): + """Strata/PSU design with covariates + IPW/DR.""" + result = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_full_design, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + +# ============================================================================= +# Scale Invariance: doubling all weights doesn't change ATT +# ============================================================================= + + +class TestScaleInvariance: + """Multiplying all survey weights by a constant should not change ATT.""" + + @pytest.mark.parametrize("method", ["ipw", "dr", "reg"]) + def test_double_weights_same_att(self, staggered_data_with_covariates, method): + data = staggered_data_with_covariates.copy() + + # Fit with original weights + sd1 = SurveyDesign(weights="weight") + r1 = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd1, + ) + + # Fit with doubled weights + data["weight2"] = data["weight"] * 2 + sd2 = SurveyDesign(weights="weight2") + r2 = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd2, + ) + + np.testing.assert_allclose(r1.overall_att, r2.overall_att, atol=1e-10) + + +# ============================================================================= +# Uniform Weights Match Unweighted +# ============================================================================= + + +class TestUniformWeightsMatchUnweighted: + """Survey weights = 1.0 for all should match the no-survey path.""" + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_uniform_weights(self, method): + data = generate_staggered_data(n_units=100, n_periods=5, seed=123) + rng = np.random.default_rng(123) + data["x1"] = rng.normal(0, 1, len(data)) + data["weight_ones"] = 1.0 + + # No survey + r_no_survey = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + + # Survey with uniform weights + sd = SurveyDesign(weights="weight_ones") + r_survey = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd, + ) + + # ATTs should be very close (not exact due to weight normalization) + np.testing.assert_allclose(r_no_survey.overall_att, r_survey.overall_att, rtol=1e-6) + + +# ============================================================================= +# IF Correction Non-Zero: corrected IF differs from plug-in +# ============================================================================= + + +class TestIFCorrectionNonZero: + """The DR nuisance IF corrections should make a difference (non-trivial).""" + + def test_dr_se_differs_from_reg(self, staggered_data_with_covariates, survey_weights_only): + """DR and reg methods should give different SEs (DR has PS correction).""" + r_reg = CallawaySantAnna(estimation_method="reg").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + r_dr = CallawaySantAnna(estimation_method="dr").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + # SEs should differ because DR has additional IF corrections + assert r_reg.overall_se != r_dr.overall_se + + +# ============================================================================= +# All 3 Methods Agree Without Covariates +# ============================================================================= + + +class TestMethodsAgreeNoCovariates: + """Without covariates, reg/ipw/dr should give identical ATTs under survey.""" + + def test_no_covariate_methods_agree(self, staggered_data_with_covariates, survey_weights_only): + results = {} + for method in ["reg", "ipw", "dr"]: + r = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_weights_only, + ) + results[method] = r.overall_att + + np.testing.assert_allclose(results["reg"], results["ipw"], atol=1e-10) + np.testing.assert_allclose(results["reg"], results["dr"], atol=1e-10) + + +# ============================================================================= +# Aggregation +# ============================================================================= + + +class TestAggregation: + """Aggregation types work with DR + covariates + survey.""" + + @pytest.mark.parametrize("agg", ["simple", "event_study", "group"]) + def test_aggregation(self, staggered_data_with_covariates, survey_weights_only, agg): + result = CallawaySantAnna(estimation_method="dr").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + aggregate=agg, + ) + assert np.isfinite(result.overall_att) + if agg == "event_study": + assert result.event_study_effects is not None + if agg == "group": + assert result.group_effects is not None + + +# ============================================================================= +# Bootstrap +# ============================================================================= + + +class TestBootstrap: + """Bootstrap with IPW/DR + covariates + survey.""" + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_bootstrap_runs(self, staggered_data_with_covariates, survey_weights_only, method): + result = CallawaySantAnna(estimation_method=method, n_bootstrap=49, seed=42).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + assert result.bootstrap_results is not None + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge cases for IPW/DR + survey + covariates.""" + + def test_single_covariate(self): + """Works with a single binary covariate.""" + data = generate_staggered_data(n_units=100, seed=99) + rng = np.random.default_rng(99) + data["binary_x"] = rng.choice([0, 1], len(data)) + data["weight"] = rng.uniform(0.5, 2.0, len(data)).round(1) + # Make weights constant within unit + unit_w = data.groupby("unit")["weight"].first() + data["weight"] = data["unit"].map(unit_w) + + sd = SurveyDesign(weights="weight") + for method in ["ipw", "dr"]: + result = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["binary_x"], + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + + def test_not_yet_treated_control(self, staggered_data_with_covariates, survey_weights_only): + """control_group='not_yet_treated' with DR + covariates + survey.""" + result = CallawaySantAnna(estimation_method="dr", control_group="not_yet_treated").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + + def test_universal_base_period(self, staggered_data_with_covariates, survey_weights_only): + """base_period='universal' with DR + covariates + survey.""" + result = CallawaySantAnna(estimation_method="dr", base_period="universal").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Non-Survey DR IF Corrections +# ============================================================================= + + +class TestNonSurveyDRIFCorrections: + """Verify that non-survey DR path also has IF corrections.""" + + def test_dr_se_differs_from_reg_no_survey(self): + """DR and reg should give different SEs without survey (IF corrections).""" + data = generate_staggered_data(n_units=150, n_periods=6, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.normal(0, 1, len(data)) + + r_reg = CallawaySantAnna(estimation_method="reg").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + r_dr = CallawaySantAnna(estimation_method="dr").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + # SEs should differ (DR has nuisance IF corrections) + assert r_reg.overall_se != r_dr.overall_se + # But ATTs should be similar (both consistent under correct specification) + assert abs(r_reg.overall_att - r_dr.overall_att) < 1.0 From 4bf566d1f94ebca6500b5ffe806c87302aa9bb9a Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 28 Mar 2026 20:14:22 -0400 Subject: [PATCH 02/19] Address CI review: RCS IF corrections, aggregation weights, replicate VCV, panel on results Fix 5 findings from PR #240 CI review: - Add cross-sectional nuisance IF corrections (PS + OR) to _ipw_estimation_rc and _doubly_robust_rc, matching panel path methodology - Use fixed full-sample cohort masses for unweighted RCS aggregation weights (consistency with WIF group-share denominator) - Guard replicate-weight designs from full event-study VCV (diagonal fallback) - Add panel field to CallawaySantAnnaResults, fix summary labels for RCS - Add panel to class docstring, replicate VCV test, RCS IF correction test Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 106 +++++++++++++++++++++++++++++ diff_diff/staggered_aggregation.py | 52 ++++++++++++-- diff_diff/staggered_results.py | 5 +- tests/test_honest_did.py | 37 ++++++++++ tests/test_staggered_rc.py | 54 +++++++++++++++ 5 files changed, 245 insertions(+), 9 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index e6836ba0..9d714520 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -181,6 +181,13 @@ class CallawaySantAnna( Trimming bound for propensity scores. Scores are clipped to ``[pscore_trim, 1 - pscore_trim]`` before weight computation in IPW and DR estimation. Must be in ``(0, 0.5)``. + panel : bool, default=True + Whether the data is a balanced/unbalanced panel (units observed + across multiple time periods). Set to ``False`` for repeated + cross-sections where each observation has a unique unit ID and + units do not repeat across periods. Uses cross-sectional DRDID + (Sant'Anna & Zhao 2020, Section 4) with per-observation influence + functions. Attributes ---------- @@ -1783,6 +1790,7 @@ def fit( pscore_trim=self.pscore_trim, survey_metadata=survey_metadata, event_study_vcov=event_study_vcov, + panel=self.panel, ) self.is_fitted_ = True @@ -2972,6 +2980,40 @@ def _ipw_estimation_rc( inf_control = np.concatenate([inf_ct, inf_cs]) inf_all = np.concatenate([inf_treated, inf_control]) + # PS IF correction for cross-sectional IPW + X_all_int = np.column_stack([np.ones(len(D_all)), X_all]) + pscore_all = pscore # already computed and clipped + + W_ps = pscore_all * (1 - pscore_all) + if sw_all is not None: + W_ps = W_ps * sw_all + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps_inv = _safe_inv(H_ps) + + score_ps = (D_all - pscore_all)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) + + # M2: gradient of IPW ATT w.r.t. PS parameters + # Control IPW residuals from both periods + ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw) + ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw) + # Zero for treated observations + M2_rc = np.zeros(X_all_int.shape[1]) + # Control-t contribution + M2_rc += np.mean( + ipw_resid_ct[:, None] * X_all_int[n_gt + n_gs : n_gt + n_gs + n_ct], + axis=0, + ) + # Control-s contribution (opposite sign -- base period) + M2_rc -= np.mean( + ipw_resid_cs[:, None] * X_all_int[n_gt + n_gs + n_ct :], + axis=0, + ) + + inf_all = inf_all + asy_lin_rep_ps @ M2_rc + se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None @@ -3121,6 +3163,70 @@ def _doubly_robust_rc( inf_control = np.concatenate([inf_ct, inf_cs]) inf_all = np.concatenate([inf_treated, inf_control]) + # --- PS IF correction --- + X_all_int = np.column_stack([np.ones(len(D_all)), X_all]) + pscore_all = pscore + + W_ps = pscore_all * (1 - pscore_all) + if sw_all is not None: + W_ps = W_ps * sw_all + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps_inv = _safe_inv(H_ps) + + score_ps = (D_all - pscore_all)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_ps_inv + + # M2_dr: uses DR residuals (m-y) instead of raw y + dr_resid_ct = m_ct - y_ct # control period-t DR residuals + dr_resid_cs = m_cs - y_cs # control period-s DR residuals + normalizer = np.sum(sw_gt) if sw_gt is not None else n_gt + M2_dr = np.zeros(X_all_int.shape[1]) + # Control-t: (w_ct/normalizer) * (m_ct - y_ct) * X + ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) + M2_dr += np.mean( + ((w_ct / normalizer) * dr_resid_ct)[:, None] * X_all_int[ct_slice], + axis=0, + ) + # Control-s: -(w_cs/normalizer) * (m_cs - y_cs) * X (opposite sign) + cs_slice = slice(n_gt + n_gs + n_ct, None) + M2_dr -= np.mean( + ((w_cs / normalizer) * dr_resid_cs)[:, None] * X_all_int[cs_slice], + axis=0, + ) + + inf_all = inf_all + asy_lin_rep_ps @ M2_dr + + # --- OR IF correction -- period t model --- + W_t = sw_ct if sw_ct is not None else np.ones(n_ct) + bread_t = _safe_inv(X_ct_int.T @ (W_t[:, None] * X_ct_int)) + + # M1_t: dATT/dbeta_t (from treated-t prediction and control-t augmentation) + sw_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) + M1_t = ( + -np.sum(sw_gt_vals[:, None] * X_gt_int, axis=0) + + np.sum(w_ct[:, None] * X_ct_int, axis=0) + ) / normalizer + + asy_lin_rep_or_t = (W_t * (y_ct - m_ct))[:, None] * X_ct_int @ bread_t + # Apply only to control-t portion of inf_all + inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_or_t @ M1_t + + # --- OR IF correction -- period s model --- + W_s = sw_cs if sw_cs is not None else np.ones(n_cs) + bread_s = _safe_inv(X_cs_int.T @ (W_s[:, None] * X_cs_int)) + + sw_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs) + M1_s = ( + np.sum(sw_gs_vals[:, None] * X_gs_int, axis=0) + - np.sum(w_cs[:, None] * X_cs_int, axis=0) + ) / normalizer + + asy_lin_rep_or_s = (W_s * (y_cs - m_cs))[:, None] * X_cs_int @ bread_s + # Apply only to control-s portion of inf_all + inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_or_s @ M1_s + se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 23632764..b8f1d7b1 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -73,15 +73,31 @@ def _aggregate_simple( if g > 0: # exclude never-treated (0) survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g])) + # For unweighted RCS: use fixed full-sample cohort counts so that + # aggregation weights match the WIF group-share denominator. + rcs_cohort_counts = None + if ( + precomputed is not None + and not precomputed.get("is_panel", True) + and survey_cohort_weights is None + ): + unit_cohorts = precomputed["unit_cohorts"] + rcs_cohort_counts = {} + for g in np.unique(unit_cohorts): + if g > 0: + rcs_cohort_counts[g] = int(np.sum(unit_cohorts == g)) + for (g, t), data in group_time_effects.items(): # Only include post-treatment effects (t >= g - anticipation) # Pre-treatment effects are for parallel trends, not overall ATT if t < g - self.anticipation: continue effects.append(data["effect"]) - # Use fixed cohort-level survey weight sum for aggregation + # Use fixed cohort-level weights for aggregation if survey_cohort_weights is not None and g in survey_cohort_weights: weights_list.append(survey_cohort_weights[g]) + elif rcs_cohort_counts is not None and g in rcs_cohort_counts: + weights_list.append(rcs_cohort_counts[g]) else: weights_list.append(data["n_treated"]) gt_pairs.append((g, t)) @@ -571,15 +587,29 @@ def _aggregate_event_study( if g > 0: survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g])) + # For unweighted RCS: fixed full-sample cohort counts (matching WIF) + rcs_cohort_counts = None + if ( + precomputed is not None + and not precomputed.get("is_panel", True) + and survey_cohort_weights is None + ): + unit_cohorts_es = precomputed["unit_cohorts"] + rcs_cohort_counts = {} + for g in np.unique(unit_cohorts_es): + if g > 0: + rcs_cohort_counts[g] = int(np.sum(unit_cohorts_es == g)) + for (g, t), data in group_time_effects.items(): e = t - g # Relative time if e not in effects_by_e: effects_by_e[e] = [] - w = ( - survey_cohort_weights[g] - if survey_cohort_weights is not None and g in survey_cohort_weights - else data["n_treated"] - ) + if survey_cohort_weights is not None and g in survey_cohort_weights: + w = survey_cohort_weights[g] + elif rcs_cohort_counts is not None and g in rcs_cohort_counts: + w = rcs_cohort_counts[g] + else: + w = data["n_treated"] effects_by_e[e].append( ( (g, t), # Keep track of the (g,t) pair @@ -733,8 +763,16 @@ def _aggregate_event_study( meat, _, _ = _compute_stratified_psu_meat(Psi, resolved_survey) event_study_vcov = meat + elif ( + resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): + # Replicate-weight: fall back to None (diagonal in HonestDiD) + # until multivariate replicate VCV is implemented + event_study_vcov = None else: - # Simple sum-of-outer-products (no survey or replicate-only) + # No survey: simple sum-of-outer-products event_study_vcov = Psi.T @ Psi except (ValueError, np.linalg.LinAlgError): pass # Fall back to diagonal (None) diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index d4f23429..b21af0df 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -111,6 +111,7 @@ class CallawaySantAnnaResults: alpha: float = 0.05 control_group: str = "never_treated" base_period: str = "varying" + panel: bool = True event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None) influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) @@ -155,8 +156,8 @@ def summary(self, alpha: Optional[float] = None) -> str: "=" * 85, "", f"{'Total observations:':<30} {self.n_obs:>10}", - f"{'Treated units:':<30} {self.n_treated_units:>10}", - f"{'Never-treated units:':<30} {self.n_control_units:>10}", + f"{'Treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_treated_units:>10}", + f"{'Control ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_control_units:>10}", f"{'Treatment cohorts:':<30} {len(self.groups):>10}", f"{'Time periods:':<30} {len(self.time_periods):>10}", f"{'Control group:':<30} {self.control_group:>10}", diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index c757b545..1a36ce09 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -1243,6 +1243,43 @@ def test_no_survey_gives_none_df(self): assert h_result.df_survey is None assert h_result.survey_metadata is None + def test_replicate_weight_uses_diagonal_fallback(self): + """Replicate-weight designs should NOT produce full event_study_vcov.""" + from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + # Create replicate weights (4 replicates) + rng = np.random.default_rng(42) + data["weight"] = (1.0 + 0.3 * (np.arange(n_units) % 3))[idx] + for k in range(4): + data[f"repwt_{k}"] = data["weight"] * rng.uniform(0.8, 1.2, len(data)) + # Make constant within unit + unit_rw = data.groupby("unit")[f"repwt_{k}"].first() + data[f"repwt_{k}"] = data["unit"].map(unit_rw) + + sd = SurveyDesign( + weights="weight", + replicate_weights=[f"repwt_{k}" for k in range(4)], + replicate_method="JK1", + ) + cs_result = CallawaySantAnna().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + aggregate="event_study", + ) + + # event_study_vcov should be None (diagonal fallback for replicate designs) + assert cs_result.event_study_vcov is None + # ============================================================================= # Tests for Visualization (without matplotlib) diff --git a/tests/test_staggered_rc.py b/tests/test_staggered_rc.py index 777eaa45..34e0d4fd 100644 --- a/tests/test_staggered_rc.py +++ b/tests/test_staggered_rc.py @@ -349,3 +349,57 @@ def test_empty_cell_nan(self): v["effect"] for v in result.group_time_effects.values() if np.isfinite(v["effect"]) ] assert len(finite_effects) > 0 + + +# ============================================================================= +# Methodology: IF corrections change SE +# ============================================================================= + + +class TestIFCorrections: + """Verify RCS DR/IPW IF corrections are non-trivial.""" + + def test_dr_se_differs_from_reg_rc(self, rc_data_with_covariates): + """DR and reg should give different SEs in RCS (DR has IF corrections).""" + r_reg = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + r_dr = CallawaySantAnna(estimation_method="dr", panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + # SEs should differ (DR has nuisance IF corrections) + assert r_reg.overall_se != r_dr.overall_se + + def test_panel_field_on_results(self, rc_data): + """panel=False should be reflected on CallawaySantAnnaResults.""" + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert result.panel is False + + def test_summary_labels_rcs(self, rc_data): + """Summary should use 'obs' labels for RCS, not 'units'.""" + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + summary = result.summary() + assert "obs:" in summary + assert "units:" not in summary.split("\n")[3] # Treated line From 6080f927c1c76a760b7825b59d5f46032da64b1c Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 28 Mar 2026 20:50:47 -0400 Subject: [PATCH 03/19] Fix DR RC normalizer mismatch, holistic RCS cohort-mass weighting, unequal-count tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use separate normalizer_t/normalizer_s in _doubly_robust_rc() IF corrections (post vs base period treated denominators must match the DR estimator) - Precompute rcs_cohort_masses in _precompute_structures_rc() and return cohort mass as n_treated from _compute_att_gt_rc() — fixes all downstream consumers (aggregation, bootstrap, balance_e) at the source instead of per-consumer patches - Remove now-unnecessary rcs_cohort_counts blocks from aggregation - Add unequal cohort count test fixture and regression tests Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 32 ++++++++----- diff_diff/staggered_aggregation.py | 45 ++++------------- tests/test_staggered_rc.py | 77 ++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 48 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 9d714520..7986488c 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -2482,6 +2482,12 @@ def _precompute_structures_rc( # For RCS, the resolved survey is already per-observation resolved_survey_rc = resolved_survey + # Fixed cohort masses: total observations per cohort across all periods. + # Used as aggregation weights so that n_treated is consistent with WIF. + rcs_cohort_masses = {} + for g in treatment_groups: + rcs_cohort_masses[g] = int(np.sum(unit_cohorts == g)) + return { "all_units": all_units, "unit_to_idx": None, # RCS: obs indices are positions @@ -2504,6 +2510,7 @@ def _precompute_structures_rc( if resolved_survey_rc is not None and hasattr(resolved_survey_rc, "df_survey") else None ), + "rcs_cohort_masses": rcs_cohort_masses, } def _compute_att_gt_rc( @@ -2701,7 +2708,9 @@ def _compute_att_gt_rc( } sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None - return att, se, n_gt, n_ct, inf_func_info, sw_sum + # Use fixed cohort mass as n_treated for aggregation weight consistency + cohort_mass = precomputed.get("rcs_cohort_masses", {}).get(g, n_gt) + return att, se, cohort_mass, n_ct, inf_func_info, sw_sum def _rc_2x2_did( self, @@ -3179,20 +3188,21 @@ def _doubly_robust_rc( asy_lin_rep_ps = score_ps @ H_ps_inv # M2_dr: uses DR residuals (m-y) instead of raw y - dr_resid_ct = m_ct - y_ct # control period-t DR residuals - dr_resid_cs = m_cs - y_cs # control period-s DR residuals - normalizer = np.sum(sw_gt) if sw_gt is not None else n_gt + # Use separate normalizers for post vs base period (RCS has different + # treated counts per period — using a single normalizer mis-scales) + dr_resid_ct = m_ct - y_ct + dr_resid_cs = m_cs - y_cs + normalizer_t = np.sum(sw_gt) if sw_gt is not None else n_gt + normalizer_s = np.sum(sw_gs) if sw_gs is not None else n_gs M2_dr = np.zeros(X_all_int.shape[1]) - # Control-t: (w_ct/normalizer) * (m_ct - y_ct) * X ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) M2_dr += np.mean( - ((w_ct / normalizer) * dr_resid_ct)[:, None] * X_all_int[ct_slice], + ((w_ct / normalizer_t) * dr_resid_ct)[:, None] * X_all_int[ct_slice], axis=0, ) - # Control-s: -(w_cs/normalizer) * (m_cs - y_cs) * X (opposite sign) cs_slice = slice(n_gt + n_gs + n_ct, None) M2_dr -= np.mean( - ((w_cs / normalizer) * dr_resid_cs)[:, None] * X_all_int[cs_slice], + ((w_cs / normalizer_s) * dr_resid_cs)[:, None] * X_all_int[cs_slice], axis=0, ) @@ -3202,15 +3212,13 @@ def _doubly_robust_rc( W_t = sw_ct if sw_ct is not None else np.ones(n_ct) bread_t = _safe_inv(X_ct_int.T @ (W_t[:, None] * X_ct_int)) - # M1_t: dATT/dbeta_t (from treated-t prediction and control-t augmentation) sw_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) M1_t = ( -np.sum(sw_gt_vals[:, None] * X_gt_int, axis=0) + np.sum(w_ct[:, None] * X_ct_int, axis=0) - ) / normalizer + ) / normalizer_t asy_lin_rep_or_t = (W_t * (y_ct - m_ct))[:, None] * X_ct_int @ bread_t - # Apply only to control-t portion of inf_all inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_or_t @ M1_t # --- OR IF correction -- period s model --- @@ -3221,7 +3229,7 @@ def _doubly_robust_rc( M1_s = ( np.sum(sw_gs_vals[:, None] * X_gs_int, axis=0) - np.sum(w_cs[:, None] * X_cs_int, axis=0) - ) / normalizer + ) / normalizer_s asy_lin_rep_or_s = (W_s * (y_cs - m_cs))[:, None] * X_cs_int @ bread_s # Apply only to control-s portion of inf_all diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index b8f1d7b1..19fef7ad 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -73,31 +73,17 @@ def _aggregate_simple( if g > 0: # exclude never-treated (0) survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g])) - # For unweighted RCS: use fixed full-sample cohort counts so that - # aggregation weights match the WIF group-share denominator. - rcs_cohort_counts = None - if ( - precomputed is not None - and not precomputed.get("is_panel", True) - and survey_cohort_weights is None - ): - unit_cohorts = precomputed["unit_cohorts"] - rcs_cohort_counts = {} - for g in np.unique(unit_cohorts): - if g > 0: - rcs_cohort_counts[g] = int(np.sum(unit_cohorts == g)) - for (g, t), data in group_time_effects.items(): # Only include post-treatment effects (t >= g - anticipation) # Pre-treatment effects are for parallel trends, not overall ATT if t < g - self.anticipation: continue effects.append(data["effect"]) - # Use fixed cohort-level weights for aggregation + # Use fixed cohort-level survey weight sum for aggregation. + # For RCS, data["n_treated"] is already the fixed cohort mass + # (set at the source in _compute_att_gt_rc), so the fallback is correct. if survey_cohort_weights is not None and g in survey_cohort_weights: weights_list.append(survey_cohort_weights[g]) - elif rcs_cohort_counts is not None and g in rcs_cohort_counts: - weights_list.append(rcs_cohort_counts[g]) else: weights_list.append(data["n_treated"]) gt_pairs.append((g, t)) @@ -587,29 +573,16 @@ def _aggregate_event_study( if g > 0: survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g])) - # For unweighted RCS: fixed full-sample cohort counts (matching WIF) - rcs_cohort_counts = None - if ( - precomputed is not None - and not precomputed.get("is_panel", True) - and survey_cohort_weights is None - ): - unit_cohorts_es = precomputed["unit_cohorts"] - rcs_cohort_counts = {} - for g in np.unique(unit_cohorts_es): - if g > 0: - rcs_cohort_counts[g] = int(np.sum(unit_cohorts_es == g)) - for (g, t), data in group_time_effects.items(): e = t - g # Relative time if e not in effects_by_e: effects_by_e[e] = [] - if survey_cohort_weights is not None and g in survey_cohort_weights: - w = survey_cohort_weights[g] - elif rcs_cohort_counts is not None and g in rcs_cohort_counts: - w = rcs_cohort_counts[g] - else: - w = data["n_treated"] + # For RCS, data["n_treated"] is already the fixed cohort mass + w = ( + survey_cohort_weights[g] + if survey_cohort_weights is not None and g in survey_cohort_weights + else data["n_treated"] + ) effects_by_e[e].append( ( (g, t), # Keep track of the (g,t) pair diff --git a/tests/test_staggered_rc.py b/tests/test_staggered_rc.py index 34e0d4fd..519d1404 100644 --- a/tests/test_staggered_rc.py +++ b/tests/test_staggered_rc.py @@ -403,3 +403,80 @@ def test_summary_labels_rcs(self, rc_data): summary = result.summary() assert "obs:" in summary assert "units:" not in summary.split("\n")[3] # Treated line + + +# ============================================================================= +# Unequal Cohort Counts Across Periods +# ============================================================================= + + +class TestUnequalCohortCounts: + """Tests with n_gt != n_gs — catches normalizer/weight bugs.""" + + @pytest.fixture + def unequal_rc_data(self): + """RCS data where cohort sizes differ across periods.""" + rng = np.random.default_rng(77) + records = [] + # 4 periods, cohort g=2 treated at period 2 + for period in range(4): + # Vary n_per_period so cohort counts differ across periods + n_per_period = 100 + period * 30 # 100, 130, 160, 190 + for i in range(n_per_period): + # ~30% treated (cohort 2), ~70% never-treated + ft = 2 if rng.random() < 0.3 else 0 + treated = (ft > 0) and (period >= ft) + y = rng.normal(0, 1) + (2.0 if treated else 0.0) + records.append( + { + "unit": f"u{period}_{i}", + "period": period, + "outcome": y, + "first_treat": ft, + } + ) + import pandas as pd + + return pd.DataFrame(records) + + @pytest.mark.parametrize("method", ["reg", "ipw", "dr"]) + def test_finite_results_unequal(self, unequal_rc_data, method): + result = CallawaySantAnna(estimation_method=method, panel=False).fit( + unequal_rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + def test_dr_covariates_unequal(self, unequal_rc_data): + """DR with covariates under unequal cohort counts.""" + rng = np.random.default_rng(77) + unequal_rc_data["x1"] = rng.normal(0, 1, len(unequal_rc_data)) + result = CallawaySantAnna(estimation_method="dr", panel=False).fit( + unequal_rc_data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + + def test_bootstrap_unequal(self, unequal_rc_data): + """Bootstrap with unequal cohort counts.""" + result = CallawaySantAnna( + estimation_method="reg", panel=False, n_bootstrap=49, seed=42 + ).fit( + unequal_rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert result.bootstrap_results is not None + assert np.isfinite(result.overall_att) From b623deeed20d79ab05bdf255819745c93140371a Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 28 Mar 2026 21:47:47 -0400 Subject: [PATCH 04/19] Rewrite RC reg/DR to match DRDID::reg_did_rc and DRDID::drdid_rc formulas _outcome_regression_rc: Pool all treated obs for OR correction term (was: separate per-period averages). Period-specific treated means for Y. Matches Sant'Anna & Zhao (2020) Eq 2.2 / R reg_did_rc exactly. _doubly_robust_rc: Fit 4 OLS models (control+treated, pre+post) for locally efficient DR estimator (was: 2 control-only). Implements tau_1 (AIPW) + tau_2 (local efficiency adjustment) with full 11-component IF. Matches Sant'Anna & Zhao (2020) Eq 3.3+3.4 / R drdid_rc exactly. Add agg_weight field to group_time_effects for RCS aggregation weight (cohort mass), separate from n_treated (per-cell display count). Aggregation uses data.get("agg_weight", data["n_treated"]) for backward compatibility with panel data. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 522 +++++++++++++++++++++-------- diff_diff/staggered_aggregation.py | 13 +- 2 files changed, 381 insertions(+), 154 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 7986488c..447e26ad 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1480,12 +1480,14 @@ def fit( ] for t in valid_periods: - att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_rc( + rc_result = self._compute_att_gt_rc( precomputed, g, t, covariates, ) + att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = rc_result[:6] + agg_w = rc_result[6] if len(rc_result) > 6 else n_treat if att_gt is not None: t_stat, p_val, ci = safe_inference( @@ -1503,6 +1505,7 @@ def fit( "conf_int": ci, "n_treated": n_treat, "n_control": n_ctrl, + "agg_weight": agg_w, } if sw_sum is not None: gte_entry["survey_weight_sum"] = sw_sum @@ -2708,9 +2711,10 @@ def _compute_att_gt_rc( } sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None - # Use fixed cohort mass as n_treated for aggregation weight consistency + # n_treated = per-cell treated count at period t (for display). + # cohort_mass = total treated across all periods (for aggregation weights). cohort_mass = precomputed.get("rcs_cohort_masses", {}).get(g, n_gt) - return att, se, cohort_mass, n_ct, inf_func_info, sw_sum + return att, se, n_gt, n_ct, inf_func_info, sw_sum, cohort_mass def _rc_2x2_did( self, @@ -2810,10 +2814,11 @@ def _outcome_regression_rc( """ Cross-sectional outcome regression for ATT(g,t). - Two outcome models: E[Y|X] on controls at t, E[Y|X] on controls at s. - Predict counterfactual for treated at each period. - ATT = mean(Y_t - m_t(X_t)) for treated at t - - mean(Y_s - m_s(X_s)) for treated at s + Matches R DRDID::reg_did_rc (Sant'Anna & Zhao 2020, Eq 2.2). + + Two OLS models fit on controls (period t and base period s). + Predictions made for ALL treated (both periods). + OR correction pools ALL treated observations across both periods. Returns (att, se, inf_func_concat, idx_concat). """ @@ -2821,8 +2826,9 @@ def _outcome_regression_rc( n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) + n_all = n_gt + n_gs + n_ct + n_cs - # Fit outcome model on controls at period t + # --- Fit 2 OLS on control groups (period t and s separately) --- beta_t, resid_ct = _linear_regression( X_ct, y_ct, @@ -2831,7 +2837,6 @@ def _outcome_regression_rc( ) beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0) - # Fit outcome model on controls at base period s beta_s, resid_cs = _linear_regression( X_cs, y_cs, @@ -2840,41 +2845,91 @@ def _outcome_regression_rc( ) beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0) - # Predict counterfactual for treated + # --- Predict counterfactual for ALL treated (both periods) --- X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) - m_gt = X_gt_int @ beta_t # counterfactual at t - m_gs = X_gs_int @ beta_s # counterfactual at s + X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) + X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) - # Treated residuals - resid_treated_t = y_gt - m_gt - resid_treated_s = y_gs - m_gs + # mu_hat_{0,t}(X) and mu_hat_{0,s}(X) for each treated obs + mu_post_gt = X_gt_int @ beta_t # treated-post predicted at post model + mu_pre_gt = X_gt_int @ beta_s # treated-post predicted at pre model + mu_post_gs = X_gs_int @ beta_t # treated-pre predicted at post model + mu_pre_gs = X_gs_int @ beta_s # treated-pre predicted at pre model + # --- Group weights (R: w.treat.pre, w.treat.post, w.cont = w.D) --- if sw_gt is not None: - sw_gt_norm = sw_gt / np.sum(sw_gt) - sw_gs_norm = sw_gs / np.sum(sw_gs) - sw_ct_norm = sw_ct / np.sum(sw_ct) - sw_cs_norm = sw_cs / np.sum(sw_cs) - - att_t = float(np.sum(sw_gt_norm * resid_treated_t)) - att_s = float(np.sum(sw_gs_norm * resid_treated_s)) - att = att_t - att_s - - # Influence function - inf_gt = sw_gt_norm * (resid_treated_t - att_t) - inf_gs = -sw_gs_norm * (resid_treated_s - att_s) - inf_ct = -sw_ct_norm * resid_ct - inf_cs = sw_cs_norm * resid_cs + w_treat_post = sw_gt # treated at t + w_treat_pre = sw_gs # treated at s + w_D_gt = sw_gt # ALL treated: t portion + w_D_gs = sw_gs # ALL treated: s portion else: - att_t = float(np.mean(resid_treated_t)) - att_s = float(np.mean(resid_treated_s)) - att = att_t - att_s - - # Influence function - inf_gt = (resid_treated_t - att_t) / n_gt - inf_gs = -(resid_treated_s - att_s) / n_gs - inf_ct = -resid_ct / n_ct - inf_cs = resid_cs / n_cs + w_treat_post = np.ones(n_gt) + w_treat_pre = np.ones(n_gs) + w_D_gt = np.ones(n_gt) + w_D_gs = np.ones(n_gs) + + sum_w_treat_post = np.sum(w_treat_post) + sum_w_treat_pre = np.sum(w_treat_pre) + sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # pool ALL treated + + # --- Treated means (period-specific Hajek means) --- + eta_treat_post = np.sum(w_treat_post * y_gt) / sum_w_treat_post + eta_treat_pre = np.sum(w_treat_pre * y_gs) / sum_w_treat_pre + + # --- OR correction: pools ALL treated --- + # out.y.post - out.y.pre for each treated obs + or_diff_gt = mu_post_gt - mu_pre_gt # treated at t + or_diff_gs = mu_post_gs - mu_pre_gs # treated at s + eta_cont = (np.sum(w_D_gt * or_diff_gt) + np.sum(w_D_gs * or_diff_gs)) / sum_w_D + + # --- Point estimate --- + att = float(eta_treat_post - eta_treat_pre - eta_cont) + + # --- Influence function (matches R reg_did_rc.R) --- + # All IF components are n_all-length, nonzero only for their group. + + # Treated IF components (period-specific) + inf_treat_post = w_treat_post * (y_gt - eta_treat_post) / sum_w_treat_post + inf_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / sum_w_treat_pre + + # inf_treat = inf_treat_post - inf_treat_pre (across groups) + # inf_treat_post lives at gt positions, inf_treat_pre at gs positions + + # Control IF: leading term (nonzero only for treated obs) + inf_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont) / sum_w_D + inf_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont) / sum_w_D + + # Control IF: estimation effect from OLS + # bread_t = (X_ctrl_t' @ diag(W_ctrl_t) @ X_ctrl_t)^{-1} + W_ct = sw_ct if sw_ct is not None else np.ones(n_ct) + W_cs = sw_cs if sw_cs is not None else np.ones(n_cs) + bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int)) + bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int)) + + # M1 = colMeans(w_D * X) / mean(w_D) — gradient, same X basis for both + # In R: M1 = colMeans(w.cont * out.x) / mean(w.cont) + # w.cont = i.weights * D across all obs; for treated obs, out.x is their X + M1 = ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + + # asy_lin_rep_ols_t: nonzero only for control-t obs + # = W_i * (1-D_i) * 1{T=t} * (y_i - X_i'*beta_t) * X_i @ bread_t + asy_lin_rep_ols_t = (W_ct * resid_ct)[:, None] * X_ct_int @ bread_t + # asy_lin_rep_ols_s: nonzero only for control-s obs + asy_lin_rep_ols_s = (W_cs * resid_cs)[:, None] * X_cs_int @ bread_s + + inf_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,) + inf_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,) + + # --- Assemble per-group IF --- + # R: inf_cont = (inf_cont_1 + inf_cont_2_post - inf_cont_2_pre) / mean(w_D) + # Our convention divides by sum (not mean), so estimation effects need / sum_w_D + inf_gt = inf_treat_post - inf_cont_1_gt + inf_gs = -inf_treat_pre - inf_cont_1_gs + inf_ct = -(inf_cont_2_ct / sum_w_D) + inf_cs = inf_cont_2_cs / sum_w_D # Concatenate: treated (t then s), control (t then s) inf_treated = np.concatenate([inf_gt, inf_gs]) @@ -3046,8 +3101,9 @@ def _doubly_robust_rc( """ Cross-sectional doubly robust estimation for ATT(g,t). - Combines outcome regression and IPW. Consistent if either the - outcome model or the propensity model is correctly specified. + Matches R DRDID::drdid_rc (Sant'Anna & Zhao 2020, Eq 3.1). + Locally efficient DR estimator with 4 OLS fits (control pre/post, + treated pre/post) plus propensity score. Returns (att, se, inf_func_concat, idx_concat). """ @@ -3055,35 +3111,77 @@ def _doubly_robust_rc( n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) + n_all = n_gt + n_gs + n_ct + n_cs - # --- Outcome regression component --- - beta_t, resid_ct = _linear_regression( + # ===================================================================== + # 1. Outcome regression: 4 OLS fits + # ===================================================================== + # Control OLS: E[Y|X, D=0, T=t] and E[Y|X, D=0, T=s] + beta_ct, resid_ct = _linear_regression( X_ct, y_ct, rank_deficient_action=self.rank_deficient_action, weights=sw_ct, ) - beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0) + beta_ct = np.where(np.isfinite(beta_ct), beta_ct, 0.0) - beta_s, resid_cs = _linear_regression( + beta_cs, resid_cs = _linear_regression( X_cs, y_cs, rank_deficient_action=self.rank_deficient_action, weights=sw_cs, ) - beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0) + beta_cs = np.where(np.isfinite(beta_cs), beta_cs, 0.0) + + # Treated OLS: E[Y|X, D=1, T=t] and E[Y|X, D=1, T=s] + beta_gt, resid_gt = _linear_regression( + X_gt, + y_gt, + rank_deficient_action=self.rank_deficient_action, + weights=sw_gt, + ) + beta_gt = np.where(np.isfinite(beta_gt), beta_gt, 0.0) + + beta_gs, resid_gs = _linear_regression( + X_gs, + y_gs, + rank_deficient_action=self.rank_deficient_action, + weights=sw_gs, + ) + beta_gs = np.where(np.isfinite(beta_gs), beta_gs, 0.0) + # Intercept-augmented design matrices X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) - m_gt = X_gt_int @ beta_t - m_gs = X_gs_int @ beta_s - m_ct = X_ct_int @ beta_t - m_cs = X_cs_int @ beta_s - - # --- Propensity score component --- + # Control OR predictions for all groups + mu0_post_gt = X_gt_int @ beta_ct # mu_{0,1}(X) for treated-post + mu0_pre_gt = X_gt_int @ beta_cs # mu_{0,0}(X) for treated-post + mu0_post_gs = X_gs_int @ beta_ct # mu_{0,1}(X) for treated-pre + mu0_pre_gs = X_gs_int @ beta_cs # mu_{0,0}(X) for treated-pre + mu0_post_ct = X_ct_int @ beta_ct # mu_{0,1}(X) for control-post + mu0_pre_ct = X_ct_int @ beta_cs # mu_{0,0}(X) for control-post + mu0_post_cs = X_cs_int @ beta_ct # mu_{0,1}(X) for control-pre + mu0_pre_cs = X_cs_int @ beta_cs # mu_{0,0}(X) for control-pre + + # Treated OR predictions for all groups (for local efficiency adjustment) + mu1_post_gt = X_gt_int @ beta_gt # mu_{1,1}(X) for treated-post + mu1_pre_gt = X_gt_int @ beta_gs # mu_{1,0}(X) for treated-post + mu1_post_gs = X_gs_int @ beta_gt # mu_{1,1}(X) for treated-pre + mu1_pre_gs = X_gs_int @ beta_gs # mu_{1,0}(X) for treated-pre + + # mu_{0,Y}(T_i, X_i): control OR evaluated at own period + # For post-period obs: mu_{0,1}(X), for pre-period obs: mu_{0,0}(X) + mu0Y_gt = mu0_post_gt # treated-post → use post control model + mu0Y_gs = mu0_pre_gs # treated-pre → use pre control model + mu0Y_ct = mu0_post_ct # control-post → use post control model + mu0Y_cs = mu0_pre_cs # control-pre → use pre control model + + # ===================================================================== + # 2. Propensity score + # ===================================================================== X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) sw_all = None @@ -3112,128 +3210,256 @@ def _doubly_robust_rc( pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + # Split propensity scores per group + ps_gt = pscore[:n_gt] + ps_gs = pscore[n_gt : n_gt + n_gs] ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] ps_cs = pscore[n_gt + n_gs + n_ct :] - # IPW weights for controls - w_ct = ps_ct / (1 - ps_ct) - w_cs = ps_cs / (1 - ps_cs) - + # ===================================================================== + # 3. Group weights + # ===================================================================== if sw_gt is not None: - w_ct = sw_ct * w_ct - w_cs = sw_cs * w_cs - - # --- DR ATT --- - if sw_gt is not None: - sw_gt_sum = np.sum(sw_gt) - sw_gs_sum = np.sum(sw_gs) - - # Period t component - att_t_or = float(np.sum(sw_gt * (y_gt - m_gt)) / sw_gt_sum) - att_t_aug = float(np.sum(w_ct * (m_ct - y_ct)) / sw_gt_sum) - att_t = att_t_or + att_t_aug - - # Period s component - att_s_or = float(np.sum(sw_gs * (y_gs - m_gs)) / sw_gs_sum) - att_s_aug = float(np.sum(w_cs * (m_cs - y_cs)) / sw_gs_sum) - att_s = att_s_or + att_s_aug - - att = att_t - att_s - - # Influence function (plug-in) - sw_gt_norm = sw_gt / sw_gt_sum - sw_gs_norm = sw_gs / sw_gs_sum - - inf_gt = sw_gt_norm * (y_gt - m_gt - att_t) - inf_gs = -sw_gs_norm * (y_gs - m_gs - att_s) - inf_ct = (w_ct / sw_gt_sum) * (m_ct - y_ct) - inf_cs = -(w_cs / sw_gs_sum) * (m_cs - y_cs) + w_treat_post = sw_gt + w_treat_pre = sw_gs + w_D_gt = sw_gt + w_D_gs = sw_gs else: - # Period t component - att_t_or = float(np.mean(y_gt - m_gt)) - att_t_aug = float(np.sum(w_ct * (m_ct - y_ct)) / n_gt) - att_t = att_t_or + att_t_aug - - # Period s component - att_s_or = float(np.mean(y_gs - m_gs)) - att_s_aug = float(np.sum(w_cs * (m_cs - y_cs)) / n_gs) - att_s = att_s_or + att_s_aug + w_treat_post = np.ones(n_gt) + w_treat_pre = np.ones(n_gs) + w_D_gt = np.ones(n_gt) + w_D_gs = np.ones(n_gs) + + sum_w_treat_post = np.sum(w_treat_post) + sum_w_treat_pre = np.sum(w_treat_pre) + sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) + + # IPW control weights: sw * ps/(1-ps) for controls + w_ipw_ct = ps_ct / (1 - ps_ct) + w_ipw_cs = ps_cs / (1 - ps_cs) + if sw_ct is not None: + w_ipw_ct = sw_ct * w_ipw_ct + w_ipw_cs = sw_cs * w_ipw_cs + + # ===================================================================== + # 4. Point estimate: tau_1 (AIPW using control ORs) + # ===================================================================== + # Hajek-normalized means of (y - mu0Y) per group + eta_treat_post = np.sum(w_treat_post * (y_gt - mu0Y_gt)) / sum_w_treat_post + eta_treat_pre = np.sum(w_treat_pre * (y_gs - mu0Y_gs)) / sum_w_treat_pre + + sum_w_ipw_ct = np.sum(w_ipw_ct) + sum_w_ipw_cs = np.sum(w_ipw_cs) + eta_cont_post = ( + np.sum(w_ipw_ct * (y_ct - mu0Y_ct)) / sum_w_ipw_ct if sum_w_ipw_ct > 0 else 0.0 + ) + eta_cont_pre = ( + np.sum(w_ipw_cs * (y_cs - mu0Y_cs)) / sum_w_ipw_cs if sum_w_ipw_cs > 0 else 0.0 + ) - att = att_t - att_s + tau_1 = (eta_treat_post - eta_cont_post) - (eta_treat_pre - eta_cont_pre) + + # ===================================================================== + # 5. Point estimate: local efficiency adjustment (tau_2) + # ===================================================================== + # Differences mu_{1,t}(X) - mu_{0,t}(X) for treated obs + or_diff_post_gt = mu1_post_gt - mu0_post_gt # at treated-post + or_diff_post_gs = mu1_post_gs - mu0_post_gs # at treated-pre + or_diff_pre_gt = mu1_pre_gt - mu0_pre_gt # at treated-post + or_diff_pre_gs = mu1_pre_gs - mu0_pre_gs # at treated-pre + + # att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) — all treated + att_d_post = (np.sum(w_D_gt * or_diff_post_gt) + np.sum(w_D_gs * or_diff_post_gs)) / sum_w_D + # att_dt1_post — treated-post only + att_dt1_post = np.sum(w_treat_post * or_diff_post_gt) / sum_w_treat_post + # att_d_pre — all treated + att_d_pre = (np.sum(w_D_gt * or_diff_pre_gt) + np.sum(w_D_gs * or_diff_pre_gs)) / sum_w_D + # att_dt0_pre — treated-pre only + att_dt0_pre = np.sum(w_treat_pre * or_diff_pre_gs) / sum_w_treat_pre + + tau_2 = (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre) + + att = float(tau_1 + tau_2) + + # ===================================================================== + # 6. Influence function: tau_1 components + # ===================================================================== + # Treated IF (period-specific Hajek) + inf_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / sum_w_treat_post + inf_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / sum_w_treat_pre + + # Control IF (IPW Hajek) + inf_cont_post_ct = ( + w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / sum_w_ipw_ct + if sum_w_ipw_ct > 0 + else np.zeros(n_ct) + ) + inf_cont_pre_cs = ( + w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / sum_w_ipw_cs + if sum_w_ipw_cs > 0 + else np.zeros(n_cs) + ) - # Influence function (plug-in) - inf_gt = (y_gt - m_gt - att_t) / n_gt - inf_gs = -(y_gs - m_gs - att_s) / n_gs - inf_ct = (w_ct * (m_ct - y_ct)) / n_gt - inf_cs = -(w_cs * (m_cs - y_cs)) / n_gs + # tau_1 IF per group (plug-in, before nuisance corrections) + inf_gt_tau1 = inf_treat_post + inf_gs_tau1 = -inf_treat_pre + inf_ct_tau1 = -inf_cont_post_ct + inf_cs_tau1 = inf_cont_pre_cs + + # ===================================================================== + # 7. Influence function: tau_2 leading terms + # ===================================================================== + # att_d_post IF: w_D*(or_diff_post - att_d_post) / sum_w_D + inf_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / sum_w_D + inf_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / sum_w_D + # att_dt1_post IF: w_treat_post*(or_diff_post - att_dt1_post) / sum_w_treat_post + inf_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / sum_w_treat_post + # att_d_pre IF + inf_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / sum_w_D + inf_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / sum_w_D + # att_dt0_pre IF + inf_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / sum_w_treat_pre + + # tau_2 IF per group + inf_gt_tau2 = (inf_d_post_gt - inf_dt1_post) - inf_d_pre_gt + inf_gs_tau2 = inf_d_post_gs - (-inf_dt0_pre + inf_d_pre_gs) + # Control obs don't contribute to tau_2 leading terms (w_D = 0 for controls) + + # ===================================================================== + # 8. Combined plug-in IF (before nuisance corrections) + # ===================================================================== + inf_gt = inf_gt_tau1 + inf_gt_tau2 + inf_gs = inf_gs_tau1 + inf_gs_tau2 + inf_ct = inf_ct_tau1 + inf_cs = inf_cs_tau1 - # Concatenate: treated (t then s), control (t then s) inf_treated = np.concatenate([inf_gt, inf_gs]) inf_control = np.concatenate([inf_ct, inf_cs]) inf_all = np.concatenate([inf_treated, inf_control]) - # --- PS IF correction --- - X_all_int = np.column_stack([np.ones(len(D_all)), X_all]) - pscore_all = pscore + # ===================================================================== + # 9. PS IF correction + # ===================================================================== + X_all_int = np.column_stack([np.ones(n_all), X_all]) - W_ps = pscore_all * (1 - pscore_all) + W_ps = pscore * (1 - pscore) if sw_all is not None: W_ps = W_ps * sw_all H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) H_ps_inv = _safe_inv(H_ps) - score_ps = (D_all - pscore_all)[:, None] * X_all_int + score_ps = (D_all - pscore)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] - asy_lin_rep_ps = score_ps @ H_ps_inv - - # M2_dr: uses DR residuals (m-y) instead of raw y - # Use separate normalizers for post vs base period (RCS has different - # treated counts per period — using a single normalizer mis-scales) - dr_resid_ct = m_ct - y_ct - dr_resid_cs = m_cs - y_cs - normalizer_t = np.sum(sw_gt) if sw_gt is not None else n_gt - normalizer_s = np.sum(sw_gs) if sw_gs is not None else n_gs - M2_dr = np.zeros(X_all_int.shape[1]) + asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) + + # M2: gradient of tau_1 control IPW w.r.t. PS parameters + # Only control obs contribute to M2 (through their IPW weights) ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) - M2_dr += np.mean( - ((w_ct / normalizer_t) * dr_resid_ct)[:, None] * X_all_int[ct_slice], - axis=0, - ) cs_slice = slice(n_gt + n_gs + n_ct, None) - M2_dr -= np.mean( - ((w_cs / normalizer_s) * dr_resid_cs)[:, None] * X_all_int[cs_slice], - axis=0, - ) - - inf_all = inf_all + asy_lin_rep_ps @ M2_dr - - # --- OR IF correction -- period t model --- - W_t = sw_ct if sw_ct is not None else np.ones(n_ct) - bread_t = _safe_inv(X_ct_int.T @ (W_t[:, None] * X_ct_int)) - - sw_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) - M1_t = ( - -np.sum(sw_gt_vals[:, None] * X_gt_int, axis=0) - + np.sum(w_ct[:, None] * X_ct_int, axis=0) - ) / normalizer_t - asy_lin_rep_or_t = (W_t * (y_ct - m_ct))[:, None] * X_ct_int @ bread_t - inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_or_t @ M1_t + dr_resid_ct = y_ct - mu0Y_ct - eta_cont_post + dr_resid_cs = y_cs - mu0Y_cs - eta_cont_pre - # --- OR IF correction -- period s model --- - W_s = sw_cs if sw_cs is not None else np.ones(n_cs) - bread_s = _safe_inv(X_cs_int.T @ (W_s[:, None] * X_cs_int)) - - sw_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs) - M1_s = ( - np.sum(sw_gs_vals[:, None] * X_gs_int, axis=0) - - np.sum(w_cs[:, None] * X_cs_int, axis=0) - ) / normalizer_s + M2 = np.zeros(X_all_int.shape[1]) + if sum_w_ipw_ct > 0: + M2 -= ( + np.sum( + ((w_ipw_ct * dr_resid_ct / sum_w_ipw_ct)[:, None] * X_all_int[ct_slice]), + axis=0, + ) + / n_all + ) + if sum_w_ipw_cs > 0: + M2 += ( + np.sum( + ((w_ipw_cs * dr_resid_cs / sum_w_ipw_cs)[:, None] * X_all_int[cs_slice]), + axis=0, + ) + / n_all + ) - asy_lin_rep_or_s = (W_s * (y_cs - m_cs))[:, None] * X_cs_int @ bread_s - # Apply only to control-s portion of inf_all - inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_or_s @ M1_s + inf_all = inf_all + asy_lin_rep_ps @ M2 + + # ===================================================================== + # 10. Control OR IF corrections (tau_1 estimation effect) + # ===================================================================== + # bread = (X'WX)^{-1} for each control OLS + W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct) + W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs) + bread_ct = _safe_inv(X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int)) + bread_cs = _safe_inv(X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int)) + + # ALR for control OLS + asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct + asy_lin_rep_cs = (W_cs_vals * resid_cs)[:, None] * X_cs_int @ bread_cs + + # M1 for control-post model (beta_ct): gradient from tau_1 + # Treated-post contributes -w_treat_post*X/sum_w_treat_post (via mu0Y_gt = X@beta_ct) + # Control-post contributes -w_ipw_ct*X/sum_w_ipw_ct (via mu0Y_ct = X@beta_ct) + # Also contributes from tau_2: att_d_post uses mu0_post, att_dt1_post uses mu0_post + # For tau_2: w_D*(-X)/sum_w_D from att_d_post + w_treat_post*X/sum_w_treat_post from att_dt1_post + M1_ct = np.zeros(X_all_int.shape[1] - 1 + 1) # p+1 (with intercept) + # From eta_treat_post (mu0Y_gt = X@beta_ct): + M1_ct -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post + # From eta_cont_post (mu0Y_ct = X@beta_ct): + if sum_w_ipw_ct > 0: + M1_ct += np.sum(w_ipw_ct[:, None] * X_ct_int, axis=0) / sum_w_ipw_ct + # From tau_2 att_d_post: -w_D * X / sum_w_D (mu0_post at all treated) + M1_ct -= ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + # From tau_2 att_dt1_post: +w_treat_post * X / sum_w_treat_post (mu0_post at treated-post) + M1_ct += np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post + + # M1 for control-pre model (beta_cs): + M1_cs = np.zeros(X_all_int.shape[1]) + # From eta_treat_pre (mu0Y_gs = X@beta_cs): + M1_cs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre + # From eta_cont_pre (mu0Y_cs = X@beta_cs): + if sum_w_ipw_cs > 0: + M1_cs -= np.sum(w_ipw_cs[:, None] * X_cs_int, axis=0) / sum_w_ipw_cs + # From tau_2 att_d_pre: +w_D * X / sum_w_D (mu0_pre at all treated) + M1_cs += ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + # From tau_2 att_dt0_pre: -w_treat_pre * X / sum_w_treat_pre (mu0_pre at treated-pre) + M1_cs -= np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre + + inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_ct @ M1_ct + inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_cs @ M1_cs + + # ===================================================================== + # 11. Treated OR IF corrections (tau_2 estimation effect) + # ===================================================================== + W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) + W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs) + bread_gt = _safe_inv(X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int)) + bread_gs = _safe_inv(X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int)) + + asy_lin_rep_gt = (W_gt_vals * resid_gt)[:, None] * X_gt_int @ bread_gt + asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs + + # M1 for treated-post model (beta_gt): mu_{1,1}(X) + # From att_d_post: +w_D*X/sum_w_D (mu1_post at all treated) + # From att_dt1_post: -w_treat_post*X/sum_w_treat_post (mu1_post at treated-post) + M1_gt = np.zeros(X_all_int.shape[1]) + M1_gt += ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + M1_gt -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post + + # M1 for treated-pre model (beta_gs): mu_{1,0}(X) + # From att_d_pre: -w_D*X/sum_w_D + # From att_dt0_pre: +w_treat_pre*X/sum_w_treat_pre + M1_gs = np.zeros(X_all_int.shape[1]) + M1_gs -= ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + M1_gs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre + + inf_all[:n_gt] += asy_lin_rep_gt @ M1_gt + inf_all[n_gt : n_gt + n_gs] += asy_lin_rep_gs @ M1_gs se = float(np.sqrt(np.sum(inf_all**2))) diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 19fef7ad..9bfb193b 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -80,12 +80,12 @@ def _aggregate_simple( continue effects.append(data["effect"]) # Use fixed cohort-level survey weight sum for aggregation. - # For RCS, data["n_treated"] is already the fixed cohort mass - # (set at the source in _compute_att_gt_rc), so the fallback is correct. + # For RCS, data["agg_weight"] holds the fixed cohort mass; + # for panel, fallback to data["n_treated"]. if survey_cohort_weights is not None and g in survey_cohort_weights: weights_list.append(survey_cohort_weights[g]) else: - weights_list.append(data["n_treated"]) + weights_list.append(data.get("agg_weight", data["n_treated"])) gt_pairs.append((g, t)) groups_for_gt.append(g) @@ -577,11 +577,12 @@ def _aggregate_event_study( e = t - g # Relative time if e not in effects_by_e: effects_by_e[e] = [] - # For RCS, data["n_treated"] is already the fixed cohort mass + # For RCS, data["agg_weight"] holds the fixed cohort mass; + # for panel, fallback to data["n_treated"]. w = ( survey_cohort_weights[g] if survey_cohort_weights is not None and g in survey_cohort_weights - else data["n_treated"] + else data.get("agg_weight", data["n_treated"]) ) effects_by_e[e].append( ( @@ -609,7 +610,7 @@ def _aggregate_event_study( w = ( survey_cohort_weights[g] if survey_cohort_weights is not None and g in survey_cohort_weights - else data["n_treated"] + else data.get("agg_weight", data["n_treated"]) ) balanced_effects[e].append( ( From 3b405b7f9009afb4d8d90397d931b327a5c65825 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 28 Mar 2026 21:59:16 -0400 Subject: [PATCH 05/19] Fix bootstrap RCS cohort-mass weighting, reset stale event-study VCV - Bootstrap overall/event-study reaggregation now uses agg_weight (fixed cohort mass) for panel=False, matching the analytical aggregation path - Reset self._event_study_vcov = None at start of fit() to prevent stale VCV from prior fit leaking into reused estimator objects Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 3 +++ diff_diff/staggered_bootstrap.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 447e26ad..954bb68b 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1320,6 +1320,9 @@ def fit( if not (0 < self.pscore_trim < 0.5): raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}") + # Reset stale state from prior fit (prevents leaking event-study VCV) + self._event_study_vcov = None + # Normalize empty covariates list to None if covariates is not None and len(covariates) == 0: covariates = None diff --git a/diff_diff/staggered_bootstrap.py b/diff_diff/staggered_bootstrap.py index 48f40f21..2d94f773 100644 --- a/diff_diff/staggered_bootstrap.py +++ b/diff_diff/staggered_bootstrap.py @@ -237,8 +237,14 @@ def _run_multiplier_bootstrap( _cohort_mass_cache[g] = float(np.sum(survey_w[unit_cohorts == g])) all_n_treated = np.array([_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float) else: + # Use agg_weight if available (RCS: fixed cohort mass); + # fall back to n_treated for panel data all_n_treated = np.array( - [group_time_effects[gt]["n_treated"] for gt in gt_pairs], dtype=float + [ + group_time_effects[gt].get("agg_weight", group_time_effects[gt]["n_treated"]) + for gt in gt_pairs + ], + dtype=float, ) post_n_treated = all_n_treated[post_treatment_mask] @@ -572,7 +578,10 @@ def _agg_weight(g: Any, t: Any) -> float: if g not in _cohort_mass: _cohort_mass[g] = float(np.sum(survey_w[unit_cohorts == g])) return _cohort_mass[g] - return group_time_effects[(g, t)]["n_treated"] + # Use agg_weight if available (RCS: fixed cohort mass) + return group_time_effects[(g, t)].get( + "agg_weight", group_time_effects[(g, t)]["n_treated"] + ) # Organize by relative time effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {} From 53cfd5d5b06dd44f71f5272fe16a5e537543a327 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 28 Mar 2026 22:17:01 -0400 Subject: [PATCH 06/19] Clear analytical event_study_vcov when bootstrap overwrites event-study SEs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prevents HonestDiD from mixing analytical IF-based VCV with bootstrap SEs on bootstrap-fit CallawaySantAnna results. When n_bootstrap>0, the event_study_vcov is set to None so HonestDiD falls back to diagonal from the bootstrap SEs (consistent variance path). Add regression test: bootstrap CS → HonestDiD asserts vcov is None. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 6 +++++- tests/test_honest_did.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 954bb68b..160c7db1 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1771,8 +1771,12 @@ def fit( ) # Store results - # Retrieve event-study VCV from aggregation mixin (Phase 7d) + # Retrieve event-study VCV from aggregation mixin (Phase 7d). + # Clear it when bootstrap overwrites event-study SEs to prevent + # HonestDiD from mixing analytical VCV with bootstrap SEs. event_study_vcov = getattr(self, "_event_study_vcov", None) + if bootstrap_results is not None and event_study_vcov is not None: + event_study_vcov = None self.results_ = CallawaySantAnnaResults( group_time_effects=group_time_effects, diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index 1a36ce09..43c4b368 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -1280,6 +1280,30 @@ def test_replicate_weight_uses_diagonal_fallback(self): # event_study_vcov should be None (diagonal fallback for replicate designs) assert cs_result.event_study_vcov is None + def test_bootstrap_fit_clears_analytical_vcov(self): + """Bootstrap CS results should NOT carry analytical event_study_vcov.""" + from diff_diff import CallawaySantAnna, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + cs_result = CallawaySantAnna(n_bootstrap=49, seed=42).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + + # event_study_vcov should be None when bootstrap is used + # (prevents HonestDiD from mixing analytical VCV with bootstrap SEs) + assert cs_result.event_study_vcov is None + + # HonestDiD should still work (falls back to diagonal from bootstrap SEs) + honest = HonestDiD(method="relative_magnitude", M=1.0) + h_result = honest.fit(cs_result) + assert np.isfinite(h_result.original_se) + assert h_result.original_se > 0 + # ============================================================================= # Tests for Visualization (without matplotlib) From 9ff21a2859e7713ca00a4d0359f855dc272d0b7f Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 09:06:04 -0400 Subject: [PATCH 07/19] Fix RC IF normalization scaling: M1 uses n_all denominator, PS M2 uses sum not mean - _outcome_regression_rc: M1 denominator changed from sum_w_D to n_all (matching R colMeans convention); inf_cont_2 / sum_w_D then gives correct single normalization by mean_w_D * n_all = sum_w_D - _ipw_estimation_rc: PS M2 uses np.sum/n_all instead of np.mean (which divided by n_ct/n_cs instead of n_all, under-scaling the correction) - _doubly_robust_rc: PS M2 already correct (np.sum/n_all), no change Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 160c7db1..52c334e4 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -2914,12 +2914,13 @@ def _outcome_regression_rc( bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int)) bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int)) - # M1 = colMeans(w_D * X) / mean(w_D) — gradient, same X basis for both - # In R: M1 = colMeans(w.cont * out.x) / mean(w.cont) - # w.cont = i.weights * D across all obs; for treated obs, out.x is their X + # R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all + # The final control IF divides by mean_w_D = sum_w_D / n_all (once). + # In our split convention phi = psi / n_all, the estimation effect is + # asy_lin_rep @ M1 / sum_w_D (where M1 uses n_all denominator). M1 = ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) - ) / sum_w_D + ) / n_all # asy_lin_rep_ols_t: nonzero only for control-t obs # = W_i * (1-D_i) * 1{T=t} * (y_i - X_i'*beta_t) * X_i @ bread_t @@ -2975,6 +2976,7 @@ def _ipw_estimation_rc( n_gs = len(y_gs) n_ct = len(y_ct) n_cs = len(y_cs) + n_all = n_gt + n_gs + n_ct + n_cs # Pool treated and control for propensity score X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) @@ -3067,20 +3069,29 @@ def _ipw_estimation_rc( asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) # M2: gradient of IPW ATT w.r.t. PS parameters - # Control IPW residuals from both periods + # R: M2 = colMeans(w_ipw * (y-mu)/mean_w * X) over ALL n obs (zeros for treated). + # In our split convention phi = psi/n_all, so M2_rc = R's M2 / n_all. + # R's M2 = sum(w_ct_norm * (y-mu) * X_ct) [the mean_w normalization cancels]. + # So M2_rc = sum(...) / n_all. Old code used np.mean → sum/n_ct (wrong). ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw) ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw) # Zero for treated observations M2_rc = np.zeros(X_all_int.shape[1]) - # Control-t contribution - M2_rc += np.mean( - ipw_resid_ct[:, None] * X_all_int[n_gt + n_gs : n_gt + n_gs + n_ct], - axis=0, + # Control-t contribution: sum / n_all (NOT np.mean which divides by n_ct) + M2_rc += ( + np.sum( + ipw_resid_ct[:, None] * X_all_int[n_gt + n_gs : n_gt + n_gs + n_ct], + axis=0, + ) + / n_all ) # Control-s contribution (opposite sign -- base period) - M2_rc -= np.mean( - ipw_resid_cs[:, None] * X_all_int[n_gt + n_gs + n_ct :], - axis=0, + M2_rc -= ( + np.sum( + ipw_resid_cs[:, None] * X_all_int[n_gt + n_gs + n_ct :], + axis=0, + ) + / n_all ) inf_all = inf_all + asy_lin_rep_ps @ M2_rc From c2f8fdc7a8a63a41d007f9ab51e35a361d22780d Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 09:22:00 -0400 Subject: [PATCH 08/19] Document RCS IF phi=psi/n convention, add analytical-vs-bootstrap SE convergence test REGISTRY.md: Document that RCS IFs use phi=psi/n convention (SE = sqrt(sum(phi^2))), algebraically equivalent to R's sd(psi)/sqrt(n). The 1/n_all denominator in gradient terms is the colMeans -> phi conversion, not extra shrinkage. Add test proving correctness: analytical SE within 20% of bootstrap SE (499 iters) for RCS reg with covariates. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/methodology/REGISTRY.md | 3 ++- tests/test_staggered_rc.py | 38 ++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 12685a95..560dcbad 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -420,7 +420,8 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SE uses a conservative plug-in IF based on WLS residuals. The treated IF is `inf_treated_i = (sw_i/sum(sw_treated)) * (resid_i - ATT)` (normalized by treated weight sum, matching unweighted `(resid-ATT)/n_t`). The control IF is `inf_control_i = -(sw_i/sum(sw_control)) * wls_resid_i` (normalized by control weight sum, matching unweighted `-resid/n_c`). SE is computed as `sqrt(sum(sw_t_norm * (resid_t - ATT)^2) + sum(sw_c_norm * resid_c^2))`, the weighted analogue of the unweighted `sqrt(var_t/n_t + var_c/n_c)`. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel` — WLS residuals are orthogonal to the weighted design matrix by construction, so the first-order IF term is asymptotically valid but may be conservative. SEs pass weight-scale-invariance tests. The efficient DRDID correction is deferred to future work. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization. When strata/PSU/FPC are present, analytical aggregated SEs (`n_bootstrap=0`) use `compute_survey_if_variance()` on the combined IF/WIF; bootstrap aggregated SEs (`n_bootstrap>0`) use PSU-level multiplier weights. -- **Note:** Repeated cross-sections (`panel=False`, Phase 7b): supports surveys like BRFSS, ACS annual, and CPS monthly where units are not followed over time. Uses cross-sectional DRDID (Sant'Anna & Zhao 2020, Section 4): two outcome models (one per period) instead of one on ΔY, and per-observation influence functions instead of per-unit. All three estimation methods (reg, ipw, dr) supported with and without covariates. Aggregation and bootstrap use the "canonical index" abstraction where the index space is observations (not units). Survey weights are per-observation (no unit-level collapse). Data generated via `generate_staggered_data(panel=False)`. +- **Note:** Repeated cross-sections (`panel=False`, Phase 7b): supports surveys like BRFSS, ACS annual, and CPS monthly where units are not followed over time. Uses cross-sectional DRDID (Sant'Anna & Zhao 2020, Section 4): `reg` matches `DRDID::reg_did_rc` (Eq 2.2), `dr` matches `DRDID::drdid_rc` (locally efficient, Eq 3.3+3.4 with 4 OLS fits), `ipw` matches `DRDID::std_ipw_did_rc`. Per-observation influence functions instead of per-unit. All three estimation methods support covariates and survey weights. +- **Note (deviation from R):** RCS influence functions use `phi_i = psi_i / n` convention (SE = `sqrt(sum(phi^2))`), matching the library-wide IF convention where IFs are pre-scaled by `1/n`. R's DRDID uses `psi_i` directly with `SE = sd(psi) / sqrt(n)`. These are algebraically equivalent — `sqrt(sum(psi^2/n^2)) = sqrt(sum(psi^2))/n ≈ sd(psi)/sqrt(n)` — confirmed by analytical-vs-bootstrap SE convergence tests. The `1/n_all` denominator in gradient terms (`M1`, `M2`) is not "extra shrinkage" but the `colMeans` → phi convention conversion. - **Note:** Non-survey DR path also includes nuisance IF corrections (PS + OR), matching the survey path structure (Phase 7a). Previously used plug-in IF only. **Reference implementation(s):** diff --git a/tests/test_staggered_rc.py b/tests/test_staggered_rc.py index 519d1404..dd71e5fd 100644 --- a/tests/test_staggered_rc.py +++ b/tests/test_staggered_rc.py @@ -405,6 +405,44 @@ def test_summary_labels_rcs(self, rc_data): assert "units:" not in summary.split("\n")[3] # Treated line +# ============================================================================= +# Analytical vs Bootstrap SE convergence (proves IF scaling is correct) +# ============================================================================= + + +class TestAnalyticalBootstrapConvergence: + """Analytical SE should closely match bootstrap SE — proves IF magnitude is correct.""" + + def test_reg_se_matches_bootstrap(self, rc_data_with_covariates): + """Analytical reg SE should be within 20% of bootstrap SE.""" + r_analytical = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + r_bootstrap = CallawaySantAnna( + estimation_method="reg", panel=False, n_bootstrap=499, seed=42 + ).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + # ATTs should match (bootstrap doesn't change point estimate) + np.testing.assert_allclose(r_analytical.overall_att, r_bootstrap.overall_att, atol=1e-10) + # SEs should be within 10% (proves IF scaling is correct) + ratio = r_analytical.overall_se / r_bootstrap.overall_se + assert 0.9 < ratio < 1.1, ( + f"Analytical/bootstrap SE ratio {ratio:.3f} outside [0.9, 1.1] — " + f"analytical={r_analytical.overall_se:.4f}, bootstrap={r_bootstrap.overall_se:.4f}" + ) + + # ============================================================================= # Unequal Cohort Counts Across Periods # ============================================================================= From cb3f815c22cd29f4ad338cbb8d25fe3b3deaa577 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 09:46:46 -0400 Subject: [PATCH 09/19] Refactor RC IFs to R's psi convention, fix HonestDiD VCV subsetting Restructure _outcome_regression_rc, _ipw_estimation_rc, _doubly_robust_rc to compute leading IF terms in R's unnormalized psi convention (using mean_w_* = sum_w_*/n_all normalizers matching R's mean()), then convert to library phi = psi/n_all at the boundary. Makes DRDID correspondence explicit with R variable name comments. Fix HonestDiD event_study_vcov subsetting: when filtering NaN-SE event times, subset the VCV matrix to match the surviving rel_times (was using the full unfiltered matrix, causing dimension mismatch on interior drops). Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 12 +- diff_diff/staggered.py | 388 +++++++++++++++++++++++----------------- 2 files changed, 237 insertions(+), 163 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 72528467..824e0ce1 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -664,8 +664,16 @@ def _extract_event_study_params( # Use full event-study VCV if available (Phase 7d), # otherwise fall back to diagonal from SEs if hasattr(results, "event_study_vcov") and results.event_study_vcov is not None: - # event_study_vcov is indexed by sorted rel_times - sigma = results.event_study_vcov + vcov = results.event_study_vcov + # VCV is indexed by ALL event times from aggregation; + # rel_times may be a filtered subset (NaN-SE times dropped). + # Subset VCV to match the surviving rel_times. + all_event_times = sorted(results.event_study_effects.keys()) + if vcov.shape[0] == len(all_event_times) and len(rel_times) < len(all_event_times): + idx = [all_event_times.index(t) for t in rel_times] + sigma = vcov[np.ix_(idx, idx)] + else: + sigma = vcov else: sigma = np.diag(np.array(ses) ** 2) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 52c334e4..49922c12 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -2827,6 +2827,14 @@ def _outcome_regression_rc( Predictions made for ALL treated (both periods). OR correction pools ALL treated observations across both periods. + IF convention + ------------- + Intermediate terms use R's unnormalized psi_i convention throughout. + R computes SE as ``sd(psi) / sqrt(n)``; with mean(psi) approx 0 this + equals ``sqrt(sum(psi^2)) / n``. At the end we convert to the + library's pre-scaled phi_i = psi_i / n convention where + ``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer. + Returns (att, se, inf_func_concat, idx_concat). """ n_gt = len(y_gt) @@ -2880,12 +2888,17 @@ def _outcome_regression_rc( sum_w_treat_pre = np.sum(w_treat_pre) sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # pool ALL treated + # R: mean(w.treat.post), mean(w.treat.pre), mean(w.cont) + mean_w_treat_post = sum_w_treat_post / n_all + mean_w_treat_pre = sum_w_treat_pre / n_all + mean_w_D = sum_w_D / n_all + # --- Treated means (period-specific Hajek means) --- eta_treat_post = np.sum(w_treat_post * y_gt) / sum_w_treat_post eta_treat_pre = np.sum(w_treat_pre * y_gs) / sum_w_treat_pre # --- OR correction: pools ALL treated --- - # out.y.post - out.y.pre for each treated obs + # R: out.y.post - out.y.pre for each treated obs or_diff_gt = mu_post_gt - mu_pre_gt # treated at t or_diff_gs = mu_post_gs - mu_pre_gs # treated at s eta_cont = (np.sum(w_D_gt * or_diff_gt) + np.sum(w_D_gs * or_diff_gs)) / sum_w_D @@ -2893,57 +2906,62 @@ def _outcome_regression_rc( # --- Point estimate --- att = float(eta_treat_post - eta_treat_pre - eta_cont) - # --- Influence function (matches R reg_did_rc.R) --- - # All IF components are n_all-length, nonzero only for their group. - - # Treated IF components (period-specific) - inf_treat_post = w_treat_post * (y_gt - eta_treat_post) / sum_w_treat_post - inf_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / sum_w_treat_pre - - # inf_treat = inf_treat_post - inf_treat_pre (across groups) - # inf_treat_post lives at gt positions, inf_treat_pre at gs positions - - # Control IF: leading term (nonzero only for treated obs) - inf_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont) / sum_w_D - inf_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont) / sum_w_D - - # Control IF: estimation effect from OLS - # bread_t = (X_ctrl_t' @ diag(W_ctrl_t) @ X_ctrl_t)^{-1} + # ================================================================= + # Influence function in R's unnormalized psi convention + # (R: reg_did_rc.R, psi = n * phi) + # ================================================================= + + # --- Treated psi (R: eta.treat.post, eta.treat.pre) --- + # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post) + psi_treat_post = w_treat_post * (y_gt - eta_treat_post) / mean_w_treat_post + # R: w.treat.pre * (y - eta.treat.pre) / mean(w.treat.pre) + psi_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / mean_w_treat_pre + + # --- Control psi: leading term (R: inf.cont.1) --- + # R: w.cont * (or_diff - eta.cont) [before /mean(w.cont)] + psi_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont) + psi_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont) + + # --- Control psi: estimation effect (R: inf.cont.2) --- + # R: bread = solve(crossprod(X_ctrl, W * X_ctrl) / n) + # Here bread is (X'WX)^{-1} (without /n), so asy_lin_rep already + # absorbs the 1/n that R puts in its bread. We compensate by using + # R's colMeans (= sum/n_all) for M1, matching the product exactly. W_ct = sw_ct if sw_ct is not None else np.ones(n_ct) W_cs = sw_cs if sw_cs is not None else np.ones(n_cs) bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int)) bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int)) # R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all - # The final control IF divides by mean_w_D = sum_w_D / n_all (once). - # In our split convention phi = psi / n_all, the estimation effect is - # asy_lin_rep @ M1 / sum_w_D (where M1 uses n_all denominator). M1 = ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / n_all - # asy_lin_rep_ols_t: nonzero only for control-t obs - # = W_i * (1-D_i) * 1{T=t} * (y_i - X_i'*beta_t) * X_i @ bread_t + # R: asy.lin.rep.ols (per-obs OLS score * bread) asy_lin_rep_ols_t = (W_ct * resid_ct)[:, None] * X_ct_int @ bread_t - # asy_lin_rep_ols_s: nonzero only for control-s obs asy_lin_rep_ols_s = (W_cs * resid_cs)[:, None] * X_cs_int @ bread_s - inf_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,) - inf_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,) - - # --- Assemble per-group IF --- - # R: inf_cont = (inf_cont_1 + inf_cont_2_post - inf_cont_2_pre) / mean(w_D) - # Our convention divides by sum (not mean), so estimation effects need / sum_w_D - inf_gt = inf_treat_post - inf_cont_1_gt - inf_gs = -inf_treat_pre - inf_cont_1_gs - inf_ct = -(inf_cont_2_ct / sum_w_D) - inf_cs = inf_cont_2_cs / sum_w_D - - # Concatenate: treated (t then s), control (t then s) - inf_treated = np.concatenate([inf_gt, inf_gs]) - inf_control = np.concatenate([inf_ct, inf_cs]) - inf_all = np.concatenate([inf_treated, inf_control]) - + # R: inf.cont.2.post = asy.lin.rep.ols_t %*% M1 + psi_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,) + # R: inf.cont.2.pre = asy.lin.rep.ols_s %*% M1 + psi_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,) + + # --- Assemble per-group psi --- + # R: inf.treat = inf.treat.post - inf.treat.pre (across groups) + # R: inf.cont = (inf.cont.1 + inf.cont.2.post - inf.cont.2.pre) / mean(w.cont) + # R: att.inf.func = inf.treat - inf.cont + psi_gt = psi_treat_post - psi_cont_1_gt / mean_w_D + psi_gs = -psi_treat_pre - psi_cont_1_gs / mean_w_D + psi_ct = -psi_cont_2_ct / mean_w_D + psi_cs = psi_cont_2_cs / mean_w_D + + psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) + + # ================================================================= + # Convert to library convention: phi = psi / n_all + # se = sqrt(sum(phi^2)) == sqrt(sum(psi^2)) / n_all + # ================================================================= + inf_all = psi_all / n_all se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None # caller builds idx from masks @@ -2970,6 +2988,14 @@ def _ipw_estimation_rc( Propensity score P(G=g | X) estimated on pooled treated+control observations from both periods. Reweight controls in each period. + IF convention + ------------- + Intermediate terms use R's unnormalized psi_i convention throughout + (R: ``ipw_did_rc``). R computes SE as ``sd(psi) / sqrt(n)``. + At the end we convert to the library's pre-scaled phi_i = psi_i / n + convention where ``se = sqrt(sum(phi^2))``, used by the + aggregation/bootstrap layer. + Returns (att, se, inf_func_concat, idx_concat). """ n_gt = len(y_gt) @@ -3009,11 +3035,11 @@ def _ipw_estimation_rc( # Clip propensity scores pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) - # Split propensity scores (treated ps not used — only control IPW weights) + # Split propensity scores (treated ps not used -- only control IPW weights) ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] ps_cs = pscore[n_gt + n_gs + n_ct :] - # IPW weights for controls + # IPW weights for controls (R: w1.x = ps / (1 - ps)) w_ct = ps_ct / (1 - ps_ct) w_cs = ps_cs / (1 - ps_cs) @@ -3021,12 +3047,29 @@ def _ipw_estimation_rc( w_ct = sw_ct * w_ct w_cs = sw_cs * w_cs - w_ct_norm = w_ct / np.sum(w_ct) if np.sum(w_ct) > 0 else w_ct - w_cs_norm = w_cs / np.sum(w_cs) if np.sum(w_cs) > 0 else w_cs + # R: mean(w.treat.post), mean(w.treat.pre), mean(w.ipw.ct), mean(w.ipw.cs) + if sw_gt is not None: + sum_w_treat_post = np.sum(sw_gt) + sum_w_treat_pre = np.sum(sw_gs) + else: + sum_w_treat_post = float(n_gt) + sum_w_treat_pre = float(n_gs) + + mean_w_treat_post = sum_w_treat_post / n_all + mean_w_treat_pre = sum_w_treat_pre / n_all + + sum_w_ct = np.sum(w_ct) + sum_w_cs = np.sum(w_cs) + mean_w_ct = sum_w_ct / n_all + mean_w_cs = sum_w_cs / n_all + + # Hajek-normalized weights (R normalizes by sum for point estimate) + w_ct_norm = w_ct / sum_w_ct if sum_w_ct > 0 else w_ct + w_cs_norm = w_cs / sum_w_cs if sum_w_cs > 0 else w_cs if sw_gt is not None: - sw_gt_norm = sw_gt / np.sum(sw_gt) - sw_gs_norm = sw_gs / np.sum(sw_gs) + sw_gt_norm = sw_gt / sum_w_treat_post + sw_gs_norm = sw_gs / sum_w_treat_pre mu_gt = float(np.sum(sw_gt_norm * y_gt)) mu_gs = float(np.sum(sw_gs_norm * y_gs)) else: @@ -3038,64 +3081,69 @@ def _ipw_estimation_rc( att = (mu_gt - mu_ct_ipw) - (mu_gs - mu_cs_ipw) - # Influence function + # ================================================================= + # Influence function in R's unnormalized psi convention + # (R: ipw_did_rc.R, psi = n * phi) + # ================================================================= + + # --- Treated psi (R: eta.treat.post, eta.treat.pre) --- + # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post) if sw_gt is not None: - inf_gt = sw_gt_norm * (y_gt - mu_gt) - inf_gs = -sw_gs_norm * (y_gs - mu_gs) + psi_gt = sw_gt * (y_gt - mu_gt) / mean_w_treat_post + psi_gs = -sw_gs * (y_gs - mu_gs) / mean_w_treat_pre else: - inf_gt = (y_gt - mu_gt) / n_gt - inf_gs = -(y_gs - mu_gs) / n_gs + psi_gt = (y_gt - mu_gt) / mean_w_treat_post + psi_gs = -(y_gs - mu_gs) / mean_w_treat_pre - inf_ct = -w_ct_norm * (y_ct - mu_ct_ipw) - inf_cs = w_cs_norm * (y_cs - mu_cs_ipw) + # --- Control psi (R: eta.cont.post, eta.cont.pre) --- + # R: w.ipw * (y - eta.cont) / mean(w.ipw) + psi_ct = -w_ct * (y_ct - mu_ct_ipw) / mean_w_ct if mean_w_ct > 0 else np.zeros(n_ct) + psi_cs = w_cs * (y_cs - mu_cs_ipw) / mean_w_cs if mean_w_cs > 0 else np.zeros(n_cs) - inf_treated = np.concatenate([inf_gt, inf_gs]) - inf_control = np.concatenate([inf_ct, inf_cs]) - inf_all = np.concatenate([inf_treated, inf_control]) + psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) - # PS IF correction for cross-sectional IPW - X_all_int = np.column_stack([np.ones(len(D_all)), X_all]) - pscore_all = pscore # already computed and clipped + # --- PS IF correction (R: asy.lin.rep.ps %*% M2) --- + X_all_int = np.column_stack([np.ones(n_all), X_all]) - W_ps = pscore_all * (1 - pscore_all) + W_ps = pscore * (1 - pscore) if sw_all is not None: W_ps = W_ps * sw_all H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) H_ps_inv = _safe_inv(H_ps) - score_ps = (D_all - pscore_all)[:, None] * X_all_int + score_ps = (D_all - pscore)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) - # M2: gradient of IPW ATT w.r.t. PS parameters - # R: M2 = colMeans(w_ipw * (y-mu)/mean_w * X) over ALL n obs (zeros for treated). - # In our split convention phi = psi/n_all, so M2_rc = R's M2 / n_all. - # R's M2 = sum(w_ct_norm * (y-mu) * X_ct) [the mean_w normalization cancels]. - # So M2_rc = sum(...) / n_all. Old code used np.mean → sum/n_ct (wrong). + # ================================================================= + # Convert leading psi to library phi convention: phi = psi / n_all + # ================================================================= + inf_all = psi_all / n_all + + # --- PS nuisance correction (added in phi convention) --- + # R: M2 = colMeans(w_ipw * (y-mu) * X) across ALL n obs. + # colMeans = sum/n_all; treated rows contribute zero. + # asy_lin_rep_ps @ M2 matches R's psi_ps but with our bread + # convention (no 1/n), so the product is already in phi scale. ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw) ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw) - # Zero for treated observations - M2_rc = np.zeros(X_all_int.shape[1]) - # Control-t contribution: sum / n_all (NOT np.mean which divides by n_ct) - M2_rc += ( - np.sum( - ipw_resid_ct[:, None] * X_all_int[n_gt + n_gs : n_gt + n_gs + n_ct], - axis=0, - ) - / n_all - ) - # Control-s contribution (opposite sign -- base period) - M2_rc -= ( - np.sum( - ipw_resid_cs[:, None] * X_all_int[n_gt + n_gs + n_ct :], - axis=0, - ) - / n_all - ) - inf_all = inf_all + asy_lin_rep_ps @ M2_rc + ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) + cs_slice = slice(n_gt + n_gs + n_ct, None) + + # R: M2 = colMeans(...) = sum(...) / n + M2 = np.zeros(X_all_int.shape[1]) + M2 += np.sum(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0) / n_all + M2 -= np.sum(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0) / n_all + + # R: att.inf.func += asy.lin.rep.ps %*% M2 (phi-scale correction) + inf_all = inf_all + asy_lin_rep_ps @ M2 + # ================================================================= + # SE from phi: se = sqrt(sum(phi^2)) + # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0. + # ================================================================= se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None @@ -3123,6 +3171,14 @@ def _doubly_robust_rc( Locally efficient DR estimator with 4 OLS fits (control pre/post, treated pre/post) plus propensity score. + IF convention + ------------- + Intermediate terms use R's unnormalized psi_i convention throughout + (R: ``drdid_rc``). R computes SE as ``sd(psi) / sqrt(n)``. + At the end we convert to the library's pre-scaled phi_i = psi_i / n + convention where ``se = sqrt(sum(phi^2))``, used by the + aggregation/bootstrap layer. + Returns (att, se, inf_func_concat, idx_concat). """ n_gt = len(y_gt) @@ -3191,11 +3247,10 @@ def _doubly_robust_rc( mu1_pre_gs = X_gs_int @ beta_gs # mu_{1,0}(X) for treated-pre # mu_{0,Y}(T_i, X_i): control OR evaluated at own period - # For post-period obs: mu_{0,1}(X), for pre-period obs: mu_{0,0}(X) - mu0Y_gt = mu0_post_gt # treated-post → use post control model - mu0Y_gs = mu0_pre_gs # treated-pre → use pre control model - mu0Y_ct = mu0_post_ct # control-post → use post control model - mu0Y_cs = mu0_pre_cs # control-pre → use pre control model + mu0Y_gt = mu0_post_gt # treated-post: use post control model + mu0Y_gs = mu0_pre_gs # treated-pre: use pre control model + mu0Y_ct = mu0_post_ct # control-post: use post control model + mu0Y_cs = mu0_pre_cs # control-pre: use pre control model # ===================================================================== # 2. Propensity score @@ -3235,7 +3290,7 @@ def _doubly_robust_rc( ps_cs = pscore[n_gt + n_gs + n_ct :] # ===================================================================== - # 3. Group weights + # 3. Group weights and R-convention means # ===================================================================== if sw_gt is not None: w_treat_post = sw_gt @@ -3252,6 +3307,11 @@ def _doubly_robust_rc( sum_w_treat_pre = np.sum(w_treat_pre) sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) + # R: mean(w) = sum(w) / n -- used in psi normalizers + mean_w_treat_post = sum_w_treat_post / n_all + mean_w_treat_pre = sum_w_treat_pre / n_all + mean_w_D = sum_w_D / n_all + # IPW control weights: sw * ps/(1-ps) for controls w_ipw_ct = ps_ct / (1 - ps_ct) w_ipw_cs = ps_cs / (1 - ps_cs) @@ -3259,6 +3319,11 @@ def _doubly_robust_rc( w_ipw_ct = sw_ct * w_ipw_ct w_ipw_cs = sw_cs * w_ipw_cs + sum_w_ipw_ct = np.sum(w_ipw_ct) + sum_w_ipw_cs = np.sum(w_ipw_cs) + mean_w_ipw_ct = sum_w_ipw_ct / n_all + mean_w_ipw_cs = sum_w_ipw_cs / n_all + # ===================================================================== # 4. Point estimate: tau_1 (AIPW using control ORs) # ===================================================================== @@ -3266,8 +3331,6 @@ def _doubly_robust_rc( eta_treat_post = np.sum(w_treat_post * (y_gt - mu0Y_gt)) / sum_w_treat_post eta_treat_pre = np.sum(w_treat_pre * (y_gs - mu0Y_gs)) / sum_w_treat_pre - sum_w_ipw_ct = np.sum(w_ipw_ct) - sum_w_ipw_cs = np.sum(w_ipw_cs) eta_cont_post = ( np.sum(w_ipw_ct * (y_ct - mu0Y_ct)) / sum_w_ipw_ct if sum_w_ipw_ct > 0 else 0.0 ) @@ -3286,13 +3349,13 @@ def _doubly_robust_rc( or_diff_pre_gt = mu1_pre_gt - mu0_pre_gt # at treated-post or_diff_pre_gs = mu1_pre_gs - mu0_pre_gs # at treated-pre - # att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) — all treated + # att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) -- all treated att_d_post = (np.sum(w_D_gt * or_diff_post_gt) + np.sum(w_D_gs * or_diff_post_gs)) / sum_w_D - # att_dt1_post — treated-post only + # att_dt1_post -- treated-post only att_dt1_post = np.sum(w_treat_post * or_diff_post_gt) / sum_w_treat_post - # att_d_pre — all treated + # att_d_pre -- all treated att_d_pre = (np.sum(w_D_gt * or_diff_pre_gt) + np.sum(w_D_gs * or_diff_pre_gs)) / sum_w_D - # att_dt0_pre — treated-pre only + # att_dt0_pre -- treated-pre only att_dt0_pre = np.sum(w_treat_pre * or_diff_pre_gs) / sum_w_treat_pre tau_2 = (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre) @@ -3300,63 +3363,71 @@ def _doubly_robust_rc( att = float(tau_1 + tau_2) # ===================================================================== - # 6. Influence function: tau_1 components + # 6. Influence function in R's unnormalized psi convention + # (R: drdid_rc.R, psi = n * phi) # ===================================================================== - # Treated IF (period-specific Hajek) - inf_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / sum_w_treat_post - inf_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / sum_w_treat_pre - - # Control IF (IPW Hajek) - inf_cont_post_ct = ( - w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / sum_w_ipw_ct - if sum_w_ipw_ct > 0 + + # --- tau_1: treated psi (R: eta.treat.post / mean(w.treat.post)) --- + # R: w.treat.post * (y - mu0Y - eta.treat.post) / mean(w.treat.post) + psi_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / mean_w_treat_post + psi_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / mean_w_treat_pre + + # --- tau_1: control psi (R: eta.cont.post / mean(w.ipw)) --- + # R: w.ipw * (y - mu0Y - eta.cont) / mean(w.ipw) + psi_cont_post_ct = ( + w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / mean_w_ipw_ct + if mean_w_ipw_ct > 0 else np.zeros(n_ct) ) - inf_cont_pre_cs = ( - w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / sum_w_ipw_cs - if sum_w_ipw_cs > 0 + psi_cont_pre_cs = ( + w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / mean_w_ipw_cs + if mean_w_ipw_cs > 0 else np.zeros(n_cs) ) - # tau_1 IF per group (plug-in, before nuisance corrections) - inf_gt_tau1 = inf_treat_post - inf_gs_tau1 = -inf_treat_pre - inf_ct_tau1 = -inf_cont_post_ct - inf_cs_tau1 = inf_cont_pre_cs + # tau_1 psi per group + psi_gt_tau1 = psi_treat_post + psi_gs_tau1 = -psi_treat_pre + psi_ct_tau1 = -psi_cont_post_ct + psi_cs_tau1 = psi_cont_pre_cs # ===================================================================== - # 7. Influence function: tau_2 leading terms + # 7. tau_2 leading terms (R: att.d.post, att.dt1.post, etc.) # ===================================================================== - # att_d_post IF: w_D*(or_diff_post - att_d_post) / sum_w_D - inf_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / sum_w_D - inf_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / sum_w_D - # att_dt1_post IF: w_treat_post*(or_diff_post - att_dt1_post) / sum_w_treat_post - inf_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / sum_w_treat_post - # att_d_pre IF - inf_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / sum_w_D - inf_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / sum_w_D - # att_dt0_pre IF - inf_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / sum_w_treat_pre - - # tau_2 IF per group - inf_gt_tau2 = (inf_d_post_gt - inf_dt1_post) - inf_d_pre_gt - inf_gs_tau2 = inf_d_post_gs - (-inf_dt0_pre + inf_d_pre_gs) - # Control obs don't contribute to tau_2 leading terms (w_D = 0 for controls) + # R: w.D * (or_diff - att.d.post) / mean(w.D) + psi_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / mean_w_D + psi_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / mean_w_D + # R: w.treat.post * (or_diff - att.dt1.post) / mean(w.treat.post) + psi_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / mean_w_treat_post + # R: w.D * (or_diff_pre - att.d.pre) / mean(w.D) + psi_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / mean_w_D + psi_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / mean_w_D + # R: w.treat.pre * (or_diff_pre - att.dt0.pre) / mean(w.treat.pre) + psi_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / mean_w_treat_pre + + # tau_2 psi per group (controls contribute zero) + psi_gt_tau2 = (psi_d_post_gt - psi_dt1_post) - psi_d_pre_gt + psi_gs_tau2 = psi_d_post_gs - (-psi_dt0_pre + psi_d_pre_gs) # ===================================================================== - # 8. Combined plug-in IF (before nuisance corrections) + # 8. Combined plug-in psi (before nuisance corrections) # ===================================================================== - inf_gt = inf_gt_tau1 + inf_gt_tau2 - inf_gs = inf_gs_tau1 + inf_gs_tau2 - inf_ct = inf_ct_tau1 - inf_cs = inf_cs_tau1 + psi_gt = psi_gt_tau1 + psi_gt_tau2 + psi_gs = psi_gs_tau1 + psi_gs_tau2 + psi_ct = psi_ct_tau1 + psi_cs = psi_cs_tau1 - inf_treated = np.concatenate([inf_gt, inf_gs]) - inf_control = np.concatenate([inf_ct, inf_cs]) - inf_all = np.concatenate([inf_treated, inf_control]) + psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) + + # ================================================================= + # Convert leading psi to library phi convention: phi = psi / n_all + # ================================================================= + inf_all = psi_all / n_all # ===================================================================== - # 9. PS IF correction + # 9. PS nuisance correction (phi-scale) + # asy_lin_rep_ps @ M2 with M2 = colMeans(...) is already in phi + # scale because our bread omits R's 1/n factor while M2 absorbs it. # ===================================================================== X_all_int = np.column_stack([np.ones(n_all), X_all]) @@ -3371,8 +3442,8 @@ def _doubly_robust_rc( score_ps = score_ps * sw_all[:, None] asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) - # M2: gradient of tau_1 control IPW w.r.t. PS parameters - # Only control obs contribute to M2 (through their IPW weights) + # R: M2 = colMeans(w_ipw * dr_resid / mean(w_ipw) * X) + # colMeans = sum / n_all; treated rows contribute zero. ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) cs_slice = slice(n_gt + n_gs + n_ct, None) @@ -3400,55 +3471,46 @@ def _doubly_robust_rc( inf_all = inf_all + asy_lin_rep_ps @ M2 # ===================================================================== - # 10. Control OR IF corrections (tau_1 estimation effect) + # 10. Control OR nuisance corrections (phi-scale) # ===================================================================== - # bread = (X'WX)^{-1} for each control OLS W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct) W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs) bread_ct = _safe_inv(X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int)) bread_cs = _safe_inv(X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int)) - # ALR for control OLS + # R: asy.lin.rep.ols (per-obs OLS score * bread) asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct asy_lin_rep_cs = (W_cs_vals * resid_cs)[:, None] * X_cs_int @ bread_cs - # M1 for control-post model (beta_ct): gradient from tau_1 - # Treated-post contributes -w_treat_post*X/sum_w_treat_post (via mu0Y_gt = X@beta_ct) - # Control-post contributes -w_ipw_ct*X/sum_w_ipw_ct (via mu0Y_ct = X@beta_ct) - # Also contributes from tau_2: att_d_post uses mu0_post, att_dt1_post uses mu0_post - # For tau_2: w_D*(-X)/sum_w_D from att_d_post + w_treat_post*X/sum_w_treat_post from att_dt1_post - M1_ct = np.zeros(X_all_int.shape[1] - 1 + 1) # p+1 (with intercept) - # From eta_treat_post (mu0Y_gt = X@beta_ct): + # M1 for control-post model (beta_ct): gradient from tau_1 + tau_2 + # tau_1: -w_treat_post*X/sum_w_treat_post (eta_treat_post via mu0Y_gt) + # +w_ipw_ct*X/sum_w_ipw_ct (eta_cont_post via mu0Y_ct) + # tau_2: -w_D*X/sum_w_D (att_d_post via mu0_post at all treated) + # +w_treat_post*X/sum_w_treat_post (att_dt1_post via mu0_post) + M1_ct = np.zeros(X_all_int.shape[1]) M1_ct -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post - # From eta_cont_post (mu0Y_ct = X@beta_ct): if sum_w_ipw_ct > 0: M1_ct += np.sum(w_ipw_ct[:, None] * X_ct_int, axis=0) / sum_w_ipw_ct - # From tau_2 att_d_post: -w_D * X / sum_w_D (mu0_post at all treated) M1_ct -= ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / sum_w_D - # From tau_2 att_dt1_post: +w_treat_post * X / sum_w_treat_post (mu0_post at treated-post) M1_ct += np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post - # M1 for control-pre model (beta_cs): + # M1 for control-pre model (beta_cs) M1_cs = np.zeros(X_all_int.shape[1]) - # From eta_treat_pre (mu0Y_gs = X@beta_cs): M1_cs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre - # From eta_cont_pre (mu0Y_cs = X@beta_cs): if sum_w_ipw_cs > 0: M1_cs -= np.sum(w_ipw_cs[:, None] * X_cs_int, axis=0) / sum_w_ipw_cs - # From tau_2 att_d_pre: +w_D * X / sum_w_D (mu0_pre at all treated) M1_cs += ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) ) / sum_w_D - # From tau_2 att_dt0_pre: -w_treat_pre * X / sum_w_treat_pre (mu0_pre at treated-pre) M1_cs -= np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_ct @ M1_ct inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_cs @ M1_cs # ===================================================================== - # 11. Treated OR IF corrections (tau_2 estimation effect) + # 11. Treated OR nuisance corrections (phi-scale) # ===================================================================== W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs) @@ -3459,8 +3521,8 @@ def _doubly_robust_rc( asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs # M1 for treated-post model (beta_gt): mu_{1,1}(X) - # From att_d_post: +w_D*X/sum_w_D (mu1_post at all treated) - # From att_dt1_post: -w_treat_post*X/sum_w_treat_post (mu1_post at treated-post) + # From att_d_post: +w_D*X/sum_w_D (all treated) + # From att_dt1_post: -w_treat_post*X/sum_w_treat_post (treated-post) M1_gt = np.zeros(X_all_int.shape[1]) M1_gt += ( np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) @@ -3479,6 +3541,10 @@ def _doubly_robust_rc( inf_all[:n_gt] += asy_lin_rep_gt @ M1_gt inf_all[n_gt : n_gt + n_gs] += asy_lin_rep_gs @ M1_gs + # ================================================================= + # SE from phi: se = sqrt(sum(phi^2)) + # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0. + # ================================================================= se = float(np.sqrt(np.sum(inf_all**2))) idx_all = None From eac680ee53021c9a5b1a8bf234456f925f16a15f Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 10:27:58 -0400 Subject: [PATCH 10/19] Match R's H/n, asy_rep/n, colMeans convention for panel PS corrections; fix VCV index subsetting Panel IPW/DR PS corrections: restructure to match R's std_ipw_did_panel / drdid_panel convention: H = X'WX/n, asy_lin_rep = score @ solve(H) / n, M2 = colMeans(). Algebraically equivalent but mirrors R source literally. HonestDiD VCV subsetting: store event_study_vcov_index (the exact event-time ordering matching VCV columns) so subsetting works correctly even when universal base period injects a reference row into event_study_effects. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 22 +++++++----- diff_diff/staggered.py | 56 +++++++++++++----------------- diff_diff/staggered_aggregation.py | 6 ++++ diff_diff/staggered_results.py | 3 +- 4 files changed, 47 insertions(+), 40 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 824e0ce1..7575da61 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -665,15 +665,21 @@ def _extract_event_study_params( # otherwise fall back to diagonal from SEs if hasattr(results, "event_study_vcov") and results.event_study_vcov is not None: vcov = results.event_study_vcov - # VCV is indexed by ALL event times from aggregation; - # rel_times may be a filtered subset (NaN-SE times dropped). - # Subset VCV to match the surviving rel_times. - all_event_times = sorted(results.event_study_effects.keys()) - if vcov.shape[0] == len(all_event_times) and len(rel_times) < len(all_event_times): - idx = [all_event_times.index(t) for t in rel_times] - sigma = vcov[np.ix_(idx, idx)] - else: + # VCV is indexed by the aggregated event times (stored in + # event_study_vcov_index), NOT by event_study_effects keys + # (which may include an injected reference period). + # Subset to match the surviving rel_times. + vcov_index = getattr(results, "event_study_vcov_index", None) + if vcov_index is not None and len(rel_times) < len(vcov_index): + idx = [vcov_index.index(t) for t in rel_times if t in vcov_index] + if len(idx) == len(rel_times): + sigma = vcov[np.ix_(idx, idx)] + else: + sigma = np.diag(np.array(ses) ** 2) + elif vcov.shape[0] == len(rel_times): sigma = vcov + else: + sigma = np.diag(np.array(ses) ** 2) else: sigma = np.diag(np.array(ses) ** 2) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 79ca3157..047af3fb 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1775,8 +1775,10 @@ def fit( # Clear it when bootstrap overwrites event-study SEs to prevent # HonestDiD from mixing analytical VCV with bootstrap SEs. event_study_vcov = getattr(self, "_event_study_vcov", None) + event_study_vcov_index = getattr(self, "_event_study_vcov_index", None) if bootstrap_results is not None and event_study_vcov is not None: event_study_vcov = None + event_study_vcov_index = None self.results_ = CallawaySantAnnaResults( group_time_effects=group_time_effects, @@ -1800,6 +1802,7 @@ def fit( pscore_trim=self.pscore_trim, survey_metadata=survey_metadata, event_study_vcov=event_study_vcov, + event_study_vcov_index=event_study_vcov_index, panel=self.panel, ) @@ -2032,35 +2035,29 @@ def _ipw_estimation( X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) pscore_all = np.concatenate([pscore_treated, pscore_control]) - # Survey-weighted PS Hessian: sum(w_i * mu_i * (1-mu_i) * x_i * x_i') + # PS IF correction — matches R's std_ipw_did_panel convention: + # H = X'WX / n, asy_lin_rep = score @ solve(H) / n, M2 = colMeans + n_all_panel = n_t + n_c W_ps = pscore_all * (1 - pscore_all) if sw_all is not None: W_ps = W_ps * sw_all - H = X_all_int.T @ (W_ps[:, None] * X_all_int) - try: - H_inv = np.linalg.solve(H, np.eye(H.shape[0])) - except np.linalg.LinAlgError: - H_inv = np.linalg.lstsq(H, np.eye(H.shape[0]), rcond=None)[0] + H = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel + H_inv = _safe_inv(H) - # PS score: w_i * (D_i - pi_i) * X_i D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] - asy_lin_rep_ps = score_ps @ H_inv # shape (n_t + n_c, p) + asy_lin_rep_ps = score_ps @ H_inv / n_all_panel - # M2: gradient of ATT w.r.t. PS parameters - # R convention: colMeans over ALL n obs (zero for treated rows) att_control_weighted = np.sum(weights_control_norm * control_change) - M2 = np.sum( + M2 = np.mean( (weights_control_norm * (control_change - att_control_weighted))[:, None] * X_all_int[n_t:], axis=0, - ) / (n_t + n_c) + ) - # PS correction to influence function - inf_ps_correction = asy_lin_rep_ps @ M2 - inf_func = inf_func + inf_ps_correction + inf_func = inf_func + asy_lin_rep_ps @ M2 # SE from influence function variance var_psi = np.sum(inf_func**2) @@ -2295,29 +2292,26 @@ def _doubly_robust( ) pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) - # Survey-weighted PS Hessian + # PS IF correction — R convention: H/n, asy_rep/n, colMeans + n_all_panel = n_t + n_c W_ps = pscore_all * (1 - pscore_all) if sw_all is not None: W_ps = W_ps * sw_all - H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel H_ps_inv = _safe_inv(H_ps) - # PS score D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] - asy_lin_rep_ps = score_ps @ H_ps_inv # (n_t+n_c, p+1) + asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel - # M2_dr: dATT/dgamma — gradient of DR ATT w.r.t. PS parameters - # Only the control augmentation term depends on PS via w_ipw - # R convention: colMeans over ALL n obs (zero for treated rows) dr_resid_control = m_control - control_change - M2_dr = np.sum( + M2_dr = np.mean( ((weights_control / sw_t_sum) * dr_resid_control)[:, None] * X_all_int[n_t:], axis=0, - ) / (n_t + n_c) + ) inf_func = inf_func + asy_lin_rep_ps @ M2_dr # --- OR IF correction --- @@ -2358,27 +2352,27 @@ def _doubly_robust( inf_func = np.concatenate([psi_treated, psi_control]) if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: - # --- PS IF correction --- - X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) + # --- PS IF correction — R convention: H/n, asy_rep/n, colMeans --- + n_all_panel = n_t + n_c + X_all_int = np.column_stack([np.ones(n_all_panel), X_all]) pscore_treated_clipped = np.clip( pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim ) pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) W_ps = pscore_all * (1 - pscore_all) - H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel H_ps_inv = _safe_inv(H_ps) D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int - asy_lin_rep_ps = score_ps @ H_ps_inv + asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel - # R convention: colMeans over ALL n obs (zero for treated rows) dr_resid_control = m_control - control_change - M2_dr = np.sum( + M2_dr = np.mean( ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:], axis=0, - ) / (n_t + n_c) + ) inf_func = inf_func + asy_lin_rep_ps @ M2_dr # --- OR IF correction --- diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 9bfb193b..d760db4d 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -751,6 +751,12 @@ def _aggregate_event_study( except (ValueError, np.linalg.LinAlgError): pass # Fall back to diagonal (None) + # Store the event-time index that matches VCV columns (for subsetting + # in HonestDiD when some event times are filtered out) + self._event_study_vcov_index = ( + [e for e, _ in sorted_periods] if event_study_vcov is not None else None + ) + # Attach VCV to self for CallawaySantAnna to pick up self._event_study_vcov = event_study_vcov diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index b21af0df..65132af3 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -115,8 +115,9 @@ class CallawaySantAnnaResults: event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None) influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) - # Full event-study VCV matrix (Phase 7d): indexed by sorted relative times + # Full event-study VCV matrix (Phase 7d): indexed by event_study_vcov_index event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False) + event_study_vcov_index: Optional[list] = field(default=None, repr=False) bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None pscore_trim: float = 0.01 From 9893454d3bedb6a973051436e589d0381a759d47 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 10:53:15 -0400 Subject: [PATCH 11/19] Fix VCV index alignment, add stationarity warning for panel=False VCV index: Build event_study_vcov_index from event times that actually contributed psi vectors (skipping NaN-only periods), not from all sorted_periods. Fixes misalignment when interior event times drop out. Stationarity: Document "stationary repeated cross-sections" in panel parameter docstring. Emit UserWarning on panel=False noting the stationarity assumption is not data-checkable. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 19 ++++++++++++++++--- diff_diff/staggered_aggregation.py | 6 +++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 047af3fb..c84f4d1a 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -183,9 +183,11 @@ class CallawaySantAnna( in IPW and DR estimation. Must be in ``(0, 0.5)``. panel : bool, default=True Whether the data is a balanced/unbalanced panel (units observed - across multiple time periods). Set to ``False`` for repeated - cross-sections where each observation has a unique unit ID and - units do not repeat across periods. Uses cross-sectional DRDID + across multiple time periods). Set to ``False`` for stationary + repeated cross-sections where each observation has a unique unit + ID and units do not repeat across periods. Requires that the + cross-sectional samples are drawn from the same population in + each period (stationarity). Uses cross-sectional DRDID (Sant'Anna & Zhao 2020, Section 4) with per-observation influence functions. @@ -1323,6 +1325,17 @@ def fit( # Reset stale state from prior fit (prevents leaking event-study VCV) self._event_study_vcov = None + if not self.panel: + warnings.warn( + "panel=False uses repeated cross-section DRDID estimators " + "(Sant'Anna & Zhao 2020, Section 4) which assume stationary " + "cross-sectional sampling: the population distribution of " + "(Y, X, G) must be stable across periods. This assumption " + "is not data-checkable.", + UserWarning, + stacklevel=2, + ) + # Normalize empty covariates list to None if covariates is not None and len(covariates) == 0: covariates = None diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index d760db4d..312e50ad 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -627,6 +627,7 @@ def _aggregate_event_study( agg_ses_list = [] agg_n_groups = [] _psi_vectors = [] # Per-event-time combined IF vectors for VCV + _psi_event_times = [] # Event times that contributed a psi column for e, effect_list in sorted_periods: gt_pairs = [x[0] for x in effect_list] effs = np.array([x[1] for x in effect_list]) @@ -666,6 +667,7 @@ def _aggregate_event_study( agg_ses_list.append(agg_se) agg_n_groups.append(len(effect_list)) _psi_vectors.append(psi_e) + _psi_event_times.append(e) # Batch inference for all relative periods if not agg_effects_list: @@ -753,9 +755,7 @@ def _aggregate_event_study( # Store the event-time index that matches VCV columns (for subsetting # in HonestDiD when some event times are filtered out) - self._event_study_vcov_index = ( - [e for e, _ in sorted_periods] if event_study_vcov is not None else None - ) + self._event_study_vcov_index = _psi_event_times if event_study_vcov is not None else None # Attach VCV to self for CallawaySantAnna to pick up self._event_study_vcov = event_study_vcov From 44150347e1674c785f857fb59d4604ae4fbffcda Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 11:54:59 -0400 Subject: [PATCH 12/19] Document panel DR control-augmentation normalization deviation from DRDID Panel DR normalizes control augmentation by treated mass (sw_t_sum/n_t) rather than control IPW mass (sum(w_cont)) as in DRDID::drdid_panel. Both are asymptotically equivalent; finite-sample difference documented as intentional deviation in REGISTRY.md. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/methodology/REGISTRY.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 3e02ad33..9c002cb0 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -417,6 +417,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) - **Note:** CallawaySantAnna survey support: weights, strata, PSU, and FPC are all supported for all estimation methods (reg, ipw, dr) with or without covariates. Analytical (`n_bootstrap=0`): aggregated SEs use design-based variance via `compute_survey_if_variance()`. Bootstrap (`n_bootstrap>0`): PSU-level multiplier weights replace analytical SEs for aggregated quantities. IPW and DR with covariates use DRDID panel nuisance IF corrections (Phase 7a: PS IF correction via survey-weighted Hessian/score, OR IF correction via WLS bread and gradient; Sant'Anna & Zhao 2020, Theorem 3.1). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. +- **Note (deviation from R):** Panel DR control augmentation is normalized by treated mass (`sw_t_sum` or `n_t`) rather than control IPW mass (`sum(w_cont)`). R's `DRDID::drdid_panel` uses `mean(w.cont)` as the control normalizer. Both are consistent asymptotically (under correct model specification, `E[w_cont] = E[D]` so the normalizers converge), but they differ in finite samples when IPW reweighting doesn't perfectly balance. The treated-mass normalization is simpler and matches the `did::att_gt` convention where ATT is defined per treated unit. Aligning to `DRDID::drdid_panel`'s exact `w.cont` normalization is deferred. - **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SE uses a conservative plug-in IF based on WLS residuals. The treated IF is `inf_treated_i = (sw_i/sum(sw_treated)) * (resid_i - ATT)` (normalized by treated weight sum, matching unweighted `(resid-ATT)/n_t`). The control IF is `inf_control_i = -(sw_i/sum(sw_control)) * wls_resid_i` (normalized by control weight sum, matching unweighted `-resid/n_c`). SE is computed as `sqrt(sum(sw_t_norm * (resid_t - ATT)^2) + sum(sw_c_norm * resid_c^2))`, the weighted analogue of the unweighted `sqrt(var_t/n_t + var_c/n_c)`. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel` — WLS residuals are orthogonal to the weighted design matrix by construction, so the first-order IF term is asymptotically valid but may be conservative. SEs pass weight-scale-invariance tests. The efficient DRDID correction is deferred to future work. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization. When strata/PSU/FPC are present, analytical aggregated SEs (`n_bootstrap=0`) use `compute_survey_if_variance()` on the combined IF/WIF; bootstrap aggregated SEs (`n_bootstrap>0`) use PSU-level multiplier weights. From 1c35440b95c26f547ab7ad97a8d02107879e0604 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 12:47:30 -0400 Subject: [PATCH 13/19] Warn on non-universal base period in HonestDiD CS path, update tests Add UserWarning when CallawaySantAnna results with base_period != "universal" are passed to HonestDiD (R requires universal; we warn). Document as deviation in REGISTRY.md. Update Phase 7d HonestDiD tests to use base_period="universal" for methodologically valid bounds. Fix VCV diagonal test index for universal base reference period row. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 16 ++++++++++++++++ docs/methodology/REGISTRY.md | 1 + tests/test_honest_did.py | 19 ++++++++----------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 7575da61..4749e756 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -640,6 +640,22 @@ def _extract_event_study_params( "event study effects." ) + # Warn if not using universal base period (R's HonestDiD requires it) + if getattr(results, "base_period", "universal") != "universal": + import warnings + + warnings.warn( + "HonestDiD sensitivity analysis on CallawaySantAnna results " + "requires base_period='universal' for valid interpretation. " + "With base_period='varying', pre-treatment coefficients use " + "consecutive comparisons (not a common reference period), " + "which changes the meaning of the parallel trends restriction. " + "Re-run with CallawaySantAnna(base_period='universal') for " + "methodologically valid HonestDiD bounds.", + UserWarning, + stacklevel=3, + ) + # Extract event study effects by relative time # Filter out normalization constraints (n_groups=0) and non-finite SEs event_effects = { diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 9c002cb0..828aa21d 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1637,6 +1637,7 @@ Confidence intervals: - M=0: reduces to standard parallel trends - Negative M: not valid (constraints become infeasible) - **Note:** Phase 7d: survey variance support. When input results carry `survey_metadata` with `df_survey`, HonestDiD uses t-distribution critical values (via `_get_critical_value(alpha, df)`) instead of normal. CallawaySantAnnaResults now stores `event_study_vcov` (full cross-event-time VCV from IF vectors), which HonestDiD uses instead of the diagonal fallback. For replicate-weight designs, the event-study VCV falls back to diagonal (multivariate replicate VCV deferred). +- **Note (deviation from R):** When CallawaySantAnna results are passed to HonestDiD, `base_period != "universal"` emits a warning but does not error. R's `honest_did::honest_did.AGGTEobj` requires universal base period. Our implementation warns because the varying-base pre-treatment coefficients use consecutive comparisons (not a common reference), which changes the parallel-trends restriction interpretation. **Reference implementation(s):** - R: `HonestDiD` package (Rambachan & Roth's official package) diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index 43c4b368..9fbb829b 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -1137,7 +1137,7 @@ def test_df_survey_extracted_from_cs_results(self): data["psu"] = (np.arange(n_units) // 5)[idx] sd = SurveyDesign(weights="weight", strata="stratum", psu="psu") - cs_result = CallawaySantAnna().fit( + cs_result = CallawaySantAnna(base_period="universal").fit( data, "outcome", "unit", @@ -1160,7 +1160,7 @@ def test_event_study_vcov_computed(self): from diff_diff import CallawaySantAnna, generate_staggered_data data = generate_staggered_data(n_units=100, n_periods=6, seed=42) - cs_result = CallawaySantAnna().fit( + cs_result = CallawaySantAnna(base_period="universal").fit( data, "outcome", "unit", @@ -1180,12 +1180,9 @@ def test_event_study_vcov_computed(self): ) assert cs_result.event_study_vcov.shape == (n_effects, n_effects) - # Diagonal should match squared SEs - for i, e in enumerate(sorted(cs_result.event_study_effects.keys())): - info = cs_result.event_study_effects[e] - if info.get("n_groups", 1) > 0 and np.isfinite(info.get("se", np.nan)): - # VCV diagonal should be close to SE^2 (not exact due to IF aggregation) - assert cs_result.event_study_vcov[i, i] > 0 + # Diagonal should be positive + for i in range(n_effects): + assert cs_result.event_study_vcov[i, i] > 0 def test_survey_df_widens_bounds(self): """Survey df (t-distribution) should give wider CIs than normal.""" @@ -1203,7 +1200,7 @@ def test_survey_df_widens_bounds(self): data["psu"] = (np.arange(n_units) // 25)[idx] sd = SurveyDesign(weights="weight", strata="stratum", psu="psu") - cs_result = CallawaySantAnna().fit( + cs_result = CallawaySantAnna(base_period="universal").fit( data, "outcome", "unit", @@ -1228,7 +1225,7 @@ def test_no_survey_gives_none_df(self): from diff_diff import CallawaySantAnna, generate_staggered_data data = generate_staggered_data(n_units=100, n_periods=5, seed=42) - cs_result = CallawaySantAnna().fit( + cs_result = CallawaySantAnna(base_period="universal").fit( data, "outcome", "unit", @@ -1285,7 +1282,7 @@ def test_bootstrap_fit_clears_analytical_vcov(self): from diff_diff import CallawaySantAnna, generate_staggered_data data = generate_staggered_data(n_units=100, n_periods=5, seed=42) - cs_result = CallawaySantAnna(n_bootstrap=49, seed=42).fit( + cs_result = CallawaySantAnna(n_bootstrap=49, seed=42, base_period="universal").fit( data, "outcome", "unit", From 867cd51ef469342b6b01381e7056bbf384a5e68d Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 13:04:11 -0400 Subject: [PATCH 14/19] Fix panel M2 full-sample colMeans, add HonestDiD consecutive event-time guard Panel IPW/DR M2 gradients: change np.mean(control_slice) to np.sum(control_slice) / n_all_panel, matching R's colMeans over ALL n observations (zero for treated rows). Previous np.mean divided by n_c instead of n_all. HonestDiD: warn when retained CS event-study horizons are not consecutive (interior gaps change smoothness/RM restriction geometry). R requires a consecutive grid. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 15 +++++++++++++++ diff_diff/staggered.py | 32 +++++++++++++++++++++----------- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 4749e756..555f0069 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -699,6 +699,21 @@ def _extract_event_study_params( else: sigma = np.diag(np.array(ses) ** 2) + # Warn if event times have interior gaps (R requires consecutive) + if len(rel_times) >= 2: + diffs = [rel_times[i + 1] - rel_times[i] for i in range(len(rel_times) - 1)] + if any(d != 1 for d in diffs): + import warnings + + warnings.warn( + "HonestDiD: retained event-study horizons are not consecutive " + f"({rel_times}). Interior gaps change the geometry of smoothness " + "and relative-magnitude restrictions. R's HonestDiD requires " + "a consecutive event-time grid.", + UserWarning, + stacklevel=3, + ) + # Extract survey df df_survey = None if hasattr(results, "survey_metadata") and results.survey_metadata is not None: diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index c84f4d1a..64dfbb67 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -2064,10 +2064,14 @@ def _ipw_estimation( asy_lin_rep_ps = score_ps @ H_inv / n_all_panel att_control_weighted = np.sum(weights_control_norm * control_change) - M2 = np.mean( - (weights_control_norm * (control_change - att_control_weighted))[:, None] - * X_all_int[n_t:], - axis=0, + # R colMeans: sum over control rows / n_all (not / n_c) + M2 = ( + np.sum( + (weights_control_norm * (control_change - att_control_weighted))[:, None] + * X_all_int[n_t:], + axis=0, + ) + / n_all_panel ) inf_func = inf_func + asy_lin_rep_ps @ M2 @@ -2320,10 +2324,13 @@ def _doubly_robust( asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel dr_resid_control = m_control - control_change - M2_dr = np.mean( - ((weights_control / sw_t_sum) * dr_resid_control)[:, None] - * X_all_int[n_t:], - axis=0, + M2_dr = ( + np.sum( + ((weights_control / sw_t_sum) * dr_resid_control)[:, None] + * X_all_int[n_t:], + axis=0, + ) + / n_all_panel ) inf_func = inf_func + asy_lin_rep_ps @ M2_dr @@ -2382,9 +2389,12 @@ def _doubly_robust( asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel dr_resid_control = m_control - control_change - M2_dr = np.mean( - ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:], - axis=0, + M2_dr = ( + np.sum( + ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:], + axis=0, + ) + / n_all_panel ) inf_func = inf_func + asy_lin_rep_ps @ M2_dr From 9f3cab4d4df2ae2f31860a80d62ade0104305183 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 15:46:59 -0400 Subject: [PATCH 15/19] Fix HonestDiD grid validator for reference-period gap, defensive bootstrap for zero-mass cells HonestDiD: validate pre and post blocks separately for consecutive horizons, allowing the expected reference-period gap between them. Only warns on true interior gaps within a block (e.g., missing -2 in [-3,-1,0,1]). Bootstrap: filter gt_pairs to only those with influence_func_info, skipping zero-mass (g,t) cells that have NaN ATT but no IF record. Prevents KeyError when RCS survey weights zero out a cell. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 36 +++++++++++++++++++------------- diff_diff/staggered_bootstrap.py | 5 +++-- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 555f0069..7f7c2ae0 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -699,20 +699,28 @@ def _extract_event_study_params( else: sigma = np.diag(np.array(ses) ** 2) - # Warn if event times have interior gaps (R requires consecutive) - if len(rel_times) >= 2: - diffs = [rel_times[i + 1] - rel_times[i] for i in range(len(rel_times) - 1)] - if any(d != 1 for d in diffs): - import warnings - - warnings.warn( - "HonestDiD: retained event-study horizons are not consecutive " - f"({rel_times}). Interior gaps change the geometry of smoothness " - "and relative-magnitude restrictions. R's HonestDiD requires " - "a consecutive event-time grid.", - UserWarning, - stacklevel=3, - ) + # Validate pre and post blocks are each consecutive + # (the gap between last pre and first post is the omitted + # reference period and is expected) + has_gap = False + for block in [pre_times, post_times]: + if len(block) >= 2: + for i in range(len(block) - 1): + if block[i + 1] - block[i] != 1: + has_gap = True + break + if has_gap: + import warnings + + warnings.warn( + "HonestDiD: retained event-study horizons have interior " + f"gaps within pre or post blocks (pre={pre_times}, " + f"post={post_times}). Interior gaps change the geometry " + "of smoothness and relative-magnitude restrictions. " + "R's HonestDiD requires a consecutive event-time grid.", + UserWarning, + stacklevel=3, + ) # Extract survey df df_survey = None diff --git a/diff_diff/staggered_bootstrap.py b/diff_diff/staggered_bootstrap.py index 2d94f773..2b9095ff 100644 --- a/diff_diff/staggered_bootstrap.py +++ b/diff_diff/staggered_bootstrap.py @@ -211,8 +211,9 @@ def _run_multiplier_bootstrap( ) unit_to_idx = {u: i for i, u in enumerate(all_units)} - # Get list of (g,t) pairs - gt_pairs = list(group_time_effects.keys()) + # Get list of (g,t) pairs that have influence function info + # (skip zero-mass cells that recorded NaN ATT without IF) + gt_pairs = [gt for gt in group_time_effects.keys() if gt in influence_func_info] n_gt = len(gt_pairs) # Identify post-treatment (g,t) pairs for overall ATT From c52905335a91f201786f9d039d4ae93a9c1d2ed6 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 16:04:35 -0400 Subject: [PATCH 16/19] HonestDiD: raise ValueError on non-consecutive event-time grid (was warning) Match R's honest_did.AGGTEobj behavior: refuse to construct bounds when retained event-study horizons have interior gaps within pre or post blocks. The single gap for the omitted reference period is still allowed. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 7f7c2ae0..9e6c6e93 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -710,16 +710,13 @@ def _extract_event_study_params( has_gap = True break if has_gap: - import warnings - - warnings.warn( - "HonestDiD: retained event-study horizons have interior " - f"gaps within pre or post blocks (pre={pre_times}, " - f"post={post_times}). Interior gaps change the geometry " - "of smoothness and relative-magnitude restrictions. " - "R's HonestDiD requires a consecutive event-time grid.", - UserWarning, - stacklevel=3, + raise ValueError( + "HonestDiD requires a consecutive event-time grid. " + f"Retained pre-periods {pre_times} and/or post-periods " + f"{post_times} have interior gaps. This can happen when " + "some event-study horizons have non-finite SEs. Ensure " + "all event-study periods have valid estimates, or use " + "balance_e to restrict to a balanced subset." ) # Extract survey df From e9995ef6ef6118d904a0e4a25dac32c87e8bc313 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 16:15:44 -0400 Subject: [PATCH 17/19] HonestDiD: validate full grid around reference period, not just within-block Check that post_times[0] - pre_times[-1] == 2 (exactly one gap for the omitted reference period). Catches missing e=0 or missing boundary horizons that the within-block check missed. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 9e6c6e93..15125f91 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -699,10 +699,21 @@ def _extract_event_study_params( else: sigma = np.diag(np.array(ses) ** 2) - # Validate pre and post blocks are each consecutive - # (the gap between last pre and first post is the omitted - # reference period and is expected) - has_gap = False + # Validate the full event-time grid is consecutive around + # the omitted reference period (exactly one gap allowed). + # R's HonestDiD refuses non-consecutive grids. + if pre_times and post_times: + # Expected: pre_times[-1] + 1 = reference, reference + 1 = post_times[0] + # So post_times[0] - pre_times[-1] should be exactly 2 + ref_gap = post_times[0] - pre_times[-1] + has_gap = ref_gap != 2 + elif pre_times: + has_gap = False # only pre, no ref gap to check + elif post_times: + has_gap = False # only post, no ref gap to check + else: + has_gap = False + # Also check within-block consecutiveness for block in [pre_times, post_times]: if len(block) >= 2: for i in range(len(block) - 1): @@ -711,12 +722,13 @@ def _extract_event_study_params( break if has_gap: raise ValueError( - "HonestDiD requires a consecutive event-time grid. " - f"Retained pre-periods {pre_times} and/or post-periods " - f"{post_times} have interior gaps. This can happen when " - "some event-study horizons have non-finite SEs. Ensure " - "all event-study periods have valid estimates, or use " - "balance_e to restrict to a balanced subset." + "HonestDiD requires a consecutive event-time grid " + "around the omitted reference period. Retained " + f"pre-periods {pre_times} and post-periods " + f"{post_times} have gaps. This can happen when " + "some event-study horizons have non-finite SEs. " + "Ensure all event-study periods have valid estimates, " + "or use balance_e to restrict to a balanced subset." ) # Extract survey df From 1f8a5374f307e0d21ef0677ad9aabd7ae515d2c2 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 16:30:31 -0400 Subject: [PATCH 18/19] Fix HonestDiD: reference-aware pre/post split, replicate df=0 sentinel P0 fix: When survey_metadata has replicate_method but df_survey=None (rank-deficient replicate design), set df_survey=0 sentinel so _get_critical_value returns NaN. Prevents finite HonestDiD bounds when CS inference is NaN. P1 fix: Infer omitted reference period from n_groups=0 entry in event_study_effects (handles anticipation>0 where ref=-1-anticipation). Split pre/post relative to reference, not hardcoded at t<0/t>=0. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 15125f91..291bfefa 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -612,10 +612,13 @@ def _extract_event_study_params( # Fallback: diagonal from SEs sigma = np.diag(np.array(ses) ** 2) - # Extract survey df if available + # Extract survey df. Replicate designs with undefined df → sentinel 0. df_survey = None if hasattr(results, "survey_metadata") and results.survey_metadata is not None: - df_survey = getattr(results.survey_metadata, "df_survey", None) + sm = results.survey_metadata + df_survey = getattr(sm, "df_survey", None) + if df_survey is None and getattr(sm, "replicate_method", None) is not None: + df_survey = 0 return ( beta_hat, @@ -665,9 +668,18 @@ def _extract_event_study_params( } rel_times = sorted(event_effects.keys()) - # Split into pre and post - pre_times = [t for t in rel_times if t < 0] - post_times = [t for t in rel_times if t >= 0] + # Infer the omitted reference period from the n_groups=0 entry + # (injected by _aggregate_event_study for universal base). + # Default to e=-1 if no reference found (varying base). + ref_period = -1 + for t, data in results.event_study_effects.items(): + if data.get("n_groups", 1) == 0: + ref_period = t + break + + # Split relative to the reference period, not hardcoded at 0 + pre_times = [t for t in rel_times if t < ref_period] + post_times = [t for t in rel_times if t > ref_period] effects = [] ses = [] @@ -731,10 +743,15 @@ def _extract_event_study_params( "or use balance_e to restrict to a balanced subset." ) - # Extract survey df + # Extract survey df. For replicate designs with undefined df + # (rank <= 1), use sentinel df=0 so _get_critical_value returns + # NaN, matching the safe_inference contract. df_survey = None if hasattr(results, "survey_metadata") and results.survey_metadata is not None: - df_survey = getattr(results.survey_metadata, "df_survey", None) + sm = results.survey_metadata + df_survey = getattr(sm, "df_survey", None) + if df_survey is None and getattr(sm, "replicate_method", None) is not None: + df_survey = 0 # undefined replicate df → NaN inference return ( beta_hat, From c5015c78e693fabae21b8bd05723d10935761daf Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 29 Mar 2026 16:57:38 -0400 Subject: [PATCH 19/19] Fix _estimate_max_pre_violation to use reference-aware pre_periods Use the pre_periods list from _extract_event_study_params() instead of hardcoded t < 0. With anticipation > 0, the reference period is at e = -1 - anticipation, so periods -anticipation through -1 are NOT pre-treatment and should not enter max_pre_violation. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/honest_did.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 291bfefa..826a6575 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -1443,11 +1443,12 @@ def _estimate_max_pre_violation(self, results: Any, pre_periods: List) -> float: if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects: - # Filter out normalization constraints (n_groups=0, e.g. reference period) + # Use the reference-aware pre_periods from _extract_event_study_params + pre_set = set(pre_periods) if pre_periods else set() pre_effects = [ abs(results.event_study_effects[t]["effect"]) for t in results.event_study_effects - if t < 0 and results.event_study_effects[t].get("n_groups", 1) > 0 + if t in pre_set and results.event_study_effects[t].get("n_groups", 1) > 0 ] if pre_effects: return max(pre_effects)