diff --git a/ROADMAP.md b/ROADMAP.md index 07b9b9e..2c0f868 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -8,7 +8,7 @@ For past changes and release history, see [CHANGELOG.md](CHANGELOG.md). ## Current Status -diff-diff v2.6.0 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis: +diff-diff v2.7.5 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis, plus **unique survey support** — design-based variance estimation (Taylor linearization, replicate weights) integrated across all estimators. No R or Python package offers this combination: - **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing et al. 2024), Continuous DiD (Callaway, Goodman-Bacon & Sant'Anna 2024) - **Valid inference**: Robust SEs, cluster SEs, wild bootstrap, multiplier bootstrap, placebo-based variance @@ -16,11 +16,34 @@ diff-diff v2.6.0 is a **production-ready** DiD library with feature parity with - **Sensitivity analysis**: Honest DiD (Rambachan-Roth), Pre-trends power analysis (Roth 2022) - **Study design**: Power analysis tools - **Data utilities**: Real-world datasets (Card-Krueger, Castle Doctrine, Divorce Laws, MPDTA), DGP functions for all supported designs +- **Survey support**: Full `SurveyDesign` with strata, PSU, FPC, weight types, replicate weights (BRR/Fay/JK1/JKn), Taylor linearization, DEFF diagnostics, subpopulation analysis — integrated across all estimators (see [survey-roadmap.md](docs/survey-roadmap.md)) - **Performance**: Optional Rust backend for accelerated computation; faster than R at scale (see [CHANGELOG.md](CHANGELOG.md) for benchmarks) --- -## Near-Term Enhancements (v2.7) +## Near-Term Enhancements (v2.8) + +### Survey Phase 7: Completing the Survey Story + +Close the remaining gaps for practitioners using major population surveys +(ACS, CPS, BRFSS, MEPS). See [survey-roadmap.md](docs/survey-roadmap.md) for +full details. + +- **CS Covariates + IPW/DR + Survey** *(High priority)*: Implement DRDID + nuisance IF corrections under survey weights. Currently the recommended DR + method raises `NotImplementedError` with covariates + survey. This is the + most commonly needed path in applied work (Medicaid expansion, minimum wage). +- **Repeated Cross-Sections** *(High priority)*: `panel=False` support for + CallawaySantAnna, enabling analysis of surveys that don't track units over + time (BRFSS, ACS annual, CPS monthly). Uses cross-sectional DRDID + (Sant'Anna & Zhao 2020, Section 4). +- **Survey-Aware DiD Tutorial** *(High priority)*: Jupyter notebook + demonstrating the full workflow with realistic survey data. diff-diff is + the only package (R or Python) with design-based variance for modern DiD + — this makes that capability discoverable. +- **HonestDiD + Survey Variance** *(Medium priority)*: Pass survey vcov + (TSL or replicate) into sensitivity analysis instead of cluster-robust vcov, + so sensitivity bounds respect the same variance structure as main estimates. ### Staggered Triple Difference (DDD) @@ -32,12 +55,6 @@ Extend the existing `TripleDifference` estimator to handle staggered adoption se **Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). *Working Paper*. R package: `triplediff`. -### Enhanced Visualization - -- Synthetic control weight visualization (bar chart of unit weights) -- Treatment adoption "staircase" plot for staggered designs -- Interactive plots with plotly backend option - --- ## Medium-Term Enhancements diff --git a/TODO.md b/TODO.md index 264e76a..780cc22 100644 --- a/TODO.md +++ b/TODO.md @@ -54,11 +54,17 @@ Deferred items from PR reviews that were not addressed before merge. |-------|----------|----|----------| | ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails) | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | -| CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works. | `staggered.py` | #233 | Medium — tracked in Survey Phase 7a | -| EfficientDiD `control_group="last_cohort"` trims at `last_g - anticipation` but REGISTRY says `t >= last_g`. With `anticipation=0` (default) these are identical. Needs design decision for `anticipation>0`. | `efficient_did.py` | #230 | Low | -| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. | `prep_dgp.py`, `power.py` | #208 | Low | -| Survey design resolution/collapse patterns inconsistent across panel estimators — extract shared helpers for panel-to-unit collapse, post-filter re-resolution, metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low | -| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` | — | Low | +| Replicate-weight survey df — **Resolved**. `df_survey = rank(replicate_weights) - 1` matching R's `survey::degf()`. For IF paths, `n_valid - 1` when dropped replicates reduce effective count. | `survey.py` | #238 | Resolved | +| CallawaySantAnna survey: strata/PSU/FPC — **Resolved**. Aggregated SEs (overall, event study, group) use `compute_survey_if_variance()`. Bootstrap uses PSU-level multiplier weights. | `staggered.py` | #237 | Resolved | +| CallawaySantAnna survey + covariates + IPW/DR — **Resolved**. DRDID panel nuisance IF corrections (PS + OR) implemented for both survey and non-survey DR paths (Phase 7a). IPW path unblocked. | `staggered.py` | #233 | Resolved | +| SyntheticDiD/TROP survey: strata/PSU/FPC — **Resolved**. Rao-Wu rescaled bootstrap implemented for both. TROP uses cross-classified pseudo-strata. Rust TROP remains pweight-only (Python fallback for full design). | `synthetic_did.py`, `trop.py` | — | Resolved | +| EfficientDiD hausman_pretest() clustered covariance stale `n_cl` — **Resolved**. Recompute `n_cl` and remap indices after `row_finite` filtering via `np.unique(return_inverse=True)`. | `efficient_did.py` | #230 | Resolved | +| EfficientDiD `control_group="last_cohort"` trims at `last_g - anticipation` but REGISTRY says `t >= last_g`. With `anticipation=0` (default) these are identical. With `anticipation>0`, code is arguably more conservative (excludes anticipation-contaminated periods). Either align REGISTRY with code or change code to `t < last_g` — needs design decision. | `efficient_did.py` | #230 | Low | +| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | +| ContinuousDiD event-study aggregation anticipation filter — **Resolved**. `_aggregate_event_study()` now filters `e < -anticipation` when `anticipation > 0`, matching CallawaySantAnna behavior. Bootstrap paths also filtered. | `continuous_did.py` | #226 | Resolved | +| Survey design resolution/collapse patterns are inconsistent across panel estimators — ContinuousDiD rebuilds unit-level design in SE code, EfficientDiD builds once in fit(), StackedDiD re-resolves on stacked data; extract shared helpers for panel-to-unit collapse, post-filter re-resolution, and metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low | +| Survey metadata formatting dedup — **Resolved**. Extracted `_format_survey_block()` helper in `results.py`, replaced 13 occurrences across 11 files. | `results.py` + 10 results files | — | Resolved | +| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup (panel pivoting, absorbing-state validation, first-treatment detection, effective rank, NaN warnings). Both bootstrap methods also duplicate the stratified resampling loop. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` | — | Low | #### Performance @@ -161,8 +167,8 @@ Features in R's `did` package that block porting additional tests: | Feature | R tests blocked | Priority | Status | |---------|----------------|----------|--------| -| Repeated cross-sections (`panel=FALSE`) | ~7 tests in test-att_gt.R + test-user_bug_fixes.R | High | Planned — Survey Phase 7b | -| Sampling/population weights | 7 tests incl. all JEL replication | Medium | Mostly resolved (Phases 1-6); CS IPW/DR + covariates + survey in Phase 7a | +| Repeated cross-sections (`panel=FALSE`) | ~7 tests in test-att_gt.R + test-user_bug_fixes.R | High | **Resolved** — Phase 7b: `panel=False` on CallawaySantAnna | +| Sampling/population weights | 7 tests incl. all JEL replication | Medium | **Resolved** (Phases 1-6 + 7a: CS IPW/DR + covariates + survey) | | Calendar time aggregation | 1 test in test-att_gt.R | Low | | --- diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index d2a5417..826a657 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -22,11 +22,12 @@ import numpy as np import pandas as pd -from scipy import optimize, stats +from scipy import optimize from diff_diff.results import ( MultiPeriodDiDResults, ) +from diff_diff.utils import _get_critical_value # ============================================================================= # Delta Restriction Classes @@ -193,6 +194,9 @@ class HonestDiDResults: original_results: Optional[Any] = field(default=None, repr=False) # Event study bounds (optional) event_study_bounds: Optional[Dict[Any, Dict[str, float]]] = field(default=None, repr=False) + # Survey design metadata (Phase 7d) + survey_metadata: Optional[Any] = field(default=None, repr=False) + df_survey: Optional[int] = field(default=None, repr=False) def __repr__(self) -> str: sig = "" if self.ci_lb <= 0 <= self.ci_ub else "*" @@ -534,7 +538,7 @@ def plot( def _extract_event_study_params( results: Union[MultiPeriodDiDResults, Any], -) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any]]: +) -> Tuple[np.ndarray, np.ndarray, int, int, List[Any], List[Any], Optional[int]]: """ Extract event study parameters from results objects. @@ -557,6 +561,8 @@ def _extract_event_study_params( Pre-period identifiers. post_periods : list Post-period identifiers. + df_survey : int or None + Survey degrees of freedom for t-distribution inference. """ if isinstance(results, MultiPeriodDiDResults): # Extract from MultiPeriodDiD @@ -606,7 +612,23 @@ def _extract_event_study_params( # Fallback: diagonal from SEs sigma = np.diag(np.array(ses) ** 2) - return beta_hat, sigma, num_pre_periods, num_post_periods, pre_periods, post_periods + # Extract survey df. Replicate designs with undefined df → sentinel 0. + df_survey = None + if hasattr(results, "survey_metadata") and results.survey_metadata is not None: + sm = results.survey_metadata + df_survey = getattr(sm, "df_survey", None) + if df_survey is None and getattr(sm, "replicate_method", None) is not None: + df_survey = 0 + + return ( + beta_hat, + sigma, + num_pre_periods, + num_post_periods, + pre_periods, + post_periods, + df_survey, + ) else: # Try CallawaySantAnnaResults @@ -621,6 +643,22 @@ def _extract_event_study_params( "event study effects." ) + # Warn if not using universal base period (R's HonestDiD requires it) + if getattr(results, "base_period", "universal") != "universal": + import warnings + + warnings.warn( + "HonestDiD sensitivity analysis on CallawaySantAnna results " + "requires base_period='universal' for valid interpretation. " + "With base_period='varying', pre-treatment coefficients use " + "consecutive comparisons (not a common reference period), " + "which changes the meaning of the parallel trends restriction. " + "Re-run with CallawaySantAnna(base_period='universal') for " + "methodologically valid HonestDiD bounds.", + UserWarning, + stacklevel=3, + ) + # Extract event study effects by relative time # Filter out normalization constraints (n_groups=0) and non-finite SEs event_effects = { @@ -630,9 +668,18 @@ def _extract_event_study_params( } rel_times = sorted(event_effects.keys()) - # Split into pre and post - pre_times = [t for t in rel_times if t < 0] - post_times = [t for t in rel_times if t >= 0] + # Infer the omitted reference period from the n_groups=0 entry + # (injected by _aggregate_event_study for universal base). + # Default to e=-1 if no reference found (varying base). + ref_period = -1 + for t, data in results.event_study_effects.items(): + if data.get("n_groups", 1) == 0: + ref_period = t + break + + # Split relative to the reference period, not hardcoded at 0 + pre_times = [t for t in rel_times if t < ref_period] + post_times = [t for t in rel_times if t > ref_period] effects = [] ses = [] @@ -641,9 +688,80 @@ def _extract_event_study_params( ses.append(event_effects[t]["se"]) beta_hat = np.array(effects) - sigma = np.diag(np.array(ses) ** 2) - return (beta_hat, sigma, len(pre_times), len(post_times), pre_times, post_times) + # Use full event-study VCV if available (Phase 7d), + # otherwise fall back to diagonal from SEs + if hasattr(results, "event_study_vcov") and results.event_study_vcov is not None: + vcov = results.event_study_vcov + # VCV is indexed by the aggregated event times (stored in + # event_study_vcov_index), NOT by event_study_effects keys + # (which may include an injected reference period). + # Subset to match the surviving rel_times. + vcov_index = getattr(results, "event_study_vcov_index", None) + if vcov_index is not None and len(rel_times) < len(vcov_index): + idx = [vcov_index.index(t) for t in rel_times if t in vcov_index] + if len(idx) == len(rel_times): + sigma = vcov[np.ix_(idx, idx)] + else: + sigma = np.diag(np.array(ses) ** 2) + elif vcov.shape[0] == len(rel_times): + sigma = vcov + else: + sigma = np.diag(np.array(ses) ** 2) + else: + sigma = np.diag(np.array(ses) ** 2) + + # Validate the full event-time grid is consecutive around + # the omitted reference period (exactly one gap allowed). + # R's HonestDiD refuses non-consecutive grids. + if pre_times and post_times: + # Expected: pre_times[-1] + 1 = reference, reference + 1 = post_times[0] + # So post_times[0] - pre_times[-1] should be exactly 2 + ref_gap = post_times[0] - pre_times[-1] + has_gap = ref_gap != 2 + elif pre_times: + has_gap = False # only pre, no ref gap to check + elif post_times: + has_gap = False # only post, no ref gap to check + else: + has_gap = False + # Also check within-block consecutiveness + for block in [pre_times, post_times]: + if len(block) >= 2: + for i in range(len(block) - 1): + if block[i + 1] - block[i] != 1: + has_gap = True + break + if has_gap: + raise ValueError( + "HonestDiD requires a consecutive event-time grid " + "around the omitted reference period. Retained " + f"pre-periods {pre_times} and post-periods " + f"{post_times} have gaps. This can happen when " + "some event-study horizons have non-finite SEs. " + "Ensure all event-study periods have valid estimates, " + "or use balance_e to restrict to a balanced subset." + ) + + # Extract survey df. For replicate designs with undefined df + # (rank <= 1), use sentinel df=0 so _get_critical_value returns + # NaN, matching the safe_inference contract. + df_survey = None + if hasattr(results, "survey_metadata") and results.survey_metadata is not None: + sm = results.survey_metadata + df_survey = getattr(sm, "df_survey", None) + if df_survey is None and getattr(sm, "replicate_method", None) is not None: + df_survey = 0 # undefined replicate df → NaN inference + + return ( + beta_hat, + sigma, + len(pre_times), + len(post_times), + pre_times, + post_times, + df_survey, + ) except ImportError: pass @@ -860,7 +978,13 @@ def _solve_bounds_lp( return lb, ub -def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple[float, float]: +def _compute_flci( + lb: float, + ub: float, + se: float, + alpha: float = 0.05, + df: Optional[int] = None, +) -> Tuple[float, float]: """ Compute Fixed Length Confidence Interval (FLCI). @@ -877,6 +1001,9 @@ def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple Standard error of the estimator. alpha : float Significance level. + df : int, optional + Degrees of freedom. If provided, uses t-distribution critical value + instead of normal (for survey designs with df = n_PSU - n_strata). Returns ------- @@ -895,7 +1022,7 @@ def _compute_flci(lb: float, ub: float, se: float, alpha: float = 0.05) -> Tuple if not (0 < alpha < 1): raise ValueError(f"alpha must be between 0 and 1, got alpha={alpha}") - z = stats.norm.ppf(1 - alpha / 2) + z = _get_critical_value(alpha, df) ci_lb = lb - z * se ci_ub = ub + z * se return ci_lb, ci_ub @@ -909,6 +1036,7 @@ def _compute_clf_ci( max_pre_violation: float, alpha: float = 0.05, n_draws: int = 1000, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """ Compute Conditional Least Favorable (C-LF) confidence interval. @@ -931,6 +1059,8 @@ def _compute_clf_ci( Significance level. n_draws : int Number of Monte Carlo draws for conditional CI. + df : int, optional + Degrees of freedom for t-distribution critical value. Returns ------- @@ -956,7 +1086,7 @@ def _compute_clf_ci( ub = theta + bound # CI with estimation uncertainty - z = stats.norm.ppf(1 - alpha / 2) + z = _get_critical_value(alpha, df) ci_lb = lb - z * se ci_ub = ub + z * se @@ -1086,7 +1216,7 @@ def fit( M = M if M is not None else self.M # Extract event study parameters - (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods, df_survey) = ( _extract_event_study_params(results) ) @@ -1137,22 +1267,41 @@ def fit( # Compute bounds based on method if self.method == "smoothness": lb, ub, ci_lb, ci_ub = self._compute_smoothness_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M + beta_post, sigma_post, l_vec, num_pre, num_post, M, df=df_survey ) ci_method = "FLCI" elif self.method == "relative_magnitude": lb, ub, ci_lb, ci_ub = self._compute_rm_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results + beta_post, + sigma_post, + l_vec, + num_pre, + num_post, + M, + pre_periods, + results, + df=df_survey, ) ci_method = "C-LF" else: # combined lb, ub, ci_lb, ci_ub = self._compute_combined_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results + beta_post, + sigma_post, + l_vec, + num_pre, + num_post, + M, + pre_periods, + results, + df=df_survey, ) ci_method = "FLCI" + # Extract survey_metadata for storage on results + survey_metadata = getattr(results, "survey_metadata", None) + return HonestDiDResults( lb=lb, ub=ub, @@ -1165,6 +1314,8 @@ def fit( alpha=self.alpha, ci_method=ci_method, original_results=results, + survey_metadata=survey_metadata, + df_survey=df_survey, ) def _compute_smoothness_bounds( @@ -1175,6 +1326,7 @@ def _compute_smoothness_bounds( num_pre: int, num_post: int, M: float, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """Compute bounds under smoothness restriction.""" # Construct constraints @@ -1185,7 +1337,7 @@ def _compute_smoothness_bounds( # Compute FLCI se = np.sqrt(l_vec @ sigma_post @ l_vec) - ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha) + ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha, df=df) return lb, ub, ci_lb, ci_ub @@ -1199,6 +1351,7 @@ def _compute_rm_bounds( Mbar: float, pre_periods: List, results: Any, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """Compute bounds under relative magnitudes restriction.""" # Estimate max pre-period violation from pre-trends @@ -1209,12 +1362,18 @@ def _compute_rm_bounds( # No pre-period violations detected - use point estimate theta = np.dot(l_vec, beta_post) se = np.sqrt(l_vec @ sigma_post @ l_vec) - z = stats.norm.ppf(1 - self.alpha / 2) + z = _get_critical_value(self.alpha, df) return theta, theta, theta - z * se, theta + z * se # Compute bounds lb, ub, ci_lb, ci_ub = _compute_clf_ci( - beta_post, sigma_post, l_vec, Mbar, max_pre_violation, self.alpha + beta_post, + sigma_post, + l_vec, + Mbar, + max_pre_violation, + self.alpha, + df=df, ) return lb, ub, ci_lb, ci_ub @@ -1229,16 +1388,17 @@ def _compute_combined_bounds( M: float, pre_periods: List, results: Any, + df: Optional[int] = None, ) -> Tuple[float, float, float, float]: """Compute bounds under combined smoothness + RM restriction.""" # Get smoothness bounds lb_sd, ub_sd, _, _ = self._compute_smoothness_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M + beta_post, sigma_post, l_vec, num_pre, num_post, M, df=df ) # Get RM bounds (use M as Mbar for combined) lb_rm, ub_rm, _, _ = self._compute_rm_bounds( - beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results + beta_post, sigma_post, l_vec, num_pre, num_post, M, pre_periods, results, df=df ) # Combined bounds are intersection @@ -1252,7 +1412,7 @@ def _compute_combined_bounds( # Compute FLCI on combined bounds se = np.sqrt(l_vec @ sigma_post @ l_vec) - ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha) + ci_lb, ci_ub = _compute_flci(lb, ub, se, self.alpha, df=df) return lb, ub, ci_lb, ci_ub @@ -1283,11 +1443,12 @@ def _estimate_max_pre_violation(self, results: Any, pre_periods: List) -> float: if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects: - # Filter out normalization constraints (n_groups=0, e.g. reference period) + # Use the reference-aware pre_periods from _extract_event_study_params + pre_set = set(pre_periods) if pre_periods else set() pre_effects = [ abs(results.event_study_effects[t]["effect"]) for t in results.event_study_effects - if t < 0 and results.event_study_effects[t].get("n_groups", 1) > 0 + if t in pre_set and results.event_study_effects[t].get("n_groups", 1) > 0 ] if pre_effects: return max(pre_effects) diff --git a/diff_diff/prep_dgp.py b/diff_diff/prep_dgp.py index 2aab32c..5fd42a2 100644 --- a/diff_diff/prep_dgp.py +++ b/diff_diff/prep_dgp.py @@ -136,6 +136,7 @@ def generate_staggered_data( time_trend: float = 0.1, noise_sd: float = 0.5, seed: Optional[int] = None, + panel: bool = True, ) -> pd.DataFrame: """ Generate synthetic data for staggered adoption DiD analysis. @@ -170,6 +171,10 @@ def generate_staggered_data( Standard deviation of idiosyncratic noise. seed : int, optional Random seed for reproducibility. + panel : bool, default=True + If True (default), generate balanced panel data (same units across + all periods). If False, generate repeated cross-section data where + each period draws independent observations with globally unique IDs. Returns ------- @@ -219,6 +224,56 @@ def generate_staggered_data( n_never = int(n_units * never_treated_frac) n_treated = n_units - n_never + if not panel: + # --- Repeated cross-section mode --- + # Each period draws n_units independent observations with unique IDs. + # Cohorts are assigned from the same distribution as panel. + records = [] + for period in range(n_periods): + # For each period, draw fresh cohort assignments + ft_period = np.zeros(n_units, dtype=int) + if n_treated > 0: + cohort_assignments = rng.choice(len(cohort_periods), size=n_treated) + ft_period[n_never:] = [cohort_periods[c] for c in cohort_assignments] + + # Unique unit IDs per period + for i in range(n_units): + uid = f"u{period}_{i}" + unit_first_treat = ft_period[i] + is_ever_treated = unit_first_treat > 0 + + is_treated = is_ever_treated and period >= unit_first_treat + + # Outcome: unit_fe_proxy (drawn fresh) + time trend + treatment + noise + unit_fe_proxy = rng.normal(0, unit_fe_sd) + y = 10.0 + unit_fe_proxy + time_trend * period + + effect = 0.0 + if is_treated: + time_since_treatment = period - unit_first_treat + if dynamic_effects: + effect = treatment_effect * (1 + effect_growth * time_since_treatment) + else: + effect = treatment_effect + y += effect + + y += rng.normal(0, noise_sd) + + records.append( + { + "unit": uid, + "period": period, + "outcome": y, + "first_treat": unit_first_treat, + "treated": int(is_treated), + "treat": int(is_ever_treated), + "true_effect": effect, + } + ) + + return pd.DataFrame(records) + + # --- Panel mode (default) --- # Assign treatment cohorts first_treat = np.zeros(n_units, dtype=int) if n_treated > 0: diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 718dc93..64dfbb6 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -92,6 +92,14 @@ def _linear_regression( return beta, residuals +def _safe_inv(A: np.ndarray) -> np.ndarray: + """Invert a square matrix with lstsq fallback for near-singular cases.""" + try: + return np.linalg.solve(A, np.eye(A.shape[0])) + except np.linalg.LinAlgError: + return np.linalg.lstsq(A, np.eye(A.shape[0]), rcond=None)[0] + + class CallawaySantAnna( CallawaySantAnnaBootstrapMixin, CallawaySantAnnaAggregationMixin, @@ -173,6 +181,15 @@ class CallawaySantAnna( Trimming bound for propensity scores. Scores are clipped to ``[pscore_trim, 1 - pscore_trim]`` before weight computation in IPW and DR estimation. Must be in ``(0, 0.5)``. + panel : bool, default=True + Whether the data is a balanced/unbalanced panel (units observed + across multiple time periods). Set to ``False`` for stationary + repeated cross-sections where each observation has a unique unit + ID and units do not repeat across periods. Requires that the + cross-sectional samples are drawn from the same population in + each period (stationarity). Uses cross-sectional DRDID + (Sant'Anna & Zhao 2020, Section 4) with per-observation influence + functions. Attributes ---------- @@ -262,6 +279,7 @@ def __init__( base_period: str = "varying", cband: bool = True, pscore_trim: float = 0.01, + panel: bool = True, ): import warnings @@ -324,6 +342,7 @@ def __init__( self.cband = cband self.pscore_trim = pscore_trim + self.panel = panel self.is_fitted_ = False self.results_: Optional[CallawaySantAnnaResults] = None @@ -501,6 +520,8 @@ def _precompute_structures( "covariate_by_period": covariate_by_period, "time_periods": time_periods, "is_balanced": is_balanced, + "is_panel": True, + "canonical_size": len(all_units), "survey_weights": survey_weights_arr, "resolved_survey": resolved_survey, "resolved_survey_unit": resolved_survey_unit, @@ -875,10 +896,12 @@ def _compute_all_att_gt_vectorized( if task_keys: df_survey_val = precomputed.get("df_survey") # Guard: replicate design with undefined df → NaN inference - if (df_survey_val is None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + df_survey_val is None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(atts), @@ -1217,10 +1240,13 @@ def _compute_all_att_gt_covariate_reg( # Use survey df for replicate designs (propagated from precomputed) _ipw_dr_df = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference - if (_ipw_dr_df is None and precomputed is not None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + _ipw_dr_df is None + and precomputed is not None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): _ipw_dr_df = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(atts), np.array(ses), alpha=self.alpha, df=_ipw_dr_df @@ -1250,7 +1276,9 @@ def fit( Parameters ---------- data : pd.DataFrame - Panel data with unit and time identifiers. + Panel data with unit and time identifiers. For repeated + cross-sections (``panel=False``), each observation should + have a unique unit ID — units do not repeat across periods. outcome : str Name of outcome variable column. unit : str @@ -1275,8 +1303,10 @@ def fit( survey_design : SurveyDesign, optional Survey design specification. Supports pweight with strata/PSU/FPC. Aggregated SEs (overall, event study, group) use design-based - variance via compute_survey_if_variance(). - Covariates + IPW/DR + survey raises NotImplementedError. + variance via compute_survey_if_variance(). All estimation methods + (reg, ipw, dr) support covariates + survey. For repeated + cross-sections (``panel=False``), survey weights are + per-observation (no unit-level collapse). Returns ------- @@ -1292,6 +1322,20 @@ def fit( if not (0 < self.pscore_trim < 0.5): raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}") + # Reset stale state from prior fit (prevents leaking event-study VCV) + self._event_study_vcov = None + + if not self.panel: + warnings.warn( + "panel=False uses repeated cross-section DRDID estimators " + "(Sant'Anna & Zhao 2020, Section 4) which assume stationary " + "cross-sectional sampling: the population distribution of " + "(Y, X, G) must be stable across periods. This assumption " + "is not data-checkable.", + UserWarning, + stacklevel=2, + ) + # Normalize empty covariates list to None if covariates is not None and len(covariates) == 0: covariates = None @@ -1308,7 +1352,8 @@ def fit( # Validate within-unit constancy for panel survey designs if resolved_survey is not None: - _validate_unit_constant_survey(data, unit, survey_design) + if self.panel: + _validate_unit_constant_survey(data, unit, survey_design) if resolved_survey.weight_type != "pweight": raise ValueError( f"CallawaySantAnna survey support requires weight_type='pweight', " @@ -1320,22 +1365,6 @@ def fit( # Bootstrap + survey is now supported via PSU-level multiplier bootstrap. - # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet - # implemented to match DRDID panel formula) - if ( - resolved_survey is not None - and covariates is not None - and len(covariates) > 0 - and self.estimation_method in ("ipw", "dr") - ): - raise NotImplementedError( - f"Survey weights with covariates and estimation_method=" - f"'{self.estimation_method}' is not yet supported for " - f"CallawaySantAnna. The DRDID panel nuisance-estimation IF " - f"corrections are not yet implemented. Use estimation_method='reg' " - f"with covariates, or use any method without covariates." - ) - # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: @@ -1365,13 +1394,19 @@ def fit( time_periods = sorted(df[time].unique()) treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0]) - # Get unique units - unit_info = ( - df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index() - ) - - n_treated_units = (unit_info[first_treat] > 0).sum() - n_control_units = (unit_info["_never_treated"]).sum() + if self.panel: + # Panel: count unique units + unit_info = ( + df.groupby(unit) + .agg({first_treat: "first", "_never_treated": "first"}) + .reset_index() + ) + n_treated_units = (unit_info[first_treat] > 0).sum() + n_control_units = (unit_info["_never_treated"]).sum() + else: + # RCS: count observations per cohort (no unit tracking) + n_treated_units = int((df[first_treat] > 0).sum()) + n_control_units = int(df["_never_treated"].sum()) if n_control_units == 0 and self.control_group == "never_treated": raise ValueError( @@ -1392,17 +1427,30 @@ def fit( # per-cell SEs use IF-based variance, not TSL. The user's cluster= # parameter is handled by the existing non-survey clustering path. # Pre-compute data structures for efficient ATT(g,t) computation - precomputed = self._precompute_structures( - df, - outcome, - unit, - time, - first_treat, - covariates, - time_periods, - treatment_groups, - resolved_survey=resolved_survey, - ) + if self.panel: + precomputed = self._precompute_structures( + df, + outcome, + unit, + time, + first_treat, + covariates, + time_periods, + treatment_groups, + resolved_survey=resolved_survey, + ) + else: + precomputed = self._precompute_structures_rc( + df, + outcome, + unit, + time, + first_treat, + covariates, + time_periods, + treatment_groups, + resolved_survey=resolved_survey, + ) # Recompute survey metadata from the unit-level resolved survey so # that n_psu and df_survey reflect the actual survey design (explicit @@ -1419,16 +1467,70 @@ def fit( # survey df computed in _precompute_structures for consistency. df_survey = precomputed.get("df_survey") # Guard: replicate design with undefined df (rank <= 1) → NaN inference - if (df_survey is None and resolved_survey is not None - and hasattr(resolved_survey, 'uses_replicate_variance') - and resolved_survey.uses_replicate_variance): + if ( + df_survey is None + and resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): df_survey = 0 # Compute ATT(g,t) for each group-time combination min_period = min(time_periods) has_survey = resolved_survey is not None - if covariates is None and self.estimation_method == "reg": + if not self.panel: + # --- Repeated cross-section path --- + # No vectorized/Cholesky fast paths (panel-only optimizations). + # Loop using _compute_att_gt_rc() for each (g,t). + group_time_effects = {} + influence_func_info = {} + + for g in treatment_groups: + if self.base_period == "universal": + universal_base = g - 1 - self.anticipation + valid_periods = [t for t in time_periods if t != universal_base] + else: + valid_periods = [ + t for t in time_periods if t >= g - self.anticipation or t > min_period + ] + + for t in valid_periods: + rc_result = self._compute_att_gt_rc( + precomputed, + g, + t, + covariates, + ) + att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = rc_result[:6] + agg_w = rc_result[6] if len(rc_result) > 6 else n_treat + + if att_gt is not None: + t_stat, p_val, ci = safe_inference( + att_gt, + se_gt, + alpha=self.alpha, + df=df_survey, + ) + + gte_entry = { + "effect": att_gt, + "se": se_gt, + "t_stat": t_stat, + "p_value": p_val, + "conf_int": ci, + "n_treated": n_treat, + "n_control": n_ctrl, + "agg_weight": agg_w, + } + if sw_sum is not None: + gte_entry["survey_weight_sum"] = sw_sum + group_time_effects[(g, t)] = gte_entry + + if inf_info is not None: + influence_func_info[(g, t)] = inf_info + + elif covariates is None and self.estimation_method == "reg": # Fast vectorized path for the common no-covariates regression case group_time_effects, influence_func_info = self._compute_all_att_gt_vectorized( precomputed, treatment_groups, time_periods, min_period @@ -1523,9 +1625,12 @@ def fit( if survey_metadata.df_survey != df_survey: survey_metadata.df_survey = df_survey # Guard: replicate design with undefined df (rank <= 1) → NaN inference - if (df_survey is None and resolved_survey is not None - and hasattr(resolved_survey, 'uses_replicate_variance') - and resolved_survey.uses_replicate_variance): + if ( + df_survey is None + and resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): df_survey = 0 overall_t, overall_p, overall_ci = safe_inference( overall_att, @@ -1679,6 +1784,15 @@ def fit( ) # Store results + # Retrieve event-study VCV from aggregation mixin (Phase 7d). + # Clear it when bootstrap overwrites event-study SEs to prevent + # HonestDiD from mixing analytical VCV with bootstrap SEs. + event_study_vcov = getattr(self, "_event_study_vcov", None) + event_study_vcov_index = getattr(self, "_event_study_vcov_index", None) + if bootstrap_results is not None and event_study_vcov is not None: + event_study_vcov = None + event_study_vcov_index = None + self.results_ = CallawaySantAnnaResults( group_time_effects=group_time_effects, overall_att=overall_att, @@ -1700,6 +1814,9 @@ def fit( cband_crit_value=cband_crit_value, pscore_trim=self.pscore_trim, survey_metadata=survey_metadata, + event_study_vcov=event_study_vcov, + event_study_vcov_index=event_study_vcov_index, + panel=self.panel, ) self.is_fitted_ = True @@ -1931,34 +2048,33 @@ def _ipw_estimation( X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) pscore_all = np.concatenate([pscore_treated, pscore_control]) - # Survey-weighted PS Hessian: sum(w_i * mu_i * (1-mu_i) * x_i * x_i') + # PS IF correction — matches R's std_ipw_did_panel convention: + # H = X'WX / n, asy_lin_rep = score @ solve(H) / n, M2 = colMeans + n_all_panel = n_t + n_c W_ps = pscore_all * (1 - pscore_all) if sw_all is not None: W_ps = W_ps * sw_all - H = X_all_int.T @ (W_ps[:, None] * X_all_int) - try: - H_inv = np.linalg.solve(H, np.eye(H.shape[0])) - except np.linalg.LinAlgError: - H_inv = np.linalg.lstsq(H, np.eye(H.shape[0]), rcond=None)[0] + H = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel + H_inv = _safe_inv(H) - # PS score: w_i * (D_i - pi_i) * X_i D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) score_ps = (D_all - pscore_all)[:, None] * X_all_int if sw_all is not None: score_ps = score_ps * sw_all[:, None] - asy_lin_rep_ps = score_ps @ H_inv # shape (n_t + n_c, p) + asy_lin_rep_ps = score_ps @ H_inv / n_all_panel - # M2: gradient of ATT w.r.t. PS parameters att_control_weighted = np.sum(weights_control_norm * control_change) - M2 = np.mean( - (weights_control_norm * (control_change - att_control_weighted))[:, None] - * X_all_int[n_t:], - axis=0, + # R colMeans: sum over control rows / n_all (not / n_c) + M2 = ( + np.sum( + (weights_control_norm * (control_change - att_control_weighted))[:, None] + * X_all_int[n_t:], + axis=0, + ) + / n_all_panel ) - # PS correction to influence function - inf_ps_correction = asy_lin_rep_ps @ M2 - inf_func = inf_func + inf_ps_correction + inf_func = inf_func + asy_lin_rep_ps @ M2 # SE from influence function variance var_psi = np.sum(inf_func**2) @@ -2178,13 +2294,69 @@ def _doubly_robust( att = att_treated_part + augmentation # Step 4: Influence function (survey-weighted DR) + # Start with plug-in IF, then add nuisance parameter corrections + # (Sant'Anna & Zhao 2020, Theorem 3.1) psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att) psi_control = (weights_control / sw_t_sum) * (m_control - control_change) + inf_func = np.concatenate([psi_treated, psi_control]) - var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) + if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: + # --- PS IF correction (mirrors IPW L1929-1961) --- + # Accounts for propensity score estimation uncertainty + X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) + pscore_treated_clipped = np.clip( + pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim + ) + pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) + + # PS IF correction — R convention: H/n, asy_rep/n, colMeans + n_all_panel = n_t + n_c + W_ps = pscore_all * (1 - pscore_all) + if sw_all is not None: + W_ps = W_ps * sw_all + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel + H_ps_inv = _safe_inv(H_ps) + + D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) + score_ps = (D_all - pscore_all)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel + + dr_resid_control = m_control - control_change + M2_dr = ( + np.sum( + ((weights_control / sw_t_sum) * dr_resid_control)[:, None] + * X_all_int[n_t:], + axis=0, + ) + / n_all_panel + ) + inf_func = inf_func + asy_lin_rep_ps @ M2_dr + + # --- OR IF correction --- + # Accounts for outcome regression estimation uncertainty + X_c_int = X_control_with_intercept + W_diag = sw_control if sw_control is not None else np.ones(n_c) + XtWX = X_c_int.T @ (W_diag[:, None] * X_c_int) + bread = _safe_inv(XtWX) + + # M1: dATT/dbeta — gradient of DR ATT w.r.t. OR parameters + X_t_int = X_treated_with_intercept + M1 = ( + -np.sum(sw_treated[:, None] * X_t_int, axis=0) + + np.sum(weights_control[:, None] * X_c_int, axis=0) + ) / sw_t_sum + + # OR asymptotic linear representation (control-only) + resid_c = control_change - m_control + asy_lin_rep_or = (W_diag * resid_c)[:, None] * X_c_int @ bread + # Apply to control portion only (treated contribute zero) + inf_func[n_t:] += asy_lin_rep_or @ M1 + + # Recompute SE from corrected IF + var_psi = np.sum(inf_func**2) se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 - - inf_func = np.concatenate([psi_treated, psi_control]) else: # IPW weights for control: p(X) / (1 - p(X)) weights_control = pscore_control / (1 - pscore_control) @@ -2194,14 +2366,56 @@ def _doubly_robust( augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t) att = att_treated_part + augmentation - # Step 4: Standard error using influence function + # Step 4: Influence function with nuisance IF corrections psi_treated = (treated_change - m_treated - att) / n_t psi_control = (weights_control * (m_control - control_change)) / n_t + inf_func = np.concatenate([psi_treated, psi_control]) - var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) - se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: + # --- PS IF correction — R convention: H/n, asy_rep/n, colMeans --- + n_all_panel = n_t + n_c + X_all_int = np.column_stack([np.ones(n_all_panel), X_all]) + pscore_treated_clipped = np.clip( + pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim + ) + pscore_all = np.concatenate([pscore_treated_clipped, pscore_control]) - inf_func = np.concatenate([psi_treated, psi_control]) + W_ps = pscore_all * (1 - pscore_all) + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel + H_ps_inv = _safe_inv(H_ps) + + D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) + score_ps = (D_all - pscore_all)[:, None] * X_all_int + asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel + + dr_resid_control = m_control - control_change + M2_dr = ( + np.sum( + ((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:], + axis=0, + ) + / n_all_panel + ) + inf_func = inf_func + asy_lin_rep_ps @ M2_dr + + # --- OR IF correction --- + X_c_int = X_control_with_intercept + XtX = X_c_int.T @ X_c_int + bread = _safe_inv(XtX) + + X_t_int = X_treated_with_intercept + M1 = ( + -np.sum(X_t_int, axis=0) + + np.sum(weights_control[:, None] * X_c_int, axis=0) + ) / n_t + + resid_c = control_change - m_control + asy_lin_rep_or = resid_c[:, None] * X_c_int @ bread + inf_func[n_t:] += asy_lin_rep_or @ M1 + + # Recompute SE from corrected IF + var_psi = np.sum(inf_func**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 else: # Without covariates, DR simplifies to difference in means if sw_treated is not None: @@ -2235,6 +2449,1127 @@ def _doubly_robust( return att, se, inf_func + # ========================================================================= + # Repeated Cross-Section (RCS) methods + # ========================================================================= + + def _precompute_structures_rc( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + time_periods: List[Any], + treatment_groups: List[Any], + resolved_survey=None, + ) -> PrecomputedData: + """ + Pre-compute observation-level structures for repeated cross-section. + + Unlike the panel path, RCS does not pivot to wide format. Each + observation is treated independently (no within-unit differencing). + + Returns + ------- + PrecomputedData + Dictionary with pre-computed structures (observation-level). + """ + n_obs = len(df) + + # Observation-level arrays (no pivot) + obs_time = df[time].values + obs_outcome = df[outcome].values + unit_cohorts = df[first_treat].values + + # "all_units" key holds integer observation indices for backward + # compatibility with aggregation code + all_units = np.arange(n_obs) + + # Pre-compute cohort masks (boolean arrays, observation-level) + cohort_masks = {} + for g in treatment_groups: + cohort_masks[g] = unit_cohorts == g + + # Never-treated mask + never_treated_mask = (unit_cohorts == 0) | (unit_cohorts == np.inf) + + # Period-to-column mapping (identity for RCS — used for base period checks) + period_to_col = {t: i for i, t in enumerate(sorted(time_periods))} + + # Covariates (observation-level, not per-period) + obs_covariates = None + if covariates: + obs_covariates = df[covariates].values + + # Survey weights (already per-observation for RCS) + if resolved_survey is not None: + survey_weights_arr = resolved_survey.weights.copy() + else: + survey_weights_arr = None + + # For RCS, the resolved survey is already per-observation + resolved_survey_rc = resolved_survey + + # Fixed cohort masses: total observations per cohort across all periods. + # Used as aggregation weights so that n_treated is consistent with WIF. + rcs_cohort_masses = {} + for g in treatment_groups: + rcs_cohort_masses[g] = int(np.sum(unit_cohorts == g)) + + return { + "all_units": all_units, + "unit_to_idx": None, # RCS: obs indices are positions + "unit_cohorts": unit_cohorts, + "canonical_size": n_obs, + "is_panel": False, + "obs_time": obs_time, + "obs_outcome": obs_outcome, + "obs_covariates": obs_covariates, + "cohort_masks": cohort_masks, + "never_treated_mask": never_treated_mask, + "time_periods": time_periods, + "period_to_col": period_to_col, + "is_balanced": False, + "survey_weights": survey_weights_arr, + "resolved_survey": resolved_survey, + "resolved_survey_unit": resolved_survey_rc, + "df_survey": ( + resolved_survey_rc.df_survey + if resolved_survey_rc is not None and hasattr(resolved_survey_rc, "df_survey") + else None + ), + "rcs_cohort_masses": rcs_cohort_masses, + } + + def _compute_att_gt_rc( + self, + precomputed: PrecomputedData, + g: Any, + t: Any, + covariates: Optional[List[str]], + ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]: + """ + Compute ATT(g,t) for repeated cross-section data. + + For RCS, the 2x2 DiD compares outcomes across two independent + cross-sections (periods t and base period s) rather than + within-unit changes. + + Returns + ------- + att_gt : float or None + se_gt : float + n_treated : int (treated obs at period t) + n_control : int (control obs at period t) + inf_func_info : dict or None + survey_weight_sum : float or None + """ + cohort_masks = precomputed["cohort_masks"] + never_treated_mask = precomputed["never_treated_mask"] + unit_cohorts = precomputed["unit_cohorts"] + obs_time = precomputed["obs_time"] + obs_outcome = precomputed["obs_outcome"] + period_to_col = precomputed["period_to_col"] + + # Base period selection (same logic as panel) + if self.base_period == "universal": + base_period_val = g - 1 - self.anticipation + else: # varying + if t < g - self.anticipation: + base_period_val = t - 1 + else: + base_period_val = g - 1 - self.anticipation + + if base_period_val not in period_to_col or t not in period_to_col: + return None, 0.0, 0, 0, None, None + + # Treated mask = cohort g + treated_mask = cohort_masks[g] + + # Control mask (same logic as panel) + if self.control_group == "never_treated": + control_mask = never_treated_mask + else: # not_yet_treated + nyt_threshold = max(t, base_period_val) + self.anticipation + control_mask = never_treated_mask | ( + (unit_cohorts > nyt_threshold) & (unit_cohorts != g) + ) + + # Period masks + at_t = obs_time == t + at_s = obs_time == base_period_val + + # 4 groups of observations + treated_t = treated_mask & at_t + treated_s = treated_mask & at_s + control_t = control_mask & at_t + control_s = control_mask & at_s + + n_gt = int(np.sum(treated_t)) + n_gs = int(np.sum(treated_s)) + n_ct = int(np.sum(control_t)) + n_cs = int(np.sum(control_s)) + + if n_gt == 0 or n_ct == 0 or n_gs == 0 or n_cs == 0: + return None, 0.0, 0, 0, None, None + + # Extract outcomes for each group + y_gt = obs_outcome[treated_t] + y_gs = obs_outcome[treated_s] + y_ct = obs_outcome[control_t] + y_cs = obs_outcome[control_s] + + # Survey weights + survey_w = precomputed.get("survey_weights") + sw_gt = survey_w[treated_t] if survey_w is not None else None + sw_gs = survey_w[treated_s] if survey_w is not None else None + sw_ct = survey_w[control_t] if survey_w is not None else None + sw_cs = survey_w[control_s] if survey_w is not None else None + + # Guard against zero effective mass + if sw_gt is not None: + if np.sum(sw_gt) <= 0 or np.sum(sw_gs) <= 0: + return np.nan, np.nan, 0, 0, None, None + if np.sum(sw_ct) <= 0 or np.sum(sw_cs) <= 0: + return np.nan, np.nan, 0, 0, None, None + + # Get covariates if specified + obs_covariates = precomputed.get("obs_covariates") + has_covariates = covariates is not None and obs_covariates is not None + + if has_covariates: + X_gt = obs_covariates[treated_t] + X_gs = obs_covariates[treated_s] + X_ct = obs_covariates[control_t] + X_cs = obs_covariates[control_s] + + # Check for NaN in covariates + if ( + np.any(np.isnan(X_gt)) + or np.any(np.isnan(X_gs)) + or np.any(np.isnan(X_ct)) + or np.any(np.isnan(X_cs)) + ): + warnings.warn( + f"Missing values in covariates for group {g}, time {t} (RCS). " + "Falling back to unconditional estimation.", + UserWarning, + stacklevel=3, + ) + has_covariates = False + + if has_covariates and self.estimation_method == "reg": + att, se, inf_func_all, idx_all = self._outcome_regression_rc( + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + elif has_covariates and self.estimation_method == "ipw": + att, se, inf_func_all, idx_all = self._ipw_estimation_rc( + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + elif has_covariates and self.estimation_method == "dr": + att, se, inf_func_all, idx_all = self._doubly_robust_rc( + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + else: + # No-covariates 2x2 DiD (all methods reduce to same) + att, se, inf_func_all, idx_all = self._rc_2x2_did( + y_gt, + y_gs, + y_ct, + y_cs, + treated_t, + treated_s, + control_t, + control_s, + sw_gt=sw_gt, + sw_gs=sw_gs, + sw_ct=sw_ct, + sw_cs=sw_cs, + ) + + # Build influence function info + # For RCS, treated_idx/control_idx combine obs from BOTH periods + treated_idx = np.concatenate([np.where(treated_t)[0], np.where(treated_s)[0]]) + control_idx = np.concatenate([np.where(control_t)[0], np.where(control_s)[0]]) + + n_treated_combined = len(treated_idx) + inf_func_info = { + "treated_idx": treated_idx, + "control_idx": control_idx, + "treated_units": treated_idx, # For RCS, obs indices = "units" + "control_units": control_idx, + "treated_inf": inf_func_all[:n_treated_combined], + "control_inf": inf_func_all[n_treated_combined:], + } + + sw_sum = float(np.sum(sw_gt)) if sw_gt is not None else None + # n_treated = per-cell treated count at period t (for display). + # cohort_mass = total treated across all periods (for aggregation weights). + cohort_mass = precomputed.get("rcs_cohort_masses", {}).get(g, n_gt) + return att, se, n_gt, n_ct, inf_func_info, sw_sum, cohort_mass + + def _rc_2x2_did( + self, + y_gt, + y_gs, + y_ct, + y_cs, + mask_gt, + mask_gs, + mask_ct, + mask_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Compute the basic 2x2 DiD for RCS (no covariates). + + ATT = (mean(Y_treated_t) - mean(Y_control_t)) + - (mean(Y_treated_s) - mean(Y_control_s)) + + Returns (att, se, inf_func_concat, idx_concat) where inf_func_concat + has treated obs (both periods) first, then control obs (both periods). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + + if sw_gt is not None: + sw_gt_norm = sw_gt / np.sum(sw_gt) + sw_gs_norm = sw_gs / np.sum(sw_gs) + sw_ct_norm = sw_ct / np.sum(sw_ct) + sw_cs_norm = sw_cs / np.sum(sw_cs) + + mu_gt = float(np.sum(sw_gt_norm * y_gt)) + mu_gs = float(np.sum(sw_gs_norm * y_gs)) + mu_ct = float(np.sum(sw_ct_norm * y_ct)) + mu_cs = float(np.sum(sw_cs_norm * y_cs)) + + att = (mu_gt - mu_ct) - (mu_gs - mu_cs) + + # Influence function for 4 groups (survey-weighted) + inf_gt = sw_gt_norm * (y_gt - mu_gt) + inf_ct = -sw_ct_norm * (y_ct - mu_ct) + inf_gs = -sw_gs_norm * (y_gs - mu_gs) + inf_cs = sw_cs_norm * (y_cs - mu_cs) + else: + mu_gt = float(np.mean(y_gt)) + mu_gs = float(np.mean(y_gs)) + mu_ct = float(np.mean(y_ct)) + mu_cs = float(np.mean(y_cs)) + + att = (mu_gt - mu_ct) - (mu_gs - mu_cs) + + # Influence function for 4 groups + inf_gt = (y_gt - mu_gt) / n_gt + inf_ct = -(y_ct - mu_ct) / n_ct + inf_gs = -(y_gs - mu_gs) / n_gs + inf_cs = (y_cs - mu_cs) / n_cs + + # Concatenate: treated (t then s), control (t then s) + inf_treated = np.concatenate([inf_gt, inf_gs]) + inf_control = np.concatenate([inf_ct, inf_cs]) + inf_all = np.concatenate([inf_treated, inf_control]) + + # SE from influence function + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = np.concatenate( + [ + np.where(mask_gt)[0], + np.where(mask_gs)[0], + np.where(mask_ct)[0], + np.where(mask_cs)[0], + ] + ) + + return att, se, inf_all, idx_all + + def _outcome_regression_rc( + self, + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Cross-sectional outcome regression for ATT(g,t). + + Matches R DRDID::reg_did_rc (Sant'Anna & Zhao 2020, Eq 2.2). + + Two OLS models fit on controls (period t and base period s). + Predictions made for ALL treated (both periods). + OR correction pools ALL treated observations across both periods. + + IF convention + ------------- + Intermediate terms use R's unnormalized psi_i convention throughout. + R computes SE as ``sd(psi) / sqrt(n)``; with mean(psi) approx 0 this + equals ``sqrt(sum(psi^2)) / n``. At the end we convert to the + library's pre-scaled phi_i = psi_i / n convention where + ``se = sqrt(sum(phi^2))``, used by the aggregation/bootstrap layer. + + Returns (att, se, inf_func_concat, idx_concat). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + n_all = n_gt + n_gs + n_ct + n_cs + + # --- Fit 2 OLS on control groups (period t and s separately) --- + beta_t, resid_ct = _linear_regression( + X_ct, + y_ct, + rank_deficient_action=self.rank_deficient_action, + weights=sw_ct, + ) + beta_t = np.where(np.isfinite(beta_t), beta_t, 0.0) + + beta_s, resid_cs = _linear_regression( + X_cs, + y_cs, + rank_deficient_action=self.rank_deficient_action, + weights=sw_cs, + ) + beta_s = np.where(np.isfinite(beta_s), beta_s, 0.0) + + # --- Predict counterfactual for ALL treated (both periods) --- + X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) + X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) + X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) + X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) + + # mu_hat_{0,t}(X) and mu_hat_{0,s}(X) for each treated obs + mu_post_gt = X_gt_int @ beta_t # treated-post predicted at post model + mu_pre_gt = X_gt_int @ beta_s # treated-post predicted at pre model + mu_post_gs = X_gs_int @ beta_t # treated-pre predicted at post model + mu_pre_gs = X_gs_int @ beta_s # treated-pre predicted at pre model + + # --- Group weights (R: w.treat.pre, w.treat.post, w.cont = w.D) --- + if sw_gt is not None: + w_treat_post = sw_gt # treated at t + w_treat_pre = sw_gs # treated at s + w_D_gt = sw_gt # ALL treated: t portion + w_D_gs = sw_gs # ALL treated: s portion + else: + w_treat_post = np.ones(n_gt) + w_treat_pre = np.ones(n_gs) + w_D_gt = np.ones(n_gt) + w_D_gs = np.ones(n_gs) + + sum_w_treat_post = np.sum(w_treat_post) + sum_w_treat_pre = np.sum(w_treat_pre) + sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) # pool ALL treated + + # R: mean(w.treat.post), mean(w.treat.pre), mean(w.cont) + mean_w_treat_post = sum_w_treat_post / n_all + mean_w_treat_pre = sum_w_treat_pre / n_all + mean_w_D = sum_w_D / n_all + + # --- Treated means (period-specific Hajek means) --- + eta_treat_post = np.sum(w_treat_post * y_gt) / sum_w_treat_post + eta_treat_pre = np.sum(w_treat_pre * y_gs) / sum_w_treat_pre + + # --- OR correction: pools ALL treated --- + # R: out.y.post - out.y.pre for each treated obs + or_diff_gt = mu_post_gt - mu_pre_gt # treated at t + or_diff_gs = mu_post_gs - mu_pre_gs # treated at s + eta_cont = (np.sum(w_D_gt * or_diff_gt) + np.sum(w_D_gs * or_diff_gs)) / sum_w_D + + # --- Point estimate --- + att = float(eta_treat_post - eta_treat_pre - eta_cont) + + # ================================================================= + # Influence function in R's unnormalized psi convention + # (R: reg_did_rc.R, psi = n * phi) + # ================================================================= + + # --- Treated psi (R: eta.treat.post, eta.treat.pre) --- + # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post) + psi_treat_post = w_treat_post * (y_gt - eta_treat_post) / mean_w_treat_post + # R: w.treat.pre * (y - eta.treat.pre) / mean(w.treat.pre) + psi_treat_pre = w_treat_pre * (y_gs - eta_treat_pre) / mean_w_treat_pre + + # --- Control psi: leading term (R: inf.cont.1) --- + # R: w.cont * (or_diff - eta.cont) [before /mean(w.cont)] + psi_cont_1_gt = w_D_gt * (or_diff_gt - eta_cont) + psi_cont_1_gs = w_D_gs * (or_diff_gs - eta_cont) + + # --- Control psi: estimation effect (R: inf.cont.2) --- + # R: bread = solve(crossprod(X_ctrl, W * X_ctrl) / n) + # Here bread is (X'WX)^{-1} (without /n), so asy_lin_rep already + # absorbs the 1/n that R puts in its bread. We compensate by using + # R's colMeans (= sum/n_all) for M1, matching the product exactly. + W_ct = sw_ct if sw_ct is not None else np.ones(n_ct) + W_cs = sw_cs if sw_cs is not None else np.ones(n_cs) + bread_t = _safe_inv(X_ct_int.T @ (W_ct[:, None] * X_ct_int)) + bread_s = _safe_inv(X_cs_int.T @ (W_cs[:, None] * X_cs_int)) + + # R: M1 = colMeans(w.cont * out.x) = sum(w_D * X) / n_all + M1 = ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / n_all + + # R: asy.lin.rep.ols (per-obs OLS score * bread) + asy_lin_rep_ols_t = (W_ct * resid_ct)[:, None] * X_ct_int @ bread_t + asy_lin_rep_ols_s = (W_cs * resid_cs)[:, None] * X_cs_int @ bread_s + + # R: inf.cont.2.post = asy.lin.rep.ols_t %*% M1 + psi_cont_2_ct = asy_lin_rep_ols_t @ M1 # (n_ct,) + # R: inf.cont.2.pre = asy.lin.rep.ols_s %*% M1 + psi_cont_2_cs = asy_lin_rep_ols_s @ M1 # (n_cs,) + + # --- Assemble per-group psi --- + # R: inf.treat = inf.treat.post - inf.treat.pre (across groups) + # R: inf.cont = (inf.cont.1 + inf.cont.2.post - inf.cont.2.pre) / mean(w.cont) + # R: att.inf.func = inf.treat - inf.cont + psi_gt = psi_treat_post - psi_cont_1_gt / mean_w_D + psi_gs = -psi_treat_pre - psi_cont_1_gs / mean_w_D + psi_ct = -psi_cont_2_ct / mean_w_D + psi_cs = psi_cont_2_cs / mean_w_D + + psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) + + # ================================================================= + # Convert to library convention: phi = psi / n_all + # se = sqrt(sum(phi^2)) == sqrt(sum(psi^2)) / n_all + # ================================================================= + inf_all = psi_all / n_all + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = None # caller builds idx from masks + return att, se, inf_all, idx_all + + def _ipw_estimation_rc( + self, + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Cross-sectional IPW estimation for ATT(g,t). + + Propensity score P(G=g | X) estimated on pooled treated+control + observations from both periods. Reweight controls in each period. + + IF convention + ------------- + Intermediate terms use R's unnormalized psi_i convention throughout + (R: ``ipw_did_rc``). R computes SE as ``sd(psi) / sqrt(n)``. + At the end we convert to the library's pre-scaled phi_i = psi_i / n + convention where ``se = sqrt(sum(phi^2))``, used by the + aggregation/bootstrap layer. + + Returns (att, se, inf_func_concat, idx_concat). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + n_all = n_gt + n_gs + n_ct + n_cs + + # Pool treated and control for propensity score + X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) + D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) + + sw_all = None + if sw_gt is not None: + sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs]) + + try: + beta_logistic, pscore = solve_logit( + X_all, + D_all, + rank_deficient_action=self.rank_deficient_action, + weights=sw_all, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) + except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise + warnings.warn( + "Propensity score estimation failed (RCS IPW). " + "Falling back to unconditional estimation.", + UserWarning, + stacklevel=4, + ) + p_treat = (n_gt + n_gs) / len(D_all) + pscore = np.full(len(D_all), p_treat) + + # Clip propensity scores + pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + + # Split propensity scores (treated ps not used -- only control IPW weights) + ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] + ps_cs = pscore[n_gt + n_gs + n_ct :] + + # IPW weights for controls (R: w1.x = ps / (1 - ps)) + w_ct = ps_ct / (1 - ps_ct) + w_cs = ps_cs / (1 - ps_cs) + + if sw_gt is not None: + w_ct = sw_ct * w_ct + w_cs = sw_cs * w_cs + + # R: mean(w.treat.post), mean(w.treat.pre), mean(w.ipw.ct), mean(w.ipw.cs) + if sw_gt is not None: + sum_w_treat_post = np.sum(sw_gt) + sum_w_treat_pre = np.sum(sw_gs) + else: + sum_w_treat_post = float(n_gt) + sum_w_treat_pre = float(n_gs) + + mean_w_treat_post = sum_w_treat_post / n_all + mean_w_treat_pre = sum_w_treat_pre / n_all + + sum_w_ct = np.sum(w_ct) + sum_w_cs = np.sum(w_cs) + mean_w_ct = sum_w_ct / n_all + mean_w_cs = sum_w_cs / n_all + + # Hajek-normalized weights (R normalizes by sum for point estimate) + w_ct_norm = w_ct / sum_w_ct if sum_w_ct > 0 else w_ct + w_cs_norm = w_cs / sum_w_cs if sum_w_cs > 0 else w_cs + + if sw_gt is not None: + sw_gt_norm = sw_gt / sum_w_treat_post + sw_gs_norm = sw_gs / sum_w_treat_pre + mu_gt = float(np.sum(sw_gt_norm * y_gt)) + mu_gs = float(np.sum(sw_gs_norm * y_gs)) + else: + mu_gt = float(np.mean(y_gt)) + mu_gs = float(np.mean(y_gs)) + + mu_ct_ipw = float(np.sum(w_ct_norm * y_ct)) + mu_cs_ipw = float(np.sum(w_cs_norm * y_cs)) + + att = (mu_gt - mu_ct_ipw) - (mu_gs - mu_cs_ipw) + + # ================================================================= + # Influence function in R's unnormalized psi convention + # (R: ipw_did_rc.R, psi = n * phi) + # ================================================================= + + # --- Treated psi (R: eta.treat.post, eta.treat.pre) --- + # R: w.treat.post * (y - eta.treat.post) / mean(w.treat.post) + if sw_gt is not None: + psi_gt = sw_gt * (y_gt - mu_gt) / mean_w_treat_post + psi_gs = -sw_gs * (y_gs - mu_gs) / mean_w_treat_pre + else: + psi_gt = (y_gt - mu_gt) / mean_w_treat_post + psi_gs = -(y_gs - mu_gs) / mean_w_treat_pre + + # --- Control psi (R: eta.cont.post, eta.cont.pre) --- + # R: w.ipw * (y - eta.cont) / mean(w.ipw) + psi_ct = -w_ct * (y_ct - mu_ct_ipw) / mean_w_ct if mean_w_ct > 0 else np.zeros(n_ct) + psi_cs = w_cs * (y_cs - mu_cs_ipw) / mean_w_cs if mean_w_cs > 0 else np.zeros(n_cs) + + psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) + + # --- PS IF correction (R: asy.lin.rep.ps %*% M2) --- + X_all_int = np.column_stack([np.ones(n_all), X_all]) + + W_ps = pscore * (1 - pscore) + if sw_all is not None: + W_ps = W_ps * sw_all + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps_inv = _safe_inv(H_ps) + + score_ps = (D_all - pscore)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) + + # ================================================================= + # Convert leading psi to library phi convention: phi = psi / n_all + # ================================================================= + inf_all = psi_all / n_all + + # --- PS nuisance correction (added in phi convention) --- + # R: M2 = colMeans(w_ipw * (y-mu) * X) across ALL n obs. + # colMeans = sum/n_all; treated rows contribute zero. + # asy_lin_rep_ps @ M2 matches R's psi_ps but with our bread + # convention (no 1/n), so the product is already in phi scale. + ipw_resid_ct = w_ct_norm * (y_ct - mu_ct_ipw) + ipw_resid_cs = w_cs_norm * (y_cs - mu_cs_ipw) + + ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) + cs_slice = slice(n_gt + n_gs + n_ct, None) + + # R: M2 = colMeans(...) = sum(...) / n + M2 = np.zeros(X_all_int.shape[1]) + M2 += np.sum(ipw_resid_ct[:, None] * X_all_int[ct_slice], axis=0) / n_all + M2 -= np.sum(ipw_resid_cs[:, None] * X_all_int[cs_slice], axis=0) / n_all + + # R: att.inf.func += asy.lin.rep.ps %*% M2 (phi-scale correction) + inf_all = inf_all + asy_lin_rep_ps @ M2 + + # ================================================================= + # SE from phi: se = sqrt(sum(phi^2)) + # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0. + # ================================================================= + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = None + return att, se, inf_all, idx_all + + def _doubly_robust_rc( + self, + y_gt, + y_gs, + y_ct, + y_cs, + X_gt, + X_gs, + X_ct, + X_cs, + sw_gt=None, + sw_gs=None, + sw_ct=None, + sw_cs=None, + ): + """ + Cross-sectional doubly robust estimation for ATT(g,t). + + Matches R DRDID::drdid_rc (Sant'Anna & Zhao 2020, Eq 3.1). + Locally efficient DR estimator with 4 OLS fits (control pre/post, + treated pre/post) plus propensity score. + + IF convention + ------------- + Intermediate terms use R's unnormalized psi_i convention throughout + (R: ``drdid_rc``). R computes SE as ``sd(psi) / sqrt(n)``. + At the end we convert to the library's pre-scaled phi_i = psi_i / n + convention where ``se = sqrt(sum(phi^2))``, used by the + aggregation/bootstrap layer. + + Returns (att, se, inf_func_concat, idx_concat). + """ + n_gt = len(y_gt) + n_gs = len(y_gs) + n_ct = len(y_ct) + n_cs = len(y_cs) + n_all = n_gt + n_gs + n_ct + n_cs + + # ===================================================================== + # 1. Outcome regression: 4 OLS fits + # ===================================================================== + # Control OLS: E[Y|X, D=0, T=t] and E[Y|X, D=0, T=s] + beta_ct, resid_ct = _linear_regression( + X_ct, + y_ct, + rank_deficient_action=self.rank_deficient_action, + weights=sw_ct, + ) + beta_ct = np.where(np.isfinite(beta_ct), beta_ct, 0.0) + + beta_cs, resid_cs = _linear_regression( + X_cs, + y_cs, + rank_deficient_action=self.rank_deficient_action, + weights=sw_cs, + ) + beta_cs = np.where(np.isfinite(beta_cs), beta_cs, 0.0) + + # Treated OLS: E[Y|X, D=1, T=t] and E[Y|X, D=1, T=s] + beta_gt, resid_gt = _linear_regression( + X_gt, + y_gt, + rank_deficient_action=self.rank_deficient_action, + weights=sw_gt, + ) + beta_gt = np.where(np.isfinite(beta_gt), beta_gt, 0.0) + + beta_gs, resid_gs = _linear_regression( + X_gs, + y_gs, + rank_deficient_action=self.rank_deficient_action, + weights=sw_gs, + ) + beta_gs = np.where(np.isfinite(beta_gs), beta_gs, 0.0) + + # Intercept-augmented design matrices + X_gt_int = np.column_stack([np.ones(n_gt), X_gt]) + X_gs_int = np.column_stack([np.ones(n_gs), X_gs]) + X_ct_int = np.column_stack([np.ones(n_ct), X_ct]) + X_cs_int = np.column_stack([np.ones(n_cs), X_cs]) + + # Control OR predictions for all groups + mu0_post_gt = X_gt_int @ beta_ct # mu_{0,1}(X) for treated-post + mu0_pre_gt = X_gt_int @ beta_cs # mu_{0,0}(X) for treated-post + mu0_post_gs = X_gs_int @ beta_ct # mu_{0,1}(X) for treated-pre + mu0_pre_gs = X_gs_int @ beta_cs # mu_{0,0}(X) for treated-pre + mu0_post_ct = X_ct_int @ beta_ct # mu_{0,1}(X) for control-post + mu0_pre_ct = X_ct_int @ beta_cs # mu_{0,0}(X) for control-post + mu0_post_cs = X_cs_int @ beta_ct # mu_{0,1}(X) for control-pre + mu0_pre_cs = X_cs_int @ beta_cs # mu_{0,0}(X) for control-pre + + # Treated OR predictions for all groups (for local efficiency adjustment) + mu1_post_gt = X_gt_int @ beta_gt # mu_{1,1}(X) for treated-post + mu1_pre_gt = X_gt_int @ beta_gs # mu_{1,0}(X) for treated-post + mu1_post_gs = X_gs_int @ beta_gt # mu_{1,1}(X) for treated-pre + mu1_pre_gs = X_gs_int @ beta_gs # mu_{1,0}(X) for treated-pre + + # mu_{0,Y}(T_i, X_i): control OR evaluated at own period + mu0Y_gt = mu0_post_gt # treated-post: use post control model + mu0Y_gs = mu0_pre_gs # treated-pre: use pre control model + mu0Y_ct = mu0_post_ct # control-post: use post control model + mu0Y_cs = mu0_pre_cs # control-pre: use pre control model + + # ===================================================================== + # 2. Propensity score + # ===================================================================== + X_all = np.vstack([X_gt, X_gs, X_ct, X_cs]) + D_all = np.concatenate([np.ones(n_gt + n_gs), np.zeros(n_ct + n_cs)]) + sw_all = None + if sw_gt is not None: + sw_all = np.concatenate([sw_gt, sw_gs, sw_ct, sw_cs]) + + try: + beta_logistic, pscore = solve_logit( + X_all, + D_all, + rank_deficient_action=self.rank_deficient_action, + weights=sw_all, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) + except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise + warnings.warn( + "Propensity score estimation failed (RCS DR). " + "Falling back to unconditional propensity.", + UserWarning, + stacklevel=4, + ) + p_treat = (n_gt + n_gs) / len(D_all) + pscore = np.full(len(D_all), p_treat) + + pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + + # Split propensity scores per group + ps_gt = pscore[:n_gt] + ps_gs = pscore[n_gt : n_gt + n_gs] + ps_ct = pscore[n_gt + n_gs : n_gt + n_gs + n_ct] + ps_cs = pscore[n_gt + n_gs + n_ct :] + + # ===================================================================== + # 3. Group weights and R-convention means + # ===================================================================== + if sw_gt is not None: + w_treat_post = sw_gt + w_treat_pre = sw_gs + w_D_gt = sw_gt + w_D_gs = sw_gs + else: + w_treat_post = np.ones(n_gt) + w_treat_pre = np.ones(n_gs) + w_D_gt = np.ones(n_gt) + w_D_gs = np.ones(n_gs) + + sum_w_treat_post = np.sum(w_treat_post) + sum_w_treat_pre = np.sum(w_treat_pre) + sum_w_D = np.sum(w_D_gt) + np.sum(w_D_gs) + + # R: mean(w) = sum(w) / n -- used in psi normalizers + mean_w_treat_post = sum_w_treat_post / n_all + mean_w_treat_pre = sum_w_treat_pre / n_all + mean_w_D = sum_w_D / n_all + + # IPW control weights: sw * ps/(1-ps) for controls + w_ipw_ct = ps_ct / (1 - ps_ct) + w_ipw_cs = ps_cs / (1 - ps_cs) + if sw_ct is not None: + w_ipw_ct = sw_ct * w_ipw_ct + w_ipw_cs = sw_cs * w_ipw_cs + + sum_w_ipw_ct = np.sum(w_ipw_ct) + sum_w_ipw_cs = np.sum(w_ipw_cs) + mean_w_ipw_ct = sum_w_ipw_ct / n_all + mean_w_ipw_cs = sum_w_ipw_cs / n_all + + # ===================================================================== + # 4. Point estimate: tau_1 (AIPW using control ORs) + # ===================================================================== + # Hajek-normalized means of (y - mu0Y) per group + eta_treat_post = np.sum(w_treat_post * (y_gt - mu0Y_gt)) / sum_w_treat_post + eta_treat_pre = np.sum(w_treat_pre * (y_gs - mu0Y_gs)) / sum_w_treat_pre + + eta_cont_post = ( + np.sum(w_ipw_ct * (y_ct - mu0Y_ct)) / sum_w_ipw_ct if sum_w_ipw_ct > 0 else 0.0 + ) + eta_cont_pre = ( + np.sum(w_ipw_cs * (y_cs - mu0Y_cs)) / sum_w_ipw_cs if sum_w_ipw_cs > 0 else 0.0 + ) + + tau_1 = (eta_treat_post - eta_cont_post) - (eta_treat_pre - eta_cont_pre) + + # ===================================================================== + # 5. Point estimate: local efficiency adjustment (tau_2) + # ===================================================================== + # Differences mu_{1,t}(X) - mu_{0,t}(X) for treated obs + or_diff_post_gt = mu1_post_gt - mu0_post_gt # at treated-post + or_diff_post_gs = mu1_post_gs - mu0_post_gs # at treated-pre + or_diff_pre_gt = mu1_pre_gt - mu0_pre_gt # at treated-post + or_diff_pre_gs = mu1_pre_gs - mu0_pre_gs # at treated-pre + + # att_d_post = mean(w_D * (mu1_post - mu0_post)) / mean(w_D) -- all treated + att_d_post = (np.sum(w_D_gt * or_diff_post_gt) + np.sum(w_D_gs * or_diff_post_gs)) / sum_w_D + # att_dt1_post -- treated-post only + att_dt1_post = np.sum(w_treat_post * or_diff_post_gt) / sum_w_treat_post + # att_d_pre -- all treated + att_d_pre = (np.sum(w_D_gt * or_diff_pre_gt) + np.sum(w_D_gs * or_diff_pre_gs)) / sum_w_D + # att_dt0_pre -- treated-pre only + att_dt0_pre = np.sum(w_treat_pre * or_diff_pre_gs) / sum_w_treat_pre + + tau_2 = (att_d_post - att_dt1_post) - (att_d_pre - att_dt0_pre) + + att = float(tau_1 + tau_2) + + # ===================================================================== + # 6. Influence function in R's unnormalized psi convention + # (R: drdid_rc.R, psi = n * phi) + # ===================================================================== + + # --- tau_1: treated psi (R: eta.treat.post / mean(w.treat.post)) --- + # R: w.treat.post * (y - mu0Y - eta.treat.post) / mean(w.treat.post) + psi_treat_post = w_treat_post * (y_gt - mu0Y_gt - eta_treat_post) / mean_w_treat_post + psi_treat_pre = w_treat_pre * (y_gs - mu0Y_gs - eta_treat_pre) / mean_w_treat_pre + + # --- tau_1: control psi (R: eta.cont.post / mean(w.ipw)) --- + # R: w.ipw * (y - mu0Y - eta.cont) / mean(w.ipw) + psi_cont_post_ct = ( + w_ipw_ct * (y_ct - mu0Y_ct - eta_cont_post) / mean_w_ipw_ct + if mean_w_ipw_ct > 0 + else np.zeros(n_ct) + ) + psi_cont_pre_cs = ( + w_ipw_cs * (y_cs - mu0Y_cs - eta_cont_pre) / mean_w_ipw_cs + if mean_w_ipw_cs > 0 + else np.zeros(n_cs) + ) + + # tau_1 psi per group + psi_gt_tau1 = psi_treat_post + psi_gs_tau1 = -psi_treat_pre + psi_ct_tau1 = -psi_cont_post_ct + psi_cs_tau1 = psi_cont_pre_cs + + # ===================================================================== + # 7. tau_2 leading terms (R: att.d.post, att.dt1.post, etc.) + # ===================================================================== + # R: w.D * (or_diff - att.d.post) / mean(w.D) + psi_d_post_gt = w_D_gt * (or_diff_post_gt - att_d_post) / mean_w_D + psi_d_post_gs = w_D_gs * (or_diff_post_gs - att_d_post) / mean_w_D + # R: w.treat.post * (or_diff - att.dt1.post) / mean(w.treat.post) + psi_dt1_post = w_treat_post * (or_diff_post_gt - att_dt1_post) / mean_w_treat_post + # R: w.D * (or_diff_pre - att.d.pre) / mean(w.D) + psi_d_pre_gt = w_D_gt * (or_diff_pre_gt - att_d_pre) / mean_w_D + psi_d_pre_gs = w_D_gs * (or_diff_pre_gs - att_d_pre) / mean_w_D + # R: w.treat.pre * (or_diff_pre - att.dt0.pre) / mean(w.treat.pre) + psi_dt0_pre = w_treat_pre * (or_diff_pre_gs - att_dt0_pre) / mean_w_treat_pre + + # tau_2 psi per group (controls contribute zero) + psi_gt_tau2 = (psi_d_post_gt - psi_dt1_post) - psi_d_pre_gt + psi_gs_tau2 = psi_d_post_gs - (-psi_dt0_pre + psi_d_pre_gs) + + # ===================================================================== + # 8. Combined plug-in psi (before nuisance corrections) + # ===================================================================== + psi_gt = psi_gt_tau1 + psi_gt_tau2 + psi_gs = psi_gs_tau1 + psi_gs_tau2 + psi_ct = psi_ct_tau1 + psi_cs = psi_cs_tau1 + + psi_all = np.concatenate([psi_gt, psi_gs, psi_ct, psi_cs]) + + # ================================================================= + # Convert leading psi to library phi convention: phi = psi / n_all + # ================================================================= + inf_all = psi_all / n_all + + # ===================================================================== + # 9. PS nuisance correction (phi-scale) + # asy_lin_rep_ps @ M2 with M2 = colMeans(...) is already in phi + # scale because our bread omits R's 1/n factor while M2 absorbs it. + # ===================================================================== + X_all_int = np.column_stack([np.ones(n_all), X_all]) + + W_ps = pscore * (1 - pscore) + if sw_all is not None: + W_ps = W_ps * sw_all + H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) + H_ps_inv = _safe_inv(H_ps) + + score_ps = (D_all - pscore)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_ps_inv # (n_all, p+1) + + # R: M2 = colMeans(w_ipw * dr_resid / mean(w_ipw) * X) + # colMeans = sum / n_all; treated rows contribute zero. + ct_slice = slice(n_gt + n_gs, n_gt + n_gs + n_ct) + cs_slice = slice(n_gt + n_gs + n_ct, None) + + dr_resid_ct = y_ct - mu0Y_ct - eta_cont_post + dr_resid_cs = y_cs - mu0Y_cs - eta_cont_pre + + M2 = np.zeros(X_all_int.shape[1]) + if sum_w_ipw_ct > 0: + M2 -= ( + np.sum( + ((w_ipw_ct * dr_resid_ct / sum_w_ipw_ct)[:, None] * X_all_int[ct_slice]), + axis=0, + ) + / n_all + ) + if sum_w_ipw_cs > 0: + M2 += ( + np.sum( + ((w_ipw_cs * dr_resid_cs / sum_w_ipw_cs)[:, None] * X_all_int[cs_slice]), + axis=0, + ) + / n_all + ) + + inf_all = inf_all + asy_lin_rep_ps @ M2 + + # ===================================================================== + # 10. Control OR nuisance corrections (phi-scale) + # ===================================================================== + W_ct_vals = sw_ct if sw_ct is not None else np.ones(n_ct) + W_cs_vals = sw_cs if sw_cs is not None else np.ones(n_cs) + bread_ct = _safe_inv(X_ct_int.T @ (W_ct_vals[:, None] * X_ct_int)) + bread_cs = _safe_inv(X_cs_int.T @ (W_cs_vals[:, None] * X_cs_int)) + + # R: asy.lin.rep.ols (per-obs OLS score * bread) + asy_lin_rep_ct = (W_ct_vals * resid_ct)[:, None] * X_ct_int @ bread_ct + asy_lin_rep_cs = (W_cs_vals * resid_cs)[:, None] * X_cs_int @ bread_cs + + # M1 for control-post model (beta_ct): gradient from tau_1 + tau_2 + # tau_1: -w_treat_post*X/sum_w_treat_post (eta_treat_post via mu0Y_gt) + # +w_ipw_ct*X/sum_w_ipw_ct (eta_cont_post via mu0Y_ct) + # tau_2: -w_D*X/sum_w_D (att_d_post via mu0_post at all treated) + # +w_treat_post*X/sum_w_treat_post (att_dt1_post via mu0_post) + M1_ct = np.zeros(X_all_int.shape[1]) + M1_ct -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post + if sum_w_ipw_ct > 0: + M1_ct += np.sum(w_ipw_ct[:, None] * X_ct_int, axis=0) / sum_w_ipw_ct + M1_ct -= ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + M1_ct += np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post + + # M1 for control-pre model (beta_cs) + M1_cs = np.zeros(X_all_int.shape[1]) + M1_cs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre + if sum_w_ipw_cs > 0: + M1_cs -= np.sum(w_ipw_cs[:, None] * X_cs_int, axis=0) / sum_w_ipw_cs + M1_cs += ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + M1_cs -= np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre + + inf_all[n_gt + n_gs : n_gt + n_gs + n_ct] += asy_lin_rep_ct @ M1_ct + inf_all[n_gt + n_gs + n_ct :] += asy_lin_rep_cs @ M1_cs + + # ===================================================================== + # 11. Treated OR nuisance corrections (phi-scale) + # ===================================================================== + W_gt_vals = sw_gt if sw_gt is not None else np.ones(n_gt) + W_gs_vals = sw_gs if sw_gs is not None else np.ones(n_gs) + bread_gt = _safe_inv(X_gt_int.T @ (W_gt_vals[:, None] * X_gt_int)) + bread_gs = _safe_inv(X_gs_int.T @ (W_gs_vals[:, None] * X_gs_int)) + + asy_lin_rep_gt = (W_gt_vals * resid_gt)[:, None] * X_gt_int @ bread_gt + asy_lin_rep_gs = (W_gs_vals * resid_gs)[:, None] * X_gs_int @ bread_gs + + # M1 for treated-post model (beta_gt): mu_{1,1}(X) + # From att_d_post: +w_D*X/sum_w_D (all treated) + # From att_dt1_post: -w_treat_post*X/sum_w_treat_post (treated-post) + M1_gt = np.zeros(X_all_int.shape[1]) + M1_gt += ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + M1_gt -= np.sum(w_treat_post[:, None] * X_gt_int, axis=0) / sum_w_treat_post + + # M1 for treated-pre model (beta_gs): mu_{1,0}(X) + # From att_d_pre: -w_D*X/sum_w_D + # From att_dt0_pre: +w_treat_pre*X/sum_w_treat_pre + M1_gs = np.zeros(X_all_int.shape[1]) + M1_gs -= ( + np.sum(w_D_gt[:, None] * X_gt_int, axis=0) + np.sum(w_D_gs[:, None] * X_gs_int, axis=0) + ) / sum_w_D + M1_gs += np.sum(w_treat_pre[:, None] * X_gs_int, axis=0) / sum_w_treat_pre + + inf_all[:n_gt] += asy_lin_rep_gt @ M1_gt + inf_all[n_gt : n_gt + n_gs] += asy_lin_rep_gs @ M1_gs + + # ================================================================= + # SE from phi: se = sqrt(sum(phi^2)) + # Equivalent to R's sqrt(sum(psi^2)) / n when mean(psi) approx 0. + # ================================================================= + se = float(np.sqrt(np.sum(inf_all**2))) + + idx_all = None + return att, se, inf_all, idx_all + def get_params(self) -> Dict[str, Any]: """Get estimator parameters (sklearn-compatible).""" return { @@ -2252,6 +3587,7 @@ def get_params(self) -> Dict[str, Any]: "base_period": self.base_period, "cband": self.cband, "pscore_trim": self.pscore_trim, + "panel": self.panel, } def set_params(self, **params) -> "CallawaySantAnna": diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 3b75bd8..312e50a 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -5,7 +5,7 @@ group-time average treatment effects into summary measures. """ -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import pandas as pd @@ -79,11 +79,13 @@ def _aggregate_simple( if t < g - self.anticipation: continue effects.append(data["effect"]) - # Use fixed cohort-level survey weight sum for aggregation + # Use fixed cohort-level survey weight sum for aggregation. + # For RCS, data["agg_weight"] holds the fixed cohort mass; + # for panel, fallback to data["n_treated"]. if survey_cohort_weights is not None and g in survey_cohort_weights: weights_list.append(survey_cohort_weights[g]) else: - weights_list.append(data["n_treated"]) + weights_list.append(data.get("agg_weight", data["n_treated"])) gt_pairs.append((g, t)) groups_for_gt.append(g) @@ -250,8 +252,15 @@ def _compute_combined_influence_function( return np.zeros(n_global_units), None return np.zeros(0), None + # Detect RCS mode via explicit flag. In RCS, obs indices ARE array positions. + _is_rcs = precomputed is not None and not precomputed.get("is_panel", True) + # Build unit index mapping (local or global) - if global_unit_to_idx is not None and n_global_units is not None: + if _is_rcs and n_global_units is not None: + # RCS: direct indexing — obs indices are the array positions + n_units = n_global_units + all_units = None + elif global_unit_to_idx is not None and n_global_units is not None: n_units = n_global_units all_units = None # caller already has the unit list else: @@ -287,6 +296,12 @@ def _compute_combined_influence_function( mask_g = precomputed_cohorts == g group_sizes[g] = float(np.sum(survey_w[mask_g])) total_weight = float(np.sum(survey_w)) + elif _is_rcs: + # RCS without survey: count observations per cohort + precomputed_cohorts = precomputed["unit_cohorts"] + for g in unique_groups: + group_sizes[g] = int(np.sum(precomputed_cohorts == g)) + total_weight = float(n_units) else: for g in unique_groups: treated_in_g = df[df["first_treat"] == g][unit].nunique() @@ -325,21 +340,31 @@ def _compute_combined_influence_function( # Build unit-group array: normalize iterator to (idx, uid) pairs unit_groups_array = np.full(n_units, -1, dtype=np.float64) - idx_uid_pairs = ( - [(idx, uid) for uid, idx in global_unit_to_idx.items()] - if global_unit_to_idx is not None - else list(enumerate(all_units)) - ) - if precomputed is not None: + if _is_rcs: + # RCS: direct vectorized assignment — obs indices are positions precomputed_cohorts = precomputed["unit_cohorts"] - precomputed_unit_to_idx = precomputed["unit_to_idx"] - for idx, uid in idx_uid_pairs: - if uid in precomputed_unit_to_idx: - cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] - if cohort in unique_groups_set: - unit_groups_array[idx] = cohort + for g in unique_groups: + mask_g = precomputed_cohorts == g + unit_groups_array[mask_g] = g + elif global_unit_to_idx is not None: + idx_uid_pairs = [(idx, uid) for uid, idx in global_unit_to_idx.items()] + + if precomputed is not None: + precomputed_cohorts = precomputed["unit_cohorts"] + precomputed_unit_to_idx = precomputed["unit_to_idx"] + for idx, uid in idx_uid_pairs: + if uid in precomputed_unit_to_idx: + cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] + if cohort in unique_groups_set: + unit_groups_array[idx] = cohort + else: + for idx, uid in idx_uid_pairs: + unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0] + if unit_first_treat in unique_groups_set: + unit_groups_array[idx] = unit_first_treat else: + idx_uid_pairs = list(enumerate(all_units)) for idx, uid in idx_uid_pairs: unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0] if unit_first_treat in unique_groups_set: @@ -357,10 +382,14 @@ def _compute_combined_influence_function( # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT s_i * (1{G_i=g} - pg_k). # The pg subtraction is NOT weighted by s_i because pg is already # the population-level expected value of w_i * 1{G_i=g}. - if global_unit_to_idx is not None and precomputed is not None: + if _is_rcs and precomputed is not None: + # RCS: survey weights are already per-observation, direct indexing + unit_sw = survey_w + elif global_unit_to_idx is not None and precomputed is not None: unit_sw = np.zeros(n_units) precomputed_unit_to_idx_local = precomputed["unit_to_idx"] - for idx, uid in idx_uid_pairs: + idx_uid_pairs_sw = [(idx, uid) for uid, idx in global_unit_to_idx.items()] + for idx, uid in idx_uid_pairs_sw: if uid in precomputed_unit_to_idx_local: pc_idx = precomputed_unit_to_idx_local[uid] unit_sw[idx] = survey_w[pc_idx] @@ -420,7 +449,8 @@ def _compute_aggregated_se_with_wif( df: pd.DataFrame, unit: str, precomputed: Optional["PrecomputedData"] = None, - ) -> float: + return_psi: bool = False, + ) -> "Union[float, Tuple[float, np.ndarray]]": """ Compute SE with weight influence function (wif) adjustment. @@ -443,8 +473,10 @@ def _compute_aggregated_se_with_wif( global_unit_to_idx = None n_global_units = None if precomputed is not None: - global_unit_to_idx = precomputed["unit_to_idx"] - n_global_units = len(precomputed["all_units"]) + global_unit_to_idx = precomputed["unit_to_idx"] # None for RCS + n_global_units = precomputed.get( + "canonical_size", len(precomputed.get("all_units", [])) + ) elif df is not None and unit is not None: n_global_units = df[unit].nunique() @@ -462,18 +494,22 @@ def _compute_aggregated_se_with_wif( ) if len(psi_total) == 0: - return 0.0 + return (0.0, psi_total) if return_psi else 0.0 # Check for NaN propagation from non-finite WIF if not np.all(np.isfinite(psi_total)): - return np.nan + return (np.nan, psi_total) if return_psi else np.nan # Use design-based variance when full survey design is available # Use unit-level resolved survey (panel IF is indexed by unit, not obs) resolved_survey = ( precomputed.get("resolved_survey_unit") if precomputed is not None else None ) - if resolved_survey is not None and hasattr(resolved_survey, "uses_replicate_variance") and resolved_survey.uses_replicate_variance: + if ( + resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): from diff_diff.survey import compute_replicate_if_variance variance, n_valid_rep = compute_replicate_if_variance(psi_total, resolved_survey) @@ -481,8 +517,10 @@ def _compute_aggregated_se_with_wif( if precomputed is not None and n_valid_rep < resolved_survey.n_replicates: precomputed["df_survey"] = n_valid_rep - 1 if n_valid_rep > 1 else None if np.isnan(variance): - return np.nan - return np.sqrt(max(variance, 0.0)) + se = np.nan + else: + se = np.sqrt(max(variance, 0.0)) + return (se, psi_total) if return_psi else se if resolved_survey is not None and ( resolved_survey.strata is not None @@ -493,11 +531,14 @@ def _compute_aggregated_se_with_wif( variance = compute_survey_if_variance(psi_total, resolved_survey) if np.isnan(variance): - return np.nan - return np.sqrt(max(variance, 0.0)) + se = np.nan + else: + se = np.sqrt(max(variance, 0.0)) + return (se, psi_total) if return_psi else se variance = np.sum(psi_total**2) - return np.sqrt(variance) + se = np.sqrt(variance) + return (se, psi_total) if return_psi else se def _aggregate_event_study( self, @@ -536,10 +577,12 @@ def _aggregate_event_study( e = t - g # Relative time if e not in effects_by_e: effects_by_e[e] = [] + # For RCS, data["agg_weight"] holds the fixed cohort mass; + # for panel, fallback to data["n_treated"]. w = ( survey_cohort_weights[g] if survey_cohort_weights is not None and g in survey_cohort_weights - else data["n_treated"] + else data.get("agg_weight", data["n_treated"]) ) effects_by_e[e].append( ( @@ -567,7 +610,7 @@ def _aggregate_event_study( w = ( survey_cohort_weights[g] if survey_cohort_weights is not None and g in survey_cohort_weights - else data["n_treated"] + else data.get("agg_weight", data["n_treated"]) ) balanced_effects[e].append( ( @@ -583,6 +626,8 @@ def _aggregate_event_study( agg_effects_list = [] agg_ses_list = [] agg_n_groups = [] + _psi_vectors = [] # Per-event-time combined IF vectors for VCV + _psi_event_times = [] # Event times that contributed a psi column for e, effect_list in sorted_periods: gt_pairs = [x[0] for x in effect_list] effs = np.array([x[1] for x in effect_list]) @@ -605,23 +650,37 @@ def _aggregate_event_study( # Compute SE with WIF adjustment (matching R's did::aggte) groups_for_gt = np.array([g for (g, t) in gt_pairs]) - agg_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed + result = self._compute_aggregated_se_with_wif( + gt_pairs, + weights, + effs, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, + return_psi=True, ) + agg_se, psi_e = result agg_effects_list.append(agg_effect) agg_ses_list.append(agg_se) agg_n_groups.append(len(effect_list)) + _psi_vectors.append(psi_e) + _psi_event_times.append(e) # Batch inference for all relative periods if not agg_effects_list: return {} df_survey_val = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference - if (df_survey_val is None and precomputed is not None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + df_survey_val is None + and precomputed is not None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( np.array(agg_effects_list), @@ -654,6 +713,53 @@ def _aggregate_event_study( "n_groups": 0, } + # Compute full event-study VCV from per-event-time IF vectors (Phase 7d) + # This enables HonestDiD to use the full covariance structure + event_study_vcov = None + valid_psi = [p for p in _psi_vectors if len(p) > 0] + if valid_psi: + try: + Psi = np.column_stack(valid_psi) # (n_units, n_event_times) + resolved_survey = ( + precomputed.get("resolved_survey_unit") if precomputed is not None else None + ) + if ( + resolved_survey is not None + and not ( + hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ) + and ( + resolved_survey.strata is not None + or resolved_survey.psu is not None + or resolved_survey.fpc is not None + ) + ): + from diff_diff.survey import _compute_stratified_psu_meat + + meat, _, _ = _compute_stratified_psu_meat(Psi, resolved_survey) + event_study_vcov = meat + elif ( + resolved_survey is not None + and hasattr(resolved_survey, "uses_replicate_variance") + and resolved_survey.uses_replicate_variance + ): + # Replicate-weight: fall back to None (diagonal in HonestDiD) + # until multivariate replicate VCV is implemented + event_study_vcov = None + else: + # No survey: simple sum-of-outer-products + event_study_vcov = Psi.T @ Psi + except (ValueError, np.linalg.LinAlgError): + pass # Fall back to diagonal (None) + + # Store the event-time index that matches VCV columns (for subsetting + # in HonestDiD when some event times are filtered out) + self._event_study_vcov_index = _psi_event_times if event_study_vcov is not None else None + + # Attach VCV to self for CallawaySantAnna to pick up + self._event_study_vcov = event_study_vcov + return event_study_effects def _aggregate_by_group( @@ -704,8 +810,7 @@ def _aggregate_by_group( # Use WIF-adjusted SE (with survey design support) groups_for_gt = np.array([gg for (gg, t) in gt_pairs]) agg_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights, effs, groups_for_gt, - influence_func_info, df, unit, precomputed + gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed ) group_data_list.append((g, agg_effect, agg_se, len(g_effects))) @@ -717,10 +822,13 @@ def _aggregate_by_group( agg_ses = np.array([x[2] for x in group_data_list]) df_survey_val = precomputed.get("df_survey") if precomputed is not None else None # Guard: replicate design with undefined df → NaN inference - if (df_survey_val is None and precomputed is not None - and precomputed.get("resolved_survey_unit") is not None - and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance') - and precomputed["resolved_survey_unit"].uses_replicate_variance): + if ( + df_survey_val is None + and precomputed is not None + and precomputed.get("resolved_survey_unit") is not None + and hasattr(precomputed["resolved_survey_unit"], "uses_replicate_variance") + and precomputed["resolved_survey_unit"].uses_replicate_variance + ): df_survey_val = 0 t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( agg_effects, diff --git a/diff_diff/staggered_bootstrap.py b/diff_diff/staggered_bootstrap.py index fc54cc4..2b9095f 100644 --- a/diff_diff/staggered_bootstrap.py +++ b/diff_diff/staggered_bootstrap.py @@ -196,8 +196,8 @@ def _run_multiplier_bootstrap( # units don't appear in any influence function. if precomputed is not None: all_units = precomputed["all_units"] - n_units = len(all_units) - unit_to_idx = precomputed["unit_to_idx"] + n_units = precomputed.get("canonical_size", len(all_units)) + unit_to_idx = precomputed["unit_to_idx"] # None for RCS else: # Fallback: collect units from influence functions all_units_set = set() @@ -211,8 +211,9 @@ def _run_multiplier_bootstrap( ) unit_to_idx = {u: i for i, u in enumerate(all_units)} - # Get list of (g,t) pairs - gt_pairs = list(group_time_effects.keys()) + # Get list of (g,t) pairs that have influence function info + # (skip zero-mass cells that recorded NaN ATT without IF) + gt_pairs = [gt for gt in group_time_effects.keys() if gt in influence_func_info] n_gt = len(gt_pairs) # Identify post-treatment (g,t) pairs for overall ATT @@ -235,12 +236,16 @@ def _run_multiplier_bootstrap( g = gt[0] if g not in _cohort_mass_cache: _cohort_mass_cache[g] = float(np.sum(survey_w[unit_cohorts == g])) - all_n_treated = np.array( - [_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float - ) + all_n_treated = np.array([_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float) else: + # Use agg_weight if available (RCS: fixed cohort mass); + # fall back to n_treated for panel data all_n_treated = np.array( - [group_time_effects[gt]["n_treated"] for gt in gt_pairs], dtype=float + [ + group_time_effects[gt].get("agg_weight", group_time_effects[gt]["n_treated"]) + for gt in gt_pairs + ], + dtype=float, ) post_n_treated = all_n_treated[post_treatment_mask] @@ -426,8 +431,9 @@ def _run_multiplier_bootstrap( # Batch compute bootstrap statistics for ATT(g,t) batch_ses, batch_ci_lo, batch_ci_hi, batch_pv = _compute_effect_bootstrap_stats_batch_func( - original_atts, bootstrap_atts_gt, alpha=self.alpha, - + original_atts, + bootstrap_atts_gt, + alpha=self.alpha, ) gt_ses = {} gt_cis = {} @@ -444,8 +450,10 @@ def _run_multiplier_bootstrap( overall_p_value = np.nan else: overall_se, overall_ci, overall_p_value = _compute_effect_bootstrap_stats_func( - original_overall, bootstrap_overall, alpha=self.alpha, context="overall ATT", - + original_overall, + bootstrap_overall, + alpha=self.alpha, + context="overall ATT", ) # Batch compute bootstrap statistics for event study effects @@ -457,8 +465,9 @@ def _run_multiplier_bootstrap( es_effects = np.array([event_study_info[e]["effect"] for e in rel_periods]) es_boot_matrix = np.column_stack([bootstrap_event_study[e] for e in rel_periods]) es_ses, es_ci_lo, es_ci_hi, es_pv = _compute_effect_bootstrap_stats_batch_func( - es_effects, es_boot_matrix, alpha=self.alpha, - + es_effects, + es_boot_matrix, + alpha=self.alpha, ) event_study_ses = {e: float(es_ses[i]) for i, e in enumerate(rel_periods)} event_study_cis = { @@ -475,8 +484,9 @@ def _run_multiplier_bootstrap( grp_effects = np.array([group_agg_info[g]["effect"] for g in group_list]) grp_boot_matrix = np.column_stack([bootstrap_group[g] for g in group_list]) grp_ses, grp_ci_lo, grp_ci_hi, grp_pv = _compute_effect_bootstrap_stats_batch_func( - grp_effects, grp_boot_matrix, alpha=self.alpha, - + grp_effects, + grp_boot_matrix, + alpha=self.alpha, ) group_effect_ses = {g: float(grp_ses[i]) for i, g in enumerate(group_list)} group_effect_cis = { @@ -569,7 +579,10 @@ def _agg_weight(g: Any, t: Any) -> float: if g not in _cohort_mass: _cohort_mass[g] = float(np.sum(survey_w[unit_cohorts == g])) return _cohort_mass[g] - return group_time_effects[(g, t)]["n_treated"] + # Use agg_weight if available (RCS: fixed cohort mass) + return group_time_effects[(g, t)].get( + "agg_weight", group_time_effects[(g, t)]["n_treated"] + ) # Organize by relative time effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {} diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index 3fea9cc..65132af 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -111,9 +111,13 @@ class CallawaySantAnnaResults: alpha: float = 0.05 control_group: str = "never_treated" base_period: str = "varying" + panel: bool = True event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None) influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) + # Full event-study VCV matrix (Phase 7d): indexed by event_study_vcov_index + event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False) + event_study_vcov_index: Optional[list] = field(default=None, repr=False) bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None pscore_trim: float = 0.01 @@ -153,8 +157,8 @@ def summary(self, alpha: Optional[float] = None) -> str: "=" * 85, "", f"{'Total observations:':<30} {self.n_obs:>10}", - f"{'Treated units:':<30} {self.n_treated_units:>10}", - f"{'Never-treated units:':<30} {self.n_control_units:>10}", + f"{'Treated ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_treated_units:>10}", + f"{'Control ' + ('obs:' if not self.panel else 'units:'):<30} {self.n_control_units:>10}", f"{'Treatment cohorts:':<30} {len(self.groups):>10}", f"{'Time periods:':<30} {len(self.time_periods):>10}", f"{'Control group:':<30} {self.control_group:>10}", diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 9e2c23c..828aa21 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,10 +416,15 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) -- **Note:** CallawaySantAnna survey support: weights, strata, PSU, and FPC are all supported. Analytical (`n_bootstrap=0`): aggregated SEs use design-based variance via `compute_survey_if_variance()`. Bootstrap (`n_bootstrap>0`): PSU-level multiplier weights replace analytical SEs for aggregated quantities. Regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. +- **Note:** CallawaySantAnna survey support: weights, strata, PSU, and FPC are all supported for all estimation methods (reg, ipw, dr) with or without covariates. Analytical (`n_bootstrap=0`): aggregated SEs use design-based variance via `compute_survey_if_variance()`. Bootstrap (`n_bootstrap>0`): PSU-level multiplier weights replace analytical SEs for aggregated quantities. IPW and DR with covariates use DRDID panel nuisance IF corrections (Phase 7a: PS IF correction via survey-weighted Hessian/score, OR IF correction via WLS bread and gradient; Sant'Anna & Zhao 2020, Theorem 3.1). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Per-unit survey weights are extracted via `groupby(unit).first()` from the panel-normalized pweight array; on unbalanced panels the pweight normalization (`w * n_obs / sum(w)`) preserves relative unit weights since all IF/WIF formulas use weight ratios (`sw_i / sum(sw)`) where the normalization constant cancels. Scale-invariance tests pass on both balanced and unbalanced panels. +- **Note (deviation from R):** Panel DR control augmentation is normalized by treated mass (`sw_t_sum` or `n_t`) rather than control IPW mass (`sum(w_cont)`). R's `DRDID::drdid_panel` uses `mean(w.cont)` as the control normalizer. Both are consistent asymptotically (under correct model specification, `E[w_cont] = E[D]` so the normalizers converge), but they differ in finite samples when IPW reweighting doesn't perfectly balance. The treated-mass normalization is simpler and matches the `did::att_gt` convention where ATT is defined per treated unit. Aligning to `DRDID::drdid_panel`'s exact `w.cont` normalization is deferred. - **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SE uses a conservative plug-in IF based on WLS residuals. The treated IF is `inf_treated_i = (sw_i/sum(sw_treated)) * (resid_i - ATT)` (normalized by treated weight sum, matching unweighted `(resid-ATT)/n_t`). The control IF is `inf_control_i = -(sw_i/sum(sw_control)) * wls_resid_i` (normalized by control weight sum, matching unweighted `-resid/n_c`). SE is computed as `sqrt(sum(sw_t_norm * (resid_t - ATT)^2) + sum(sw_c_norm * resid_c^2))`, the weighted analogue of the unweighted `sqrt(var_t/n_t + var_c/n_c)`. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel` — WLS residuals are orthogonal to the weighted design matrix by construction, so the first-order IF term is asymptotically valid but may be conservative. SEs pass weight-scale-invariance tests. The efficient DRDID correction is deferred to future work. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization. When strata/PSU/FPC are present, analytical aggregated SEs (`n_bootstrap=0`) use `compute_survey_if_variance()` on the combined IF/WIF; bootstrap aggregated SEs (`n_bootstrap>0`) use PSU-level multiplier weights. +- **Note:** Repeated cross-sections (`panel=False`, Phase 7b): supports surveys like BRFSS, ACS annual, and CPS monthly where units are not followed over time. Uses cross-sectional DRDID (Sant'Anna & Zhao 2020, Section 4): `reg` matches `DRDID::reg_did_rc` (Eq 2.2), `dr` matches `DRDID::drdid_rc` (locally efficient, Eq 3.3+3.4 with 4 OLS fits), `ipw` matches `DRDID::std_ipw_did_rc`. Per-observation influence functions instead of per-unit. All three estimation methods support covariates and survey weights. +- **Note (deviation from R):** RCS influence functions use `phi_i = psi_i / n` convention (SE = `sqrt(sum(phi^2))`), matching the library-wide IF convention where IFs are pre-scaled by `1/n`. R's DRDID uses `psi_i` directly with `SE = sd(psi) / sqrt(n)`. These are algebraically equivalent — `sqrt(sum(psi^2/n^2)) = sqrt(sum(psi^2))/n ≈ sd(psi)/sqrt(n)` — confirmed by analytical-vs-bootstrap SE convergence tests. The `1/n_all` denominator in gradient terms (`M1`, `M2`) is not "extra shrinkage" but the `colMeans` → phi convention conversion. +- **Note:** Non-survey DR path also includes nuisance IF corrections (PS + OR), matching the survey path structure (Phase 7a). Previously used plug-in IF only. + **Reference implementation(s):** - R: `did::att_gt()` (Callaway & Sant'Anna's official package) - Stata: `csdid` @@ -430,6 +435,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - [ ] Aggregations: simple, event_study, group all implemented - [ ] Doubly robust estimation when covariates provided - [ ] Multiplier bootstrap preserves panel structure +- [x] Repeated cross-sections (`panel=False`) for non-panel surveys (Phase 7b) --- @@ -1630,6 +1636,8 @@ Confidence intervals: - Breakdown point: smallest M where CI includes zero - M=0: reduces to standard parallel trends - Negative M: not valid (constraints become infeasible) +- **Note:** Phase 7d: survey variance support. When input results carry `survey_metadata` with `df_survey`, HonestDiD uses t-distribution critical values (via `_get_critical_value(alpha, df)`) instead of normal. CallawaySantAnnaResults now stores `event_study_vcov` (full cross-event-time VCV from IF vectors), which HonestDiD uses instead of the diagonal fallback. For replicate-weight designs, the event-study VCV falls back to diagonal (multivariate replicate VCV deferred). +- **Note (deviation from R):** When CallawaySantAnna results are passed to HonestDiD, `base_period != "universal"` emits a warning but does not error. R's `honest_did::honest_did.AGGTEobj` requires universal base period. Our implementation warns because the varying-base pre-treatment coefficients use consecutive comparisons (not a common reference), which changes the parallel-trends restriction interpretation. **Reference implementation(s):** - R: `HonestDiD` package (Rambachan & Roth's official package) diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index a521d4e..41f1f35 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -72,10 +72,10 @@ TripleDifference IPW/DR from Phase 3 deferred work. ### Phase 5 Deferred Work -| Estimator | Deferred Capability | Blocker | -|-----------|-------------------|---------| -| SyntheticDiD | strata/PSU/FPC + survey-aware bootstrap | Bootstrap + Survey Interaction | -| TROP | strata/PSU/FPC + survey-aware bootstrap | Bootstrap + Survey Interaction | +| Estimator | Deferred Capability | Status | +|-----------|-------------------|--------| +| SyntheticDiD | strata/PSU/FPC + survey-aware bootstrap | Resolved in Phase 6 | +| TROP | strata/PSU/FPC + survey-aware bootstrap | Resolved in Phase 6 | ## Phase 6: Advanced Features @@ -104,7 +104,8 @@ JKn requires explicit `replicate_strata` (per-replicate stratum assignment). TripleDifference (analytical only, no bootstrap). Rejected with `NotImplementedError` in DifferenceInDifferences, TwoWayFixedEffects, MultiPeriodDiD, StackedDiD, SunAbraham, ImputationDiD, TwoStageDiD, - SyntheticDiD, TROP. + SyntheticDiD, TROP. Expansion to regression-based estimators (SA, + Imputation, TwoStage, Stacked) is straightforward but deferred. ### DEFF Diagnostics ✅ (2026-03-26) Per-coefficient design effects comparing survey vcov to SRS (HC1) vcov. @@ -120,3 +121,115 @@ estimation (unlike simple subsetting, which would drop design information). - Mask: bool array/Series, column name, or callable - Returns (SurveyDesign, DataFrame) pair with synthetic `_subpop_weight` column - Weight validation relaxed: zero weights allowed (negative still rejected) + +--- + +## Phase 7: Completing the Survey Story + +These items close the remaining gaps that matter for practitioners using major +population surveys (ACS, CPS, BRFSS, MEPS) with modern DiD methods. Together +they make diff-diff the only package — R or Python — with full design-based +variance estimation for heterogeneity-robust DiD estimators. + +### 7a. CallawaySantAnna Covariates + IPW/DR + Survey ✅ + +**Priority: High.** This is the single highest-impact gap. The Callaway-Sant'Anna +`reg` method with covariates already works under survey designs, but the +recommended IPW and DR methods raise `NotImplementedError`. Most applied work +(Medicaid expansion, minimum wage studies) uses DR with covariates. + +**What's needed:** +- Implement DRDID panel nuisance-estimation influence function corrections + under survey weights (Sant'Anna & Zhao 2020, Theorem 3.1) +- Survey-weighted propensity score estimation via `solve_logit()` (already + available from Phase 4) +- Survey-weighted outcome regression for imputation step +- Correct IF that accounts for nuisance parameter estimation uncertainty + under the survey design +- Thread `ResolvedSurveyDesign` through the IPW and DR paths in + `_estimate_att_gt()` + +**Reference:** Sant'Anna, P.H.C. & Zhao, J. (2020). "Doubly Robust +Difference-in-Differences Estimators." *Journal of Econometrics* 219(1). + +**Current gate:** `staggered.py` — `NotImplementedError` when +`estimation_method in ('ipw', 'dr')` and covariates are provided with +`survey_design`. + +### 7b. Repeated Cross-Sections ✅ + +**Priority: High.** Many major surveys (BRFSS, ACS annual cross-sections, +CPS monthly) are not panels — units are not followed over time. The R `did` +package supports `panel=FALSE` for these settings. diff-diff currently +requires panel data for all staggered estimators. + +**What's needed:** +- `panel` parameter on `CallawaySantAnna` (default `True`, set `False` for + repeated cross-sections) +- Repeated cross-section ATT(g,t) estimation using cross-sectional DRDID + (Sant'Anna & Zhao 2020, Section 4) +- Cross-sectional propensity score: model P(G=g | X) instead of panel + first-difference +- Cross-sectional outcome regression: model E[Y | X, G, T] instead of + E[ΔY | X, G] +- Influence functions for cross-sectional case (different from panel IF) +- Survey weight support for repeated cross-sections (weights apply per + observation, no unit-level collapse) +- Update `generate_did_data()` / `generate_staggered_data()` with + `panel=False` option for testing + +**Reference:** Sant'Anna, P.H.C. & Zhao, J. (2020). Sections 3 (panel) vs +4 (repeated cross-sections). Callaway, B. & Sant'Anna, P.H.C. (2021). +Section 4.1. + +**Scope:** CallawaySantAnna only. Other staggered estimators (SunAbraham, +ImputationDiD, TwoStageDiD, StackedDiD) are inherently panel methods. + +### 7c. Survey-Aware DiD Tutorial + +**Priority: High.** diff-diff is the only package (R or Python) with +design-based variance estimation for modern DiD estimators, but no one +knows this. A tutorial demonstrating the full workflow with realistic +survey data would make the capability discoverable. + +**What's needed:** +- Jupyter notebook: `docs/tutorials/16_survey_did.ipynb` +- Sections: + 1. Why survey design matters for DiD (variance inflation from clustering, + weight effects on point estimates — cite Solon, Haider & Wooldridge 2015) + 2. Setting up `SurveyDesign` (weights, strata, PSU, FPC) + 3. Basic DiD with survey design (compare naive vs. design-based SEs) + 4. Staggered DiD with survey weights (CallawaySantAnna) + 5. Replicate weights workflow (BRR/JKn for MEPS/ACS PUMS users) + 6. Subpopulation analysis + 7. DEFF diagnostics — interpreting design effects + 8. Comparison: show that R's `did` package with `weightsname` gives + survey-naive variance while diff-diff gives design-based variance +- Use realistic synthetic data mimicking ACS/CPS structure (stratified + multi-stage design with known treatment effect) +- Cross-reference from README, choosing_estimator.rst, and quickstart.rst + +### 7d. HonestDiD with Survey Variance ✅ + +**Priority: Medium.** Sensitivity analysis (Rambachan & Roth 2023) currently +uses cluster-robust or HC variance. Under complex survey designs, the +variance-covariance matrix should come from TSL or replicate weights. + +**What's needed:** +- Accept optional `survey_design` parameter in `HonestDiD` +- When provided, use survey vcov matrix instead of cluster-robust vcov + for computing sensitivity bounds +- Degrees of freedom from survey design (n_PSU - n_strata) for + t-distribution critical values +- Propagate through all three methods: relative magnitudes, smoothness, + and conditional least favorable (C-LF) +- The core optimization (LP/QP for bounds) is unchanged — only the input + vcov and df change + +**Reference:** Rambachan, A. & Roth, J. (2023). "A More Credible Approach +to Parallel Trends." *Review of Economic Studies* 90(5). + +**Why it matters:** A practitioner who runs CS with survey design but then +runs HonestDiD sensitivity analysis with cluster-robust SEs gets +inconsistent inference. The sensitivity bounds should respect the same +variance structure as the main estimates. diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index 050ae27..9fbb829 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -292,7 +292,7 @@ class TestParameterExtraction: def test_extract_from_multiperiod(self, mock_multiperiod_results): """Test extraction from MultiPeriodDiDResults.""" - (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods, _df) = ( _extract_event_study_params(mock_multiperiod_results) ) @@ -667,7 +667,7 @@ def test_multiperiod_sub_vcov_extraction(self, simple_panel_data): reference_period=3, ) - (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods) = ( + (beta_hat, sigma, num_pre, num_post, pre_periods, post_periods, _df) = ( _extract_event_study_params(results) ) @@ -890,7 +890,9 @@ def test_honest_did_filters_nan_pre_period_effects(self): ) # _extract_event_study_params should filter out period 0 (NaN) - beta_hat, sigma, num_pre, num_post, pre_p, post_p = _extract_event_study_params(results) + beta_hat, sigma, num_pre, num_post, pre_p, post_p, _df = _extract_event_study_params( + results + ) assert len(beta_hat) == 3 # periods 1, 3, 4 (period 0 filtered) assert num_pre == 1 # only period 1 assert num_post == 2 # periods 3, 4 @@ -1090,7 +1092,9 @@ def test_honest_did_nonmonotone_period_labels(self): interaction_indices=interaction_indices, ) - beta_hat, sigma, num_pre, num_post, pre_p, post_p = _extract_event_study_params(results) + beta_hat, sigma, num_pre, num_post, pre_p, post_p, _df = _extract_event_study_params( + results + ) # Pre-periods: 5, 6 (7 is reference, omitted) assert num_pre == 2 @@ -1110,6 +1114,194 @@ def test_honest_did_nonmonotone_period_labels(self): assert sigma[3, 3] == pytest.approx(0.2025) # period 2 variance +# ============================================================================= +# Tests for Phase 7d: Survey variance support +# ============================================================================= + + +class TestSurveyVariance: + """Tests for HonestDiD survey variance support (Phase 7d).""" + + def test_df_survey_extracted_from_cs_results(self): + """df_survey is extracted from CallawaySantAnna survey_metadata.""" + from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + data["weight"] = (1.0 + 0.5 * (np.arange(n_units) % 4))[idx] + data["stratum"] = (np.arange(n_units) // 25)[idx] + data["psu"] = (np.arange(n_units) // 5)[idx] + + sd = SurveyDesign(weights="weight", strata="stratum", psu="psu") + cs_result = CallawaySantAnna(base_period="universal").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + aggregate="event_study", + ) + + honest = HonestDiD(method="smoothness", M=0.0) + h_result = honest.fit(cs_result) + + # df_survey should be propagated + assert h_result.df_survey is not None + assert h_result.df_survey > 0 + assert h_result.survey_metadata is not None + + def test_event_study_vcov_computed(self): + """CallawaySantAnna event_study_vcov is computed and used by HonestDiD.""" + from diff_diff import CallawaySantAnna, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=6, seed=42) + cs_result = CallawaySantAnna(base_period="universal").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + + # VCV should be computed + assert cs_result.event_study_vcov is not None + n_effects = len( + [ + e + for e, d in cs_result.event_study_effects.items() + if d.get("n_groups", 1) > 0 and np.isfinite(d.get("se", np.nan)) + ] + ) + assert cs_result.event_study_vcov.shape == (n_effects, n_effects) + + # Diagonal should be positive + for i in range(n_effects): + assert cs_result.event_study_vcov[i, i] > 0 + + def test_survey_df_widens_bounds(self): + """Survey df (t-distribution) should give wider CIs than normal.""" + from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + data["weight"] = (1.0 + 0.3 * (np.arange(n_units) % 3))[idx] + # 2 strata, 4 PSUs total -> df = 4 - 2 = 2 (very low df) + data["stratum"] = (np.arange(n_units) // 50)[idx] + data["psu"] = (np.arange(n_units) // 25)[idx] + + sd = SurveyDesign(weights="weight", strata="stratum", psu="psu") + cs_result = CallawaySantAnna(base_period="universal").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + aggregate="event_study", + ) + + honest = HonestDiD(method="smoothness", M=0.0) + h_result = honest.fit(cs_result) + + # With df=2, t critical value (~4.3) >> z critical value (1.96) + # So CI width should be wider than 2*1.96*SE + ci_width = h_result.ci_ub - h_result.ci_lb + # Lower bound: normal-based CI width + normal_width = 2 * 1.96 * h_result.original_se + assert ci_width > normal_width + + def test_no_survey_gives_none_df(self): + """Without survey, df_survey should be None.""" + from diff_diff import CallawaySantAnna, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + cs_result = CallawaySantAnna(base_period="universal").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + + honest = HonestDiD(method="smoothness", M=0.0) + h_result = honest.fit(cs_result) + + assert h_result.df_survey is None + assert h_result.survey_metadata is None + + def test_replicate_weight_uses_diagonal_fallback(self): + """Replicate-weight designs should NOT produce full event_study_vcov.""" + from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + # Create replicate weights (4 replicates) + rng = np.random.default_rng(42) + data["weight"] = (1.0 + 0.3 * (np.arange(n_units) % 3))[idx] + for k in range(4): + data[f"repwt_{k}"] = data["weight"] * rng.uniform(0.8, 1.2, len(data)) + # Make constant within unit + unit_rw = data.groupby("unit")[f"repwt_{k}"].first() + data[f"repwt_{k}"] = data["unit"].map(unit_rw) + + sd = SurveyDesign( + weights="weight", + replicate_weights=[f"repwt_{k}" for k in range(4)], + replicate_method="JK1", + ) + cs_result = CallawaySantAnna().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + aggregate="event_study", + ) + + # event_study_vcov should be None (diagonal fallback for replicate designs) + assert cs_result.event_study_vcov is None + + def test_bootstrap_fit_clears_analytical_vcov(self): + """Bootstrap CS results should NOT carry analytical event_study_vcov.""" + from diff_diff import CallawaySantAnna, generate_staggered_data + + data = generate_staggered_data(n_units=100, n_periods=5, seed=42) + cs_result = CallawaySantAnna(n_bootstrap=49, seed=42, base_period="universal").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + + # event_study_vcov should be None when bootstrap is used + # (prevents HonestDiD from mixing analytical VCV with bootstrap SEs) + assert cs_result.event_study_vcov is None + + # HonestDiD should still work (falls back to diagonal from bootstrap SEs) + honest = HonestDiD(method="relative_magnitude", M=1.0) + h_result = honest.fit(cs_result) + assert np.isfinite(h_result.original_se) + assert h_result.original_se > 0 + + # ============================================================================= # Tests for Visualization (without matplotlib) # ============================================================================= diff --git a/tests/test_staggered_rc.py b/tests/test_staggered_rc.py new file mode 100644 index 0000000..dd71e5f --- /dev/null +++ b/tests/test_staggered_rc.py @@ -0,0 +1,520 @@ +"""Tests for Phase 7b: CallawaySantAnna repeated cross-section support. + +Covers: panel=False mode with reg/ipw/dr, covariates, survey weights, +aggregation, bootstrap, control group options, base period options, +and edge cases. +""" + +import numpy as np +import pytest + +from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def rc_data(): + """Basic repeated cross-section data.""" + return generate_staggered_data(n_units=200, n_periods=6, panel=False, seed=42) + + +@pytest.fixture(scope="module") +def rc_data_with_covariates(): + """RCS data with a covariate.""" + data = generate_staggered_data(n_units=200, n_periods=6, panel=False, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.normal(0, 1, len(data)) + return data + + +@pytest.fixture(scope="module") +def rc_data_with_survey(): + """RCS data with survey weights.""" + data = generate_staggered_data(n_units=200, n_periods=6, panel=False, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.normal(0, 1, len(data)) + data["weight"] = rng.uniform(0.5, 2.0, len(data)) + return data + + +# ============================================================================= +# Basic Fit +# ============================================================================= + + +class TestBasicFit: + """Basic repeated cross-section fit tests.""" + + def test_basic_reg(self, rc_data): + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + @pytest.mark.parametrize("method", ["reg", "ipw", "dr"]) + def test_all_methods(self, rc_data, method): + result = CallawaySantAnna(estimation_method=method, panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + def test_panel_param_in_get_params(self): + cs = CallawaySantAnna(panel=False) + params = cs.get_params() + assert params["panel"] is False + + +# ============================================================================= +# Methods Agree Without Covariates +# ============================================================================= + + +class TestMethodsAgreeNoCovariates: + """Without covariates, reg/ipw/dr should give identical ATTs in RCS.""" + + def test_no_covariate_methods_agree(self, rc_data): + results = {} + for method in ["reg", "ipw", "dr"]: + r = CallawaySantAnna(estimation_method=method, panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + results[method] = r.overall_att + + np.testing.assert_allclose(results["reg"], results["ipw"], atol=1e-10) + np.testing.assert_allclose(results["reg"], results["dr"], atol=1e-10) + + +# ============================================================================= +# Treatment Effect Recovery +# ============================================================================= + + +class TestTreatmentEffectRecovery: + """Known DGP should recover approximately correct treatment effect.""" + + def test_positive_effect(self, rc_data): + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + # DGP has treatment_effect=2.0 by default + assert result.overall_att > 0 + assert abs(result.overall_att - 2.0) < 2.0 # within 2 SE roughly + + +# ============================================================================= +# Aggregation +# ============================================================================= + + +class TestAggregation: + """Aggregation types work with RCS.""" + + @pytest.mark.parametrize("agg", ["simple", "event_study", "group"]) + def test_aggregation(self, rc_data, agg): + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate=agg, + ) + assert np.isfinite(result.overall_att) + if agg == "event_study": + assert result.event_study_effects is not None + for e, info in result.event_study_effects.items(): + if info["n_groups"] > 0: + assert np.isfinite(info["effect"]) + if agg == "group": + assert result.group_effects is not None + + +# ============================================================================= +# Covariates +# ============================================================================= + + +class TestCovariates: + """Covariate-adjusted estimation in RCS.""" + + @pytest.mark.parametrize("method", ["reg", "ipw", "dr"]) + def test_with_covariates(self, rc_data_with_covariates, method): + result = CallawaySantAnna(estimation_method=method, panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + +# ============================================================================= +# Survey Weights +# ============================================================================= + + +class TestSurveyWeights: + """Survey weights work with RCS (per-observation).""" + + def test_survey_weights_pweight(self, rc_data_with_survey): + sd = SurveyDesign(weights="weight") + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data_with_survey, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_survey_covariates_dr(self, rc_data_with_survey): + """Combined: survey + covariates + DR + RCS.""" + sd = SurveyDesign(weights="weight") + result = CallawaySantAnna(estimation_method="dr", panel=False).fit( + rc_data_with_survey, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + +# ============================================================================= +# Control Group Options +# ============================================================================= + + +class TestControlGroup: + """Control group options work with RCS.""" + + def test_not_yet_treated(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, control_group="not_yet_treated" + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + def test_never_treated(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, control_group="never_treated" + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Base Period Options +# ============================================================================= + + +class TestBasePeriod: + """Base period options work with RCS.""" + + def test_universal(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, base_period="universal" + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + def test_varying(self, rc_data): + result = CallawaySantAnna(estimation_method="reg", panel=False, base_period="varying").fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Bootstrap +# ============================================================================= + + +class TestBootstrap: + """Bootstrap works with RCS.""" + + def test_bootstrap_reg(self, rc_data): + result = CallawaySantAnna( + estimation_method="reg", panel=False, n_bootstrap=49, seed=42 + ).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert result.bootstrap_results is not None + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Data Generator +# ============================================================================= + + +class TestDataGenerator: + """Test the RCS data generator.""" + + def test_rc_data_structure(self): + data = generate_staggered_data(n_units=100, n_periods=5, panel=False, seed=99) + # Each observation should have a unique unit ID + assert data["unit"].nunique() == len(data) + # Should have n_units * n_periods rows + assert len(data) == 100 * 5 + # Each period should have n_units observations + assert all(data.groupby("period")["unit"].count() == 100) + + def test_panel_data_unchanged(self): + """panel=True (default) should produce panel data.""" + data = generate_staggered_data(n_units=50, n_periods=4, panel=True, seed=42) + # Units should repeat across periods + assert data["unit"].nunique() < len(data) + assert data["unit"].nunique() == 50 + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge cases for RCS.""" + + def test_empty_cell_nan(self): + """(g,t) cell with no observations should be NaN, not crash.""" + data = generate_staggered_data(n_units=50, n_periods=4, panel=False, seed=42) + # This should handle cells with few/no observations gracefully + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + ) + # Should produce at least some finite effects + finite_effects = [ + v["effect"] for v in result.group_time_effects.values() if np.isfinite(v["effect"]) + ] + assert len(finite_effects) > 0 + + +# ============================================================================= +# Methodology: IF corrections change SE +# ============================================================================= + + +class TestIFCorrections: + """Verify RCS DR/IPW IF corrections are non-trivial.""" + + def test_dr_se_differs_from_reg_rc(self, rc_data_with_covariates): + """DR and reg should give different SEs in RCS (DR has IF corrections).""" + r_reg = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + r_dr = CallawaySantAnna(estimation_method="dr", panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + # SEs should differ (DR has nuisance IF corrections) + assert r_reg.overall_se != r_dr.overall_se + + def test_panel_field_on_results(self, rc_data): + """panel=False should be reflected on CallawaySantAnnaResults.""" + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert result.panel is False + + def test_summary_labels_rcs(self, rc_data): + """Summary should use 'obs' labels for RCS, not 'units'.""" + result = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + summary = result.summary() + assert "obs:" in summary + assert "units:" not in summary.split("\n")[3] # Treated line + + +# ============================================================================= +# Analytical vs Bootstrap SE convergence (proves IF scaling is correct) +# ============================================================================= + + +class TestAnalyticalBootstrapConvergence: + """Analytical SE should closely match bootstrap SE — proves IF magnitude is correct.""" + + def test_reg_se_matches_bootstrap(self, rc_data_with_covariates): + """Analytical reg SE should be within 20% of bootstrap SE.""" + r_analytical = CallawaySantAnna(estimation_method="reg", panel=False).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + r_bootstrap = CallawaySantAnna( + estimation_method="reg", panel=False, n_bootstrap=499, seed=42 + ).fit( + rc_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + # ATTs should match (bootstrap doesn't change point estimate) + np.testing.assert_allclose(r_analytical.overall_att, r_bootstrap.overall_att, atol=1e-10) + # SEs should be within 10% (proves IF scaling is correct) + ratio = r_analytical.overall_se / r_bootstrap.overall_se + assert 0.9 < ratio < 1.1, ( + f"Analytical/bootstrap SE ratio {ratio:.3f} outside [0.9, 1.1] — " + f"analytical={r_analytical.overall_se:.4f}, bootstrap={r_bootstrap.overall_se:.4f}" + ) + + +# ============================================================================= +# Unequal Cohort Counts Across Periods +# ============================================================================= + + +class TestUnequalCohortCounts: + """Tests with n_gt != n_gs — catches normalizer/weight bugs.""" + + @pytest.fixture + def unequal_rc_data(self): + """RCS data where cohort sizes differ across periods.""" + rng = np.random.default_rng(77) + records = [] + # 4 periods, cohort g=2 treated at period 2 + for period in range(4): + # Vary n_per_period so cohort counts differ across periods + n_per_period = 100 + period * 30 # 100, 130, 160, 190 + for i in range(n_per_period): + # ~30% treated (cohort 2), ~70% never-treated + ft = 2 if rng.random() < 0.3 else 0 + treated = (ft > 0) and (period >= ft) + y = rng.normal(0, 1) + (2.0 if treated else 0.0) + records.append( + { + "unit": f"u{period}_{i}", + "period": period, + "outcome": y, + "first_treat": ft, + } + ) + import pandas as pd + + return pd.DataFrame(records) + + @pytest.mark.parametrize("method", ["reg", "ipw", "dr"]) + def test_finite_results_unequal(self, unequal_rc_data, method): + result = CallawaySantAnna(estimation_method=method, panel=False).fit( + unequal_rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + def test_dr_covariates_unequal(self, unequal_rc_data): + """DR with covariates under unequal cohort counts.""" + rng = np.random.default_rng(77) + unequal_rc_data["x1"] = rng.normal(0, 1, len(unequal_rc_data)) + result = CallawaySantAnna(estimation_method="dr", panel=False).fit( + unequal_rc_data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + + def test_bootstrap_unequal(self, unequal_rc_data): + """Bootstrap with unequal cohort counts.""" + result = CallawaySantAnna( + estimation_method="reg", panel=False, n_bootstrap=49, seed=42 + ).fit( + unequal_rc_data, + "outcome", + "unit", + "period", + "first_treat", + ) + assert result.bootstrap_results is not None + assert np.isfinite(result.overall_att) diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index f4781b6..b8ca4dc 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -926,35 +926,39 @@ def test_bootstrap_survey_supported(self, staggered_survey_data, survey_design_w assert np.isfinite(result.overall_att) assert np.isfinite(result.overall_se) - def test_ipw_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """IPW + covariates + survey should raise NotImplementedError.""" + def test_ipw_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """IPW + covariates + survey works (Phase 7a: nuisance IF corrections).""" data = staggered_survey_data.copy() data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - with pytest.raises(NotImplementedError, match="covariates"): - CallawaySantAnna(estimation_method="ipw").fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - survey_design=survey_design_weights_only, - ) + result = CallawaySantAnna(estimation_method="ipw").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 - def test_dr_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """DR + covariates + survey should raise NotImplementedError.""" + def test_dr_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """DR + covariates + survey works (Phase 7a: nuisance IF corrections).""" data = staggered_survey_data.copy() data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - with pytest.raises(NotImplementedError, match="covariates"): - CallawaySantAnna(estimation_method="dr").fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - survey_design=survey_design_weights_only, - ) + result = CallawaySantAnna(estimation_method="dr").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): """Regression + covariates + survey should work (has nuisance IF correction).""" @@ -1517,13 +1521,22 @@ class TestCallawaySantAnnaFullDesignBootstrap: def test_bootstrap_full_design_cs(self, staggered_survey_data): """Bootstrap + full survey (strata+PSU+FPC) works for CS.""" sd = SurveyDesign( - weights="weight", strata="stratum", psu="psu", fpc="fpc", + weights="weight", + strata="stratum", + psu="psu", + fpc="fpc", ) result = CallawaySantAnna( - estimation_method="reg", n_bootstrap=30, seed=42, + estimation_method="reg", + n_bootstrap=30, + seed=42, ).fit( - staggered_survey_data, "outcome", "unit", "period", - "first_treat", survey_design=sd, + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, ) assert np.isfinite(result.overall_att) assert np.isfinite(result.overall_se) diff --git a/tests/test_survey_phase7a.py b/tests/test_survey_phase7a.py new file mode 100644 index 0000000..3c71efc --- /dev/null +++ b/tests/test_survey_phase7a.py @@ -0,0 +1,389 @@ +"""Tests for Phase 7a: CallawaySantAnna IPW/DR + covariates + survey support. + +Covers: DR nuisance IF corrections (PS + OR), IPW unblocking, +scale invariance, uniform-weight equivalence, aggregation, bootstrap, +and edge cases. +""" + +import numpy as np +import pytest + +from diff_diff import CallawaySantAnna, SurveyDesign, generate_staggered_data + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def staggered_data_with_covariates(): + """Staggered panel with a covariate and survey columns.""" + data = generate_staggered_data(n_units=200, n_periods=6, seed=42) + rng = np.random.default_rng(42) + + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + # Covariate: unit-level (constant within unit, varies across units) + unit_x = rng.normal(0, 1, n_units) + data["x1"] = unit_x[idx] + + # Survey design columns (unit-level) + data["weight"] = (1.0 + 0.5 * (np.arange(n_units) % 5))[idx] + data["stratum"] = (np.arange(n_units) // 40)[idx] + data["psu"] = (np.arange(n_units) // 10)[idx] + + return data + + +@pytest.fixture(scope="module") +def survey_weights_only(): + return SurveyDesign(weights="weight") + + +@pytest.fixture(scope="module") +def survey_full_design(): + return SurveyDesign(weights="weight", strata="stratum", psu="psu") + + +# ============================================================================= +# Smoke Tests: IPW and DR with covariates + survey produce finite results +# ============================================================================= + + +class TestSmokeIPWDRSurvey: + """Basic smoke tests that IPW/DR + covariates + survey runs and returns + finite results.""" + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_finite_results(self, staggered_data_with_covariates, survey_weights_only, method): + result = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_event_study(self, staggered_data_with_covariates, survey_weights_only, method): + result = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + aggregate="event_study", + ) + assert result.event_study_effects is not None + for e, info in result.event_study_effects.items(): + if info["n_groups"] > 0: + assert np.isfinite(info["effect"]) + assert np.isfinite(info["se"]) + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_full_design(self, staggered_data_with_covariates, survey_full_design, method): + """Strata/PSU design with covariates + IPW/DR.""" + result = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_full_design, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.overall_se > 0 + + +# ============================================================================= +# Scale Invariance: doubling all weights doesn't change ATT +# ============================================================================= + + +class TestScaleInvariance: + """Multiplying all survey weights by a constant should not change ATT.""" + + @pytest.mark.parametrize("method", ["ipw", "dr", "reg"]) + def test_double_weights_same_att(self, staggered_data_with_covariates, method): + data = staggered_data_with_covariates.copy() + + # Fit with original weights + sd1 = SurveyDesign(weights="weight") + r1 = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd1, + ) + + # Fit with doubled weights + data["weight2"] = data["weight"] * 2 + sd2 = SurveyDesign(weights="weight2") + r2 = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd2, + ) + + np.testing.assert_allclose(r1.overall_att, r2.overall_att, atol=1e-10) + + +# ============================================================================= +# Uniform Weights Match Unweighted +# ============================================================================= + + +class TestUniformWeightsMatchUnweighted: + """Survey weights = 1.0 for all should match the no-survey path.""" + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_uniform_weights(self, method): + data = generate_staggered_data(n_units=100, n_periods=5, seed=123) + rng = np.random.default_rng(123) + data["x1"] = rng.normal(0, 1, len(data)) + data["weight_ones"] = 1.0 + + # No survey + r_no_survey = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + + # Survey with uniform weights + sd = SurveyDesign(weights="weight_ones") + r_survey = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd, + ) + + # ATTs should be very close (not exact due to weight normalization) + np.testing.assert_allclose(r_no_survey.overall_att, r_survey.overall_att, rtol=1e-6) + + +# ============================================================================= +# IF Correction Non-Zero: corrected IF differs from plug-in +# ============================================================================= + + +class TestIFCorrectionNonZero: + """The DR nuisance IF corrections should make a difference (non-trivial).""" + + def test_dr_se_differs_from_reg(self, staggered_data_with_covariates, survey_weights_only): + """DR and reg methods should give different SEs (DR has PS correction).""" + r_reg = CallawaySantAnna(estimation_method="reg").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + r_dr = CallawaySantAnna(estimation_method="dr").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + # SEs should differ because DR has additional IF corrections + assert r_reg.overall_se != r_dr.overall_se + + +# ============================================================================= +# All 3 Methods Agree Without Covariates +# ============================================================================= + + +class TestMethodsAgreeNoCovariates: + """Without covariates, reg/ipw/dr should give identical ATTs under survey.""" + + def test_no_covariate_methods_agree(self, staggered_data_with_covariates, survey_weights_only): + results = {} + for method in ["reg", "ipw", "dr"]: + r = CallawaySantAnna(estimation_method=method).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_weights_only, + ) + results[method] = r.overall_att + + np.testing.assert_allclose(results["reg"], results["ipw"], atol=1e-10) + np.testing.assert_allclose(results["reg"], results["dr"], atol=1e-10) + + +# ============================================================================= +# Aggregation +# ============================================================================= + + +class TestAggregation: + """Aggregation types work with DR + covariates + survey.""" + + @pytest.mark.parametrize("agg", ["simple", "event_study", "group"]) + def test_aggregation(self, staggered_data_with_covariates, survey_weights_only, agg): + result = CallawaySantAnna(estimation_method="dr").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + aggregate=agg, + ) + assert np.isfinite(result.overall_att) + if agg == "event_study": + assert result.event_study_effects is not None + if agg == "group": + assert result.group_effects is not None + + +# ============================================================================= +# Bootstrap +# ============================================================================= + + +class TestBootstrap: + """Bootstrap with IPW/DR + covariates + survey.""" + + @pytest.mark.parametrize("method", ["ipw", "dr"]) + def test_bootstrap_runs(self, staggered_data_with_covariates, survey_weights_only, method): + result = CallawaySantAnna(estimation_method=method, n_bootstrap=49, seed=42).fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + assert result.bootstrap_results is not None + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge cases for IPW/DR + survey + covariates.""" + + def test_single_covariate(self): + """Works with a single binary covariate.""" + data = generate_staggered_data(n_units=100, seed=99) + rng = np.random.default_rng(99) + data["binary_x"] = rng.choice([0, 1], len(data)) + data["weight"] = rng.uniform(0.5, 2.0, len(data)).round(1) + # Make weights constant within unit + unit_w = data.groupby("unit")["weight"].first() + data["weight"] = data["unit"].map(unit_w) + + sd = SurveyDesign(weights="weight") + for method in ["ipw", "dr"]: + result = CallawaySantAnna(estimation_method=method).fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["binary_x"], + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + + def test_not_yet_treated_control(self, staggered_data_with_covariates, survey_weights_only): + """control_group='not_yet_treated' with DR + covariates + survey.""" + result = CallawaySantAnna(estimation_method="dr", control_group="not_yet_treated").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + + def test_universal_base_period(self, staggered_data_with_covariates, survey_weights_only): + """base_period='universal' with DR + covariates + survey.""" + result = CallawaySantAnna(estimation_method="dr", base_period="universal").fit( + staggered_data_with_covariates, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_weights_only, + ) + assert np.isfinite(result.overall_att) + + +# ============================================================================= +# Non-Survey DR IF Corrections +# ============================================================================= + + +class TestNonSurveyDRIFCorrections: + """Verify that non-survey DR path also has IF corrections.""" + + def test_dr_se_differs_from_reg_no_survey(self): + """DR and reg should give different SEs without survey (IF corrections).""" + data = generate_staggered_data(n_units=150, n_periods=6, seed=42) + rng = np.random.default_rng(42) + data["x1"] = rng.normal(0, 1, len(data)) + + r_reg = CallawaySantAnna(estimation_method="reg").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + r_dr = CallawaySantAnna(estimation_method="dr").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + ) + # SEs should differ (DR has nuisance IF corrections) + assert r_reg.overall_se != r_dr.overall_se + # But ATTs should be similar (both consistent under correct specification) + assert abs(r_reg.overall_att - r_dr.overall_att) < 1.0