diff --git a/CHANGELOG.md b/CHANGELOG.md index 445f31f0..d9d4e13a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- **WooldridgeDiD (ETWFE)** estimator — Extended Two-Way Fixed Effects from Wooldridge (2025, 2023). Supports OLS, logit, and Poisson QMLE paths with ASF-based ATT and delta-method SEs. Four aggregation types (simple, group, calendar, event) matching Stata `jwdid_estat`. Alias: `ETWFE`. (PR #216, thanks @wenddymacro) + ## [2.8.4] - 2026-04-04 ### Added diff --git a/README.md b/README.md index b007b28f..42cc866c 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ Detailed guide: [`docs/llms-practitioner.txt`](docs/llms-practitioner.txt) - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights - **Panel data support**: Two-way fixed effects estimator for panel designs - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects -- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), and Efficient DiD (Chen, Sant'Anna & Xie 2025) estimators for heterogeneous treatment timing +- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, Two-Stage DiD (Gardner 2022), Stacked DiD (Wing, Freedman & Hollingsworth 2024), Efficient DiD (Chen, Sant'Anna & Xie 2025), and Wooldridge ETWFE (2021/2023) estimators for heterogeneous treatment timing - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025) @@ -117,6 +117,7 @@ All estimators have short aliases for convenience: | `Stacked` | `StackedDiD` | Stacked DiD | | `Bacon` | `BaconDecomposition` | Goodman-Bacon decomposition | | `EDiD` | `EfficientDiD` | Efficient DiD | +| `ETWFE` | `WooldridgeDiD` | Wooldridge ETWFE (2021/2023) | `TROP` already uses its short canonical name and needs no alias. diff --git a/ROADMAP.md b/ROADMAP.md index 058f2ba1..d9868de3 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -82,11 +82,7 @@ Implements local projections for dynamic treatment effects. Doesn't require spec ### Nonlinear DiD -For outcomes where linear models are inappropriate (binary, count, bounded). - -- Logit/probit DiD for binary outcomes -- Poisson DiD for count outcomes -- Proper handling of incidence rate ratios and odds ratios +Implemented in `WooldridgeDiD` (alias `ETWFE`) — OLS, Poisson QMLE, and logit paths with ASF-based ATT. See [Tutorial 16](docs/tutorials/16_wooldridge_etwfe.ipynb). **Reference**: [Wooldridge (2023)](https://academic.oup.com/ectj/article/26/3/C31/7250479). *The Econometrics Journal*. diff --git a/TODO.md b/TODO.md index 9c1aad97..c76e0f11 100644 --- a/TODO.md +++ b/TODO.md @@ -72,6 +72,10 @@ Deferred items from PR reviews that were not addressed before merge. | StaggeredTripleDifference: per-cohort group-effect SEs include WIF (conservative vs R's wif=NULL). Documented in REGISTRY. Could override mixin for exact R match. | `staggered_triple_diff.py` | #245 | Low | | HonestDiD Delta^RM: uses naive FLCI instead of paper's ARP conditional/hybrid confidence sets (Sections 3.2.1-3.2.2). ARP infrastructure exists but moment inequality transformation needs calibration. CIs are conservative (wider, valid coverage). | `honest_did.py` | #248 | Medium | | Replicate weight tests use Fay-like BRR perturbations (0.5/1.5), not true half-sample BRR. Add true BRR regressions per estimator family. Existing `test_survey_phase6.py` covers true BRR at the helper level. | `tests/test_replicate_weight_expansion.py` | #253 | Low | +| WooldridgeDiD: QMLE sandwich uses `aweight` cluster-robust adjustment `(G/(G-1))*(n-1)/(n-k)` vs Stata's `G/(G-1)` only. Conservative (inflates SEs). Add `qmle` weight type if Stata golden values confirm material difference. | `wooldridge.py`, `linalg.py` | #216 | Medium | +| WooldridgeDiD: aggregation weights use cell-level n_{g,t} counts. Paper (W2025 Eqs. 7.2-7.4) defines cohort-share weights. Add optional `weights="cohort_share"` parameter to `aggregate()`. | `wooldridge_results.py` | #216 | Medium | +| WooldridgeDiD: canonical link requirement (W2023 Prop 3.1) not enforced — no warning if user applies wrong method to outcome type. Estimator is consistent regardless, but equivalence with imputation breaks. | `wooldridge.py` | #216 | Low | +| WooldridgeDiD: Stata `jwdid` golden value tests — add R/Stata reference script and `TestReferenceValues` class. | `tests/test_wooldridge.py` | #216 | Medium | #### Performance diff --git a/benchmarks/python/benchmark_wooldridge.py b/benchmarks/python/benchmark_wooldridge.py new file mode 100644 index 00000000..7930a39b --- /dev/null +++ b/benchmarks/python/benchmark_wooldridge.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Benchmark: WooldridgeDiD (ETWFE) Estimator (diff-diff WooldridgeDiD). + +Validates OLS ETWFE ATT(g,t) against Callaway-Sant'Anna on mpdta data +(Proposition 3.1 equivalence), and measures estimation timing. + +Usage: + python benchmark_wooldridge.py --data path/to/mpdta.csv --output path/to/results.json +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff +def _get_backend_from_args(): + """Parse --backend argument without importing diff_diff.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"]) + args, _ = parser.parse_known_args() + return args.backend + +_requested_backend = _get_backend_from_args() +if _requested_backend in ("python", "rust"): + os.environ["DIFF_DIFF_BACKEND"] = _requested_backend + +# NOW import diff_diff and other dependencies (will see the env var) +import pandas as pd + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from diff_diff import WooldridgeDiD, HAS_RUST_BACKEND +from benchmarks.python.utils import Timer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark WooldridgeDiD (ETWFE) estimator" + ) + parser.add_argument("--data", required=True, help="Path to input CSV data (mpdta format)") + parser.add_argument("--output", required=True, help="Path to output JSON results") + parser.add_argument( + "--backend", default="auto", choices=["auto", "python", "rust"], + help="Backend to use: auto (default), python (pure Python), rust (Rust backend)" + ) + return parser.parse_args() + + +def get_actual_backend() -> str: + """Return the actual backend being used based on HAS_RUST_BACKEND.""" + return "rust" if HAS_RUST_BACKEND else "python" + + +def main(): + args = parse_args() + + actual_backend = get_actual_backend() + print(f"Using backend: {actual_backend}") + + print(f"Loading data from: {args.data}") + df = pd.read_csv(args.data) + + # Run OLS ETWFE estimation + print("Running WooldridgeDiD (OLS ETWFE) estimation...") + est = WooldridgeDiD(method="ols", control_group="not_yet_treated") + + with Timer() as estimation_timer: + results = est.fit( + df, + outcome="lemp", + unit="countyreal", + time="year", + cohort="first_treat", + ) + + estimation_time = estimation_timer.elapsed + + # Compute event study aggregation + results.aggregate("event") + total_time = estimation_timer.elapsed + + # Store data info + n_units = len(df["countyreal"].unique()) + n_periods = len(df["year"].unique()) + n_obs = len(df) + + # Format ATT(g,t) effects + gt_effects_out = [] + for (g, t), cell in sorted(results.group_time_effects.items()): + gt_effects_out.append({ + "cohort": int(g), + "time": int(t), + "att": float(cell["att"]), + "se": float(cell["se"]), + }) + + # Format event study effects + es_effects = [] + if results.event_study_effects: + for rel_t, effect_data in sorted(results.event_study_effects.items()): + es_effects.append({ + "event_time": int(rel_t), + "att": float(effect_data["att"]), + "se": float(effect_data["se"]), + }) + + output = { + "estimator": "diff_diff.WooldridgeDiD", + "method": "ols", + "control_group": "not_yet_treated", + "backend": actual_backend, + # Overall ATT + "overall_att": float(results.overall_att), + "overall_se": float(results.overall_se), + # Group-time ATT(g,t) + "group_time_effects": gt_effects_out, + # Event study + "event_study": es_effects, + # Timing + "timing": { + "estimation_seconds": estimation_time, + "total_seconds": total_time, + }, + # Metadata + "metadata": { + "n_units": n_units, + "n_periods": n_periods, + "n_obs": n_obs, + "n_cohorts": len(results.groups), + }, + } + + print(f"Writing results to: {args.output}") + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + print(f"Overall ATT: {results.overall_att:.6f} (SE: {results.overall_se:.6f})") + print(f"Completed in {total_time:.3f} seconds") + return output + + +if __name__ == "__main__": + main() diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 423f2386..2fe60801 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -164,6 +164,8 @@ TROPResults, trop, ) +from diff_diff.wooldridge import WooldridgeDiD +from diff_diff.wooldridge_results import WooldridgeDiDResults from diff_diff.utils import ( WildBootstrapResults, check_parallel_trends, @@ -210,6 +212,7 @@ Stacked = StackedDiD Bacon = BaconDecomposition EDiD = EfficientDiD +ETWFE = WooldridgeDiD __version__ = "2.8.4" __all__ = [ @@ -276,6 +279,10 @@ "EfficientDiDResults", "EDiDBootstrapResults", "EDiD", + # WooldridgeDiD (ETWFE) + "WooldridgeDiD", + "WooldridgeDiDResults", + "ETWFE", # Visualization "plot_bacon", "plot_event_study", diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index 968c87a9..22041bfe 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -2342,3 +2342,69 @@ def _compute_confidence_interval( upper = estimate + critical_value * se return (lower, upper) + + +def solve_poisson( + X: np.ndarray, + y: np.ndarray, + max_iter: int = 200, + tol: float = 1e-8, + init_beta: Optional[np.ndarray] = None, +) -> Tuple[np.ndarray, np.ndarray]: + """Poisson IRLS (Newton-Raphson with log link). + + Does NOT prepend an intercept — caller must include one if needed. + Returns (beta, W_final) where W_final = mu_hat (used for sandwich vcov). + + Parameters + ---------- + X : (n, k) design matrix (caller provides intercept / group FE dummies) + y : (n,) non-negative count outcomes + max_iter : maximum IRLS iterations + tol : convergence threshold on sup-norm of coefficient change + init_beta : optional starting coefficient vector; if None, zeros are used + with the first column treated as the intercept and initialized to + log(mean(y)) to improve convergence for large-scale outcomes. + + Returns + ------- + beta : (k,) coefficient vector + W : (n,) final fitted means mu_hat (weights for sandwich vcov) + """ + n, k = X.shape + if init_beta is not None: + beta = init_beta.copy() + else: + beta = np.zeros(k) + # Initialise the intercept to log(mean(y)) so the first IRLS step + # starts near the unconditional mean rather than exp(0)=1, which + # causes overflow when y is large (e.g. employment levels). + mean_y = float(np.mean(y)) + if mean_y > 0: + beta[0] = np.log(mean_y) + for _ in range(max_iter): + eta = np.clip(X @ beta, -500, 500) + mu = np.exp(eta) + score = X.T @ (y - mu) # gradient of log-likelihood + hess = X.T @ (mu[:, None] * X) # -Hessian = X'WX, W=diag(mu) + try: + delta = np.linalg.solve(hess + 1e-12 * np.eye(k), score) + except np.linalg.LinAlgError: + break + # Damped step: cap the maximum coefficient change to avoid overshooting + max_step = np.max(np.abs(delta)) + if max_step > 1.0: + delta = delta / max_step + beta_new = beta + delta + if np.max(np.abs(beta_new - beta)) < tol: + beta = beta_new + break + beta = beta_new + else: + warnings.warn( + "solve_poisson did not converge in {} iterations".format(max_iter), + RuntimeWarning, + stacklevel=2, + ) + mu_final = np.exp(np.clip(X @ beta, -500, 500)) + return beta, mu_final diff --git a/diff_diff/wooldridge.py b/diff_diff/wooldridge.py new file mode 100644 index 00000000..c703aa9d --- /dev/null +++ b/diff_diff/wooldridge.py @@ -0,0 +1,826 @@ +"""WooldridgeDiD: Extended Two-Way Fixed Effects (ETWFE) estimator. + +Implements Wooldridge (2025, 2023) ETWFE, faithful to the Stata jwdid package. + +References +---------- +Wooldridge (2025). Two-Way Fixed Effects, the Two-Way Mundlak Regression, + and Difference-in-Differences Estimators. Empirical Economics, 69(5), 2545-2587. +Wooldridge (2023). Simple approaches to nonlinear difference-in-differences + with panel data. The Econometrics Journal, 26(3), C31-C66. +Friosavila (2021). jwdid: Stata module. SSC s459114. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.linalg import compute_robust_vcov, solve_logit, solve_ols, solve_poisson +from diff_diff.utils import safe_inference, within_transform +from diff_diff.wooldridge_results import WooldridgeDiDResults + +_VALID_METHODS = ("ols", "logit", "poisson") +_VALID_CONTROL_GROUPS = ("never_treated", "not_yet_treated") +_VALID_BOOTSTRAP_WEIGHTS = ("rademacher", "webb", "mammen") + + +def _logistic(x: np.ndarray) -> np.ndarray: + return 1.0 / (1.0 + np.exp(-x)) + + +def _logistic_deriv(x: np.ndarray) -> np.ndarray: + p = _logistic(x) + return p * (1.0 - p) + + +def _compute_weighted_agg( + gt_effects: Dict, + gt_weights: Dict, + gt_keys: List, + gt_vcov: Optional[np.ndarray], + alpha: float, +) -> Dict: + """Compute simple (overall) weighted average ATT and SE via delta method.""" + post_keys = [(g, t) for (g, t) in gt_keys if t >= g] + w_total = sum(gt_weights.get(k, 0) for k in post_keys) + if w_total == 0: + att = float("nan") + se = float("nan") + else: + att = ( + sum(gt_weights.get(k, 0) * gt_effects[k]["att"] for k in post_keys if k in gt_effects) + / w_total + ) + if gt_vcov is not None: + w_vec = np.array( + [gt_weights.get(k, 0) / w_total if k in post_keys else 0.0 for k in gt_keys] + ) + var = float(w_vec @ gt_vcov @ w_vec) + se = float(np.sqrt(max(var, 0.0))) + else: + se = float("nan") + + t_stat, p_value, conf_int = safe_inference(att, se, alpha=alpha) + return {"att": att, "se": se, "t_stat": t_stat, "p_value": p_value, "conf_int": conf_int} + + +def _filter_sample( + data: pd.DataFrame, + unit: str, + time: str, + cohort: str, + control_group: str, + anticipation: int, +) -> pd.DataFrame: + """Return the analysis sample following jwdid selection rules. + + For "not_yet_treated": keep all observations from treated units (pre- and + post-treatment) plus all never-treated and not-yet-treated observations. + For "never_treated": keep only post-treatment observations from treated + units (t >= g - anticipation) plus all never-treated observations. + Pre-treatment observations from treated units are excluded so they do not + serve as implicit controls in the regression baseline. + """ + df = data.copy() + # Normalise never-treated: fill NaN cohort with 0 + df[cohort] = df[cohort].fillna(0) + + if control_group == "never_treated": + # Post-treatment obs from treated units + all never-treated obs. + # Pre-treatment obs from treated units are excluded so the + # counterfactual is identified solely from never-treated units. + treated_mask = (df[cohort] > 0) & (df[time] >= df[cohort] - anticipation) + control_mask = df[cohort] == 0 + else: # not_yet_treated + # All treated-unit obs + never-treated + not-yet-treated obs + treated_mask = df[cohort] > 0 + control_mask = (df[cohort] == 0) | (df[cohort] > df[time]) + + return df[treated_mask | control_mask].copy() + + +def _build_interaction_matrix( + data: pd.DataFrame, + cohort: str, + time: str, + anticipation: int, +) -> Tuple[np.ndarray, List[str], List[Tuple[Any, Any]]]: + """Build the saturated cohort×time interaction design matrix. + + Returns + ------- + X_int : (n, n_cells) binary indicator matrix + col_names : list of string labels "g{g}_t{t}" + gt_keys : list of (g, t) tuples in same column order + """ + groups = sorted(g for g in data[cohort].unique() if g > 0) + times = sorted(data[time].unique()) + cohort_vals = data[cohort].values + time_vals = data[time].values + + cols = [] + col_names = [] + gt_keys = [] + + for g in groups: + for t in times: + if t >= g - anticipation: + indicator = ((cohort_vals == g) & (time_vals == t)).astype(float) + cols.append(indicator) + col_names.append(f"g{g}_t{t}") + gt_keys.append((g, t)) + + if not cols: + return np.empty((len(data), 0)), [], [] + return np.column_stack(cols), col_names, gt_keys + + +def _prepare_covariates( + data: pd.DataFrame, + exovar: Optional[List[str]], + xtvar: Optional[List[str]], + xgvar: Optional[List[str]], + cohort: str, + time: str, + demean_covariates: bool, + groups: List[Any], +) -> Optional[np.ndarray]: + """Build covariate matrix following jwdid covariate type conventions. + + Returns None if no covariates, else (n, k) array. + """ + parts = [] + + if exovar: + parts.append(data[exovar].values.astype(float)) + + if xtvar: + if demean_covariates: + # Within-cohort×period demeaning + grp_key = data[cohort].astype(str) + "_" + data[time].astype(str) + tmp = data[xtvar].copy() + for col in xtvar: + tmp[col] = tmp[col] - tmp.groupby(grp_key)[col].transform("mean") + parts.append(tmp.values.astype(float)) + else: + parts.append(data[xtvar].values.astype(float)) + + if xgvar: + for g in groups: + g_indicator = (data[cohort] == g).values.astype(float) + for col in xgvar: + parts.append((g_indicator * data[col].values).reshape(-1, 1)) + + if not parts: + return None + return np.hstack([p if p.ndim == 2 else p.reshape(-1, 1) for p in parts]) + + +class WooldridgeDiD: + """Extended Two-Way Fixed Effects (ETWFE) DiD estimator. + + Implements the Wooldridge (2021) saturated cohort×time regression and + Wooldridge (2023) nonlinear extensions (logit, Poisson). Produces all + four ``jwdid_estat`` aggregation types: simple, group, calendar, event. + + Parameters + ---------- + method : {"ols", "logit", "poisson"} + Estimation method. "ols" for continuous outcomes; "logit" for binary + or fractional outcomes; "poisson" for count data. + control_group : {"not_yet_treated", "never_treated"} + Which units serve as the comparison group. "not_yet_treated" (jwdid + default) uses all untreated observations at each time period; + "never_treated" uses only units never treated throughout the sample. + anticipation : int + Number of periods before treatment onset to include as treatment cells + (anticipation effects). 0 means no anticipation. + demean_covariates : bool + If True (jwdid default), ``xtvar`` covariates are demeaned within each + cohort×period cell before entering the regression. Set to False to + replicate jwdid's ``xasis`` option. + alpha : float + Significance level for confidence intervals. + cluster : str or None + Column name to use for cluster-robust SEs. Defaults to the ``unit`` + identifier passed to ``fit()``. + n_bootstrap : int + Number of bootstrap replications. 0 disables bootstrap. + bootstrap_weights : {"rademacher", "webb", "mammen"} + Bootstrap weight distribution. + seed : int or None + Random seed for reproducibility. + rank_deficient_action : {"warn", "error", "silent"} + How to handle rank-deficient design matrices. + """ + + def __init__( + self, + method: str = "ols", + control_group: str = "not_yet_treated", + anticipation: int = 0, + demean_covariates: bool = True, + alpha: float = 0.05, + cluster: Optional[str] = None, + n_bootstrap: int = 0, + bootstrap_weights: str = "rademacher", + seed: Optional[int] = None, + rank_deficient_action: str = "warn", + ) -> None: + if method not in _VALID_METHODS: + raise ValueError(f"method must be one of {_VALID_METHODS}, got {method!r}") + if control_group not in _VALID_CONTROL_GROUPS: + raise ValueError( + f"control_group must be one of {_VALID_CONTROL_GROUPS}, got {control_group!r}" + ) + if anticipation < 0: + raise ValueError(f"anticipation must be >= 0, got {anticipation}") + if bootstrap_weights not in _VALID_BOOTSTRAP_WEIGHTS: + raise ValueError( + f"bootstrap_weights must be one of {_VALID_BOOTSTRAP_WEIGHTS}, " + f"got {bootstrap_weights!r}" + ) + + self.method = method + self.control_group = control_group + self.anticipation = anticipation + self.demean_covariates = demean_covariates + self.alpha = alpha + self.cluster = cluster + self.n_bootstrap = n_bootstrap + self.bootstrap_weights = bootstrap_weights + self.seed = seed + self.rank_deficient_action = rank_deficient_action + + self.is_fitted_: bool = False + self._results: Optional[WooldridgeDiDResults] = None + + @property + def results_(self) -> WooldridgeDiDResults: + if not self.is_fitted_: + raise RuntimeError("Call fit() before accessing results_") + return self._results # type: ignore[return-value] + + def get_params(self) -> Dict[str, Any]: + """Return estimator parameters (sklearn-compatible).""" + return { + "method": self.method, + "control_group": self.control_group, + "anticipation": self.anticipation, + "demean_covariates": self.demean_covariates, + "alpha": self.alpha, + "cluster": self.cluster, + "n_bootstrap": self.n_bootstrap, + "bootstrap_weights": self.bootstrap_weights, + "seed": self.seed, + "rank_deficient_action": self.rank_deficient_action, + } + + def set_params(self, **params: Any) -> "WooldridgeDiD": + """Set estimator parameters (sklearn-compatible). Returns self.""" + for key, value in params.items(): + if not hasattr(self, key): + raise ValueError(f"Unknown parameter: {key!r}") + setattr(self, key, value) + return self + + def fit( + self, + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + cohort: str, + exovar: Optional[List[str]] = None, + xtvar: Optional[List[str]] = None, + xgvar: Optional[List[str]] = None, + ) -> WooldridgeDiDResults: + """Fit the ETWFE model. See class docstring for parameter details. + + Parameters + ---------- + data : DataFrame with panel data (long format) + outcome : outcome column name + unit : unit identifier column + time : time period column + cohort : first treatment period (0 or NaN = never treated) + exovar : time-invariant covariates added without interaction/demeaning + xtvar : time-varying covariates (demeaned within cohort×period cells + when ``demean_covariates=True``) + xgvar : covariates interacted with each cohort indicator + """ + df = data.copy() + df[cohort] = df[cohort].fillna(0) + + # 1. Filter to analysis sample + sample = _filter_sample(df, unit, time, cohort, self.control_group, self.anticipation) + + # 2. Build interaction matrix + X_int, int_col_names, gt_keys = _build_interaction_matrix( + sample, cohort=cohort, time=time, anticipation=self.anticipation + ) + + # 3. Covariates + groups = sorted(g for g in sample[cohort].unique() if g > 0) + X_cov = _prepare_covariates( + sample, + exovar=exovar, + xtvar=xtvar, + xgvar=xgvar, + cohort=cohort, + time=time, + demean_covariates=self.demean_covariates, + groups=groups, + ) + + all_regressors = int_col_names.copy() + if X_cov is not None: + X_design = np.hstack([X_int, X_cov]) + for i in range(X_cov.shape[1]): + all_regressors.append(f"_cov_{i}") + else: + X_design = X_int + + if self.method == "ols": + results = self._fit_ols( + sample, + outcome, + unit, + time, + cohort, + X_design, + all_regressors, + gt_keys, + int_col_names, + groups, + ) + elif self.method == "logit": + results = self._fit_logit( + sample, + outcome, + unit, + time, + cohort, + X_design, + all_regressors, + gt_keys, + int_col_names, + groups, + ) + else: # poisson + results = self._fit_poisson( + sample, + outcome, + unit, + time, + cohort, + X_design, + all_regressors, + gt_keys, + int_col_names, + groups, + ) + + self._results = results + self.is_fitted_ = True + return results + + def _count_control_units(self, sample: pd.DataFrame, unit: str, cohort: str, time: str) -> int: + """Count control units consistent with control_group setting.""" + n_never = int(sample[sample[cohort] == 0][unit].nunique()) + if self.control_group == "not_yet_treated": + # Also count future-treated units that contribute pre-treatment obs + nyt = sample[(sample[cohort] > 0) & (sample[time] < sample[cohort])][unit].nunique() + return n_never + int(nyt) + return n_never + + def _fit_ols( + self, + sample: pd.DataFrame, + outcome: str, + unit: str, + time: str, + cohort: str, + X_design: np.ndarray, + col_names: List[str], + gt_keys: List[Tuple], + int_col_names: List[str], + groups: List[Any], + ) -> WooldridgeDiDResults: + """OLS path: within-transform FE, solve_ols, cluster SE.""" + # 4. Within-transform: absorb unit + time FE + all_vars = [outcome] + [f"_x{i}" for i in range(X_design.shape[1])] + tmp = sample[[unit, time]].copy() + tmp[outcome] = sample[outcome].values + for i in range(X_design.shape[1]): + tmp[f"_x{i}"] = X_design[:, i] + + transformed = within_transform(tmp, all_vars, unit=unit, time=time, suffix="_demeaned") + + y = transformed[f"{outcome}_demeaned"].values + X_cols = [f"_x{i}_demeaned" for i in range(X_design.shape[1])] + X = transformed[X_cols].values + + # 5. Cluster IDs (default: unit level) + cluster_col = self.cluster if self.cluster else unit + cluster_ids = sample[cluster_col].values + + # 6. Solve OLS + coefs, resids, vcov = solve_ols( + X, + y, + cluster_ids=cluster_ids, + return_vcov=True, + rank_deficient_action=self.rank_deficient_action, + column_names=col_names, + ) + + # 7. Extract β_{g,t} and build gt_effects dict + gt_effects: Dict[Tuple, Dict] = {} + gt_weights: Dict[Tuple, int] = {} + for idx, (g, t) in enumerate(gt_keys): + if idx >= len(coefs): + break + att = float(coefs[idx]) + se = float(np.sqrt(max(vcov[idx, idx], 0.0))) if vcov is not None else float("nan") + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha) + gt_effects[(g, t)] = { + "att": att, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + } + gt_weights[(g, t)] = int(((sample[cohort] == g) & (sample[time] == t)).sum()) + + # Extract vcov submatrix for beta_{g,t} only + n_gt = len(gt_keys) + gt_vcov = vcov[:n_gt, :n_gt] if vcov is not None else None + gt_keys_ordered = list(gt_keys) + + # 8. Simple aggregation (always computed) + overall = _compute_weighted_agg( + gt_effects, gt_weights, gt_keys_ordered, gt_vcov, self.alpha + ) + + # Metadata + n_treated = int(sample[sample[cohort] > 0][unit].nunique()) + n_control = self._count_control_units(sample, unit, cohort, time) + all_times = sorted(sample[time].unique().tolist()) + + results = WooldridgeDiDResults( + group_time_effects=gt_effects, + overall_att=overall["att"], + overall_se=overall["se"], + overall_t_stat=overall["t_stat"], + overall_p_value=overall["p_value"], + overall_conf_int=overall["conf_int"], + method=self.method, + control_group=self.control_group, + groups=groups, + time_periods=all_times, + n_obs=len(sample), + n_treated_units=n_treated, + n_control_units=n_control, + alpha=self.alpha, + _gt_weights=gt_weights, + _gt_vcov=gt_vcov, + _gt_keys=gt_keys_ordered, + ) + + # 9. Optional multiplier bootstrap (overrides analytic SE for overall ATT) + if self.n_bootstrap > 0: + rng = np.random.default_rng(self.seed) + units_arr = sample[unit].values + unique_units = np.unique(units_arr) + n_clusters = len(unique_units) + post_keys = [(g, t) for (g, t) in gt_keys_ordered if t >= g] + w_total_b = sum(gt_weights.get(k, 0) for k in post_keys) + boot_atts: List[float] = [] + for _ in range(self.n_bootstrap): + if self.bootstrap_weights == "rademacher": + unit_weights = rng.choice([-1.0, 1.0], size=n_clusters) + elif self.bootstrap_weights == "webb": + unit_weights = rng.choice( + [-np.sqrt(1.5), -1.0, -np.sqrt(0.5), np.sqrt(0.5), 1.0, np.sqrt(1.5)], + size=n_clusters, + ) + else: # mammen + phi = (1 + np.sqrt(5)) / 2 + unit_weights = rng.choice( + [-(phi - 1), phi], + p=[phi / np.sqrt(5), (phi - 1) / np.sqrt(5)], + size=n_clusters, + ) + obs_weights = unit_weights[np.searchsorted(unique_units, units_arr)] + y_boot = y + obs_weights * resids + coefs_b, _, _ = solve_ols( + X, + y_boot, + cluster_ids=cluster_ids, + return_vcov=False, + rank_deficient_action="silent", + ) + if w_total_b > 0: + att_b = ( + sum( + gt_weights.get(k, 0) * float(coefs_b[i]) + for i, k in enumerate(gt_keys) + if k in post_keys and i < len(coefs_b) + ) + / w_total_b + ) + boot_atts.append(att_b) + if boot_atts: + boot_se = float(np.std(boot_atts, ddof=1)) + t_stat_b, p_b, ci_b = safe_inference(results.overall_att, boot_se, alpha=self.alpha) + results.overall_se = boot_se + results.overall_t_stat = t_stat_b + results.overall_p_value = p_b + results.overall_conf_int = ci_b + + return results + + def _fit_logit( + self, + sample: pd.DataFrame, + outcome: str, + unit: str, + time: str, + cohort: str, + X_int: np.ndarray, + col_names: List[str], + gt_keys: List[Tuple], + int_col_names: List[str], + groups: List[Any], + ) -> WooldridgeDiDResults: + """Logit path: cohort + time additive FEs + solve_logit + ASF ATT. + + Matches Stata jwdid method(logit): logit y [treatment_interactions] + i.gvar i.tvar — cohort main effects + time main effects (additive), + not cohort×time saturated group FEs. + """ + n_int = len(int_col_names) + + # Design matrix: treatment interactions + cohort FEs + time FEs + # This matches Stata's `i.gvar i.tvar` specification. + cohort_dummies = pd.get_dummies(sample[cohort], drop_first=True).values.astype(float) + time_dummies = pd.get_dummies(sample[time], drop_first=True).values.astype(float) + X_full = np.hstack([X_int, cohort_dummies, time_dummies]) + + y = sample[outcome].values.astype(float) + cluster_col = self.cluster if self.cluster else unit + cluster_ids = sample[cluster_col].values + + beta, probs = solve_logit( + X_full, + y, + rank_deficient_action=self.rank_deficient_action, + ) + # solve_logit prepends intercept — beta[0] is intercept, beta[1:] are X_full cols + beta_int_cols = beta[1 : n_int + 1] # treatment interaction coefficients + + # Handle rank-deficient designs: zero out NaN entries so downstream + # matrix ops don't propagate NaN (dropped columns contribute nothing) + nan_mask = np.isnan(beta) + if np.any(nan_mask): + beta = np.where(nan_mask, 0.0, beta) + + # QMLE sandwich vcov via shared linalg backend + resids = y - probs + X_with_intercept = np.column_stack([np.ones(len(y)), X_full]) + vcov_full = compute_robust_vcov( + X_with_intercept, + resids, + cluster_ids=cluster_ids, + weights=probs * (1 - probs), # logit QMLE bread: (X'WX)^{-1} + weight_type="aweight", # unweighted scores for QMLE sandwich + ) + + # ASF ATT(g,t) for treated units in each cell + gt_effects: Dict[Tuple, Dict] = {} + gt_weights: Dict[Tuple, int] = {} + gt_grads: Dict[Tuple, np.ndarray] = {} # store per-cell gradients for aggregate SE + for idx, (g, t) in enumerate(gt_keys): + if idx >= n_int: + break + cell_mask = (sample[cohort] == g) & (sample[time] == t) + if cell_mask.sum() == 0: + continue + # Skip cells whose interaction coefficient was dropped (rank deficiency) + delta = beta_int_cols[idx] + if np.isnan(delta): + continue + eta_base = X_with_intercept[cell_mask] @ beta + # eta_base already contains the treatment effect (D_{g,t}=1 in cell). + # Counterfactual: eta_0 = eta_base - delta (treatment switched off). + # ATT = E[Λ(η_1)] - E[Λ(η_0)] = E[Λ(η_base)] - E[Λ(η_base - δ)] + eta_0 = eta_base - delta + att = float(np.mean(_logistic(eta_base) - _logistic(eta_0))) + # Delta method gradient: d(ATT)/d(β) + # for p ≠ int_idx: mean_i[(Λ'(η_1) - Λ'(η_0)) * X_p] + # for p = int_idx: mean_i[Λ'(η_1)] + d_diff = _logistic_deriv(eta_base) - _logistic_deriv(eta_0) + grad = np.mean(X_with_intercept[cell_mask] * d_diff[:, None], axis=0) + grad[1 + idx] = float(np.mean(_logistic_deriv(eta_base))) + se = float(np.sqrt(max(grad @ vcov_full @ grad, 0.0))) + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha) + gt_effects[(g, t)] = { + "att": att, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + "_gradient": grad.copy(), + } + gt_weights[(g, t)] = int(cell_mask.sum()) + gt_grads[(g, t)] = grad + + gt_keys_ordered = [k for k in gt_keys if k in gt_effects] + # ATT-level covariance: J @ vcov_full @ J' where J rows are per-cell gradients + if gt_keys_ordered: + J = np.array([gt_grads[k] for k in gt_keys_ordered]) + gt_vcov = J @ vcov_full @ J.T + else: + gt_vcov = None + + # Overall SE via joint delta method: ∇β(overall_att) = Σ w_k/w_total * grad_k + post_keys = [(g, t) for (g, t) in gt_keys_ordered if t >= g] + w_total = sum(gt_weights.get(k, 0) for k in post_keys) + if w_total > 0 and post_keys: + overall_att = sum(gt_weights[k] * gt_effects[k]["att"] for k in post_keys) / w_total + agg_grad = sum((gt_weights[k] / w_total) * gt_grads[k] for k in post_keys) + overall_se = float(np.sqrt(max(agg_grad @ vcov_full @ agg_grad, 0.0))) + t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha) + overall = { + "att": overall_att, + "se": overall_se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + } + else: + overall = _compute_weighted_agg( + gt_effects, gt_weights, gt_keys_ordered, None, self.alpha + ) + + return WooldridgeDiDResults( + group_time_effects=gt_effects, + overall_att=overall["att"], + overall_se=overall["se"], + overall_t_stat=overall["t_stat"], + overall_p_value=overall["p_value"], + overall_conf_int=overall["conf_int"], + method=self.method, + control_group=self.control_group, + groups=groups, + time_periods=sorted(sample[time].unique().tolist()), + n_obs=len(sample), + n_treated_units=int(sample[sample[cohort] > 0][unit].nunique()), + n_control_units=self._count_control_units(sample, unit, cohort, time), + alpha=self.alpha, + _gt_weights=gt_weights, + _gt_vcov=gt_vcov, + _gt_keys=gt_keys_ordered, + ) + + def _fit_poisson( + self, + sample: pd.DataFrame, + outcome: str, + unit: str, + time: str, + cohort: str, + X_int: np.ndarray, + col_names: List[str], + gt_keys: List[Tuple], + int_col_names: List[str], + groups: List[Any], + ) -> WooldridgeDiDResults: + """Poisson path: cohort + time additive FEs + solve_poisson + ASF ATT. + + Matches Stata jwdid method(poisson): poisson y [treatment_interactions] + i.gvar i.tvar — cohort main effects + time main effects (additive), + not cohort×time saturated group FEs. + """ + n_int = len(int_col_names) + + # Design matrix: intercept + treatment interactions + cohort FEs + time FEs. + # Matches Stata's `i.gvar i.tvar` + treatment interaction specification. + # solve_poisson does not prepend an intercept, so we include one explicitly. + intercept = np.ones((len(sample), 1)) + cohort_dummies = pd.get_dummies(sample[cohort], drop_first=True).values.astype(float) + time_dummies = pd.get_dummies(sample[time], drop_first=True).values.astype(float) + X_full = np.hstack([intercept, X_int, cohort_dummies, time_dummies]) + # Treatment interaction coefficients start at column index 1. + + y = sample[outcome].values.astype(float) + cluster_col = self.cluster if self.cluster else unit + cluster_ids = sample[cluster_col].values + + beta, mu_hat = solve_poisson(X_full, y) + + # QMLE sandwich vcov via shared linalg backend + resids = y - mu_hat + vcov_full = compute_robust_vcov( + X_full, + resids, + cluster_ids=cluster_ids, + weights=mu_hat, # Poisson QMLE bread: (X'WX)^{-1} + weight_type="aweight", # unweighted scores for QMLE sandwich + ) + + # Treatment interaction coefficients: beta[1 : 1+n_int] + beta_int = beta[1 : 1 + n_int] + + # ASF ATT(g,t) for treated units in each cell. + # eta_base = X_full @ beta already includes the treatment effect (D_{g,t}=1). + # Counterfactual: eta_0 = eta_base - delta (treatment switched off). + # ATT = E[exp(η_1)] - E[exp(η_0)] = E[exp(η_base)] - E[exp(η_base - δ)] + gt_effects: Dict[Tuple, Dict] = {} + gt_weights: Dict[Tuple, int] = {} + gt_grads: Dict[Tuple, np.ndarray] = {} # per-cell gradients for aggregate SE + for idx, (g, t) in enumerate(gt_keys): + if idx >= n_int: + break + cell_mask = (sample[cohort] == g) & (sample[time] == t) + if cell_mask.sum() == 0: + continue + # Skip cells whose interaction coefficient was dropped (rank deficiency) + delta = beta_int[idx] + if np.isnan(delta): + continue + eta_base = np.clip(X_full[cell_mask] @ beta, -500, 500) + eta_0 = eta_base - delta + mu_1 = np.exp(eta_base) + mu_0 = np.exp(eta_0) + att = float(np.mean(mu_1 - mu_0)) + # Delta method gradient: + # for p ≠ int_idx: mean_i[(μ_1 - μ_0) * X_p] + # for p = int_idx: mean_i[μ_1] + diff_mu = mu_1 - mu_0 + grad = np.mean(X_full[cell_mask] * diff_mu[:, None], axis=0) + grad[1 + idx] = float(np.mean(mu_1)) + se = float(np.sqrt(max(grad @ vcov_full @ grad, 0.0))) + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha) + gt_effects[(g, t)] = { + "att": att, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + "_gradient": grad.copy(), + } + gt_weights[(g, t)] = int(cell_mask.sum()) + gt_grads[(g, t)] = grad + + gt_keys_ordered = [k for k in gt_keys if k in gt_effects] + # ATT-level covariance: J @ vcov_full @ J' where J rows are per-cell gradients + if gt_keys_ordered: + J = np.array([gt_grads[k] for k in gt_keys_ordered]) + gt_vcov = J @ vcov_full @ J.T + else: + gt_vcov = None + + # Overall SE via joint delta method + post_keys = [(g, t) for (g, t) in gt_keys_ordered if t >= g] + w_total = sum(gt_weights.get(k, 0) for k in post_keys) + if w_total > 0 and post_keys: + overall_att = sum(gt_weights[k] * gt_effects[k]["att"] for k in post_keys) / w_total + agg_grad = sum((gt_weights[k] / w_total) * gt_grads[k] for k in post_keys) + overall_se = float(np.sqrt(max(agg_grad @ vcov_full @ agg_grad, 0.0))) + t_stat, p_value, conf_int = safe_inference(overall_att, overall_se, alpha=self.alpha) + overall = { + "att": overall_att, + "se": overall_se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + } + else: + overall = _compute_weighted_agg( + gt_effects, gt_weights, gt_keys_ordered, None, self.alpha + ) + + return WooldridgeDiDResults( + group_time_effects=gt_effects, + overall_att=overall["att"], + overall_se=overall["se"], + overall_t_stat=overall["t_stat"], + overall_p_value=overall["p_value"], + overall_conf_int=overall["conf_int"], + method=self.method, + control_group=self.control_group, + groups=groups, + time_periods=sorted(sample[time].unique().tolist()), + n_obs=len(sample), + n_treated_units=int(sample[sample[cohort] > 0][unit].nunique()), + n_control_units=self._count_control_units(sample, unit, cohort, time), + alpha=self.alpha, + _gt_weights=gt_weights, + _gt_vcov=gt_vcov, + _gt_keys=gt_keys_ordered, + ) diff --git a/diff_diff/wooldridge_results.py b/diff_diff/wooldridge_results.py new file mode 100644 index 00000000..6196e3c1 --- /dev/null +++ b/diff_diff/wooldridge_results.py @@ -0,0 +1,333 @@ +"""Results class for WooldridgeDiD (ETWFE) estimator.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.utils import safe_inference + + +@dataclass +class WooldridgeDiDResults: + """Results from WooldridgeDiD.fit(). + + Core output is ``group_time_effects``: a dict keyed by (cohort_g, time_t) + with per-cell ATT estimates and inference. Call ``.aggregate(type)`` to + compute any of the four jwdid_estat aggregation types. + """ + + # ------------------------------------------------------------------ # + # Core cohort×time estimates # + # ------------------------------------------------------------------ # + group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] + """key=(g,t), value={att, se, t_stat, p_value, conf_int}""" + + # ------------------------------------------------------------------ # + # Simple (overall) aggregation — always populated at fit time # + # ------------------------------------------------------------------ # + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + + # ------------------------------------------------------------------ # + # Other aggregations — populated by .aggregate() # + # ------------------------------------------------------------------ # + group_effects: Optional[Dict[Any, Dict]] = field(default=None, repr=False) + calendar_effects: Optional[Dict[Any, Dict]] = field(default=None, repr=False) + event_study_effects: Optional[Dict[int, Dict]] = field(default=None, repr=False) + + # ------------------------------------------------------------------ # + # Metadata # + # ------------------------------------------------------------------ # + method: str = "ols" + control_group: str = "not_yet_treated" + groups: List[Any] = field(default_factory=list) + time_periods: List[Any] = field(default_factory=list) + n_obs: int = 0 + n_treated_units: int = 0 + n_control_units: int = 0 + alpha: float = 0.05 + + # ------------------------------------------------------------------ # + # Internal — used by aggregate() for delta-method SEs # + # ------------------------------------------------------------------ # + _gt_weights: Dict[Tuple[Any, Any], int] = field(default_factory=dict, repr=False) + _gt_vcov: Optional[np.ndarray] = field(default=None, repr=False) + """Full vcov of all β_{g,t} coefficients (ordered same as sorted group_time_effects keys).""" + _gt_keys: List[Tuple[Any, Any]] = field(default_factory=list, repr=False) + """Ordered list of (g,t) keys corresponding to _gt_vcov columns.""" + + # ------------------------------------------------------------------ # + # Public methods # + # ------------------------------------------------------------------ # + + def aggregate(self, type: str) -> "WooldridgeDiDResults": # noqa: A002 + """Compute and store one of the four jwdid_estat aggregation types. + + Parameters + ---------- + type : "simple" | "group" | "calendar" | "event" + + Returns self for chaining. + """ + valid = ("simple", "group", "calendar", "event") + if type not in valid: + raise ValueError(f"type must be one of {valid}, got {type!r}") + + gt = self.group_time_effects + weights = self._gt_weights + vcov = self._gt_vcov + keys_ordered = self._gt_keys if self._gt_keys else sorted(gt.keys()) + + def _agg_se(w_vec: np.ndarray) -> float: + """Delta-method SE for a linear combination w'β given full vcov.""" + if vcov is None or len(w_vec) != vcov.shape[0]: + return float("nan") + return float(np.sqrt(max(w_vec @ vcov @ w_vec, 0.0))) + + def _build_effect(att: float, se: float) -> Dict[str, Any]: + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha) + return { + "att": att, + "se": se, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + } + + if type == "simple": + # Re-compute overall using delta method (already stored in overall_* fields) + # This is a no-op but keeps the method callable. + pass + + elif type == "group": + result: Dict[Any, Dict] = {} + for g in self.groups: + cells = [(g2, t) for (g2, t) in keys_ordered if g2 == g and t >= g] + if not cells: + continue + w_total = sum(weights.get(c, 0) for c in cells) + if w_total == 0: + continue + att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total + # delta-method weights vector over all keys_ordered + w_vec = np.array( + [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered] + ) + se = _agg_se(w_vec) + result[g] = _build_effect(att, se) + self.group_effects = result + + elif type == "calendar": + result = {} + for t in self.time_periods: + cells = [(g, t2) for (g, t2) in keys_ordered if t2 == t and t >= g] + if not cells: + continue + w_total = sum(weights.get(c, 0) for c in cells) + if w_total == 0: + continue + att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total + w_vec = np.array( + [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered] + ) + se = _agg_se(w_vec) + result[t] = _build_effect(att, se) + self.calendar_effects = result + + elif type == "event": + all_k = sorted({t - g for (g, t) in keys_ordered}) + result = {} + for k in all_k: + cells = [(g, t) for (g, t) in keys_ordered if t - g == k] + if not cells: + continue + w_total = sum(weights.get(c, 0) for c in cells) + if w_total == 0: + continue + att = sum(weights.get(c, 0) * gt[c]["att"] for c in cells) / w_total + w_vec = np.array( + [weights.get(c, 0) / w_total if c in cells else 0.0 for c in keys_ordered] + ) + se = _agg_se(w_vec) + result[k] = _build_effect(att, se) + self.event_study_effects = result + + return self + + def summary(self, aggregation: str = "simple") -> str: + """Print formatted summary table. + + Parameters + ---------- + aggregation : which aggregation to display ("simple", "group", "calendar", "event") + """ + lines = [ + "=" * 70, + " Wooldridge Extended Two-Way Fixed Effects (ETWFE) Results", + "=" * 70, + f"Method: {self.method}", + f"Control group: {self.control_group}", + f"Observations: {self.n_obs}", + f"Treated units: {self.n_treated_units}", + f"Control units: {self.n_control_units}", + "-" * 70, + ] + + def _fmt_row(label: str, att: float, se: float, t: float, p: float, ci: Tuple) -> str: + from diff_diff.results import _get_significance_stars # type: ignore + + stars = _get_significance_stars(p) if not np.isnan(p) else "" + ci_lo = f"{ci[0]:.4f}" if not np.isnan(ci[0]) else "NaN" + ci_hi = f"{ci[1]:.4f}" if not np.isnan(ci[1]) else "NaN" + return ( + f"{label:<22} {att:>10.4f} {se:>10.4f} {t:>8.3f} " + f"{p:>8.4f}{stars} [{ci_lo}, {ci_hi}]" + ) + + header = ( + f"{'Parameter':<22} {'Estimate':>10} {'Std. Err.':>10} " + f"{'t-stat':>8} {'P>|t|':>8} [95% CI]" + ) + lines.append(header) + lines.append("-" * 70) + + if aggregation == "simple": + lines.append( + _fmt_row( + "ATT (simple)", + self.overall_att, + self.overall_se, + self.overall_t_stat, + self.overall_p_value, + self.overall_conf_int, + ) + ) + elif aggregation == "group" and self.group_effects: + for g, eff in sorted(self.group_effects.items()): + lines.append( + _fmt_row( + f"ATT(g={g})", + eff["att"], + eff["se"], + eff["t_stat"], + eff["p_value"], + eff["conf_int"], + ) + ) + elif aggregation == "calendar" and self.calendar_effects: + for t, eff in sorted(self.calendar_effects.items()): + lines.append( + _fmt_row( + f"ATT(t={t})", + eff["att"], + eff["se"], + eff["t_stat"], + eff["p_value"], + eff["conf_int"], + ) + ) + elif aggregation == "event" and self.event_study_effects: + for k, eff in sorted(self.event_study_effects.items()): + label = f"ATT(k={k})" + (" [pre]" if k < 0 else "") + lines.append( + _fmt_row( + label, + eff["att"], + eff["se"], + eff["t_stat"], + eff["p_value"], + eff["conf_int"], + ) + ) + else: + lines.append(f" (call .aggregate({aggregation!r}) first)") + + lines.append("=" * 70) + return "\n".join(lines) + + def to_dataframe(self, aggregation: str = "event") -> pd.DataFrame: + """Export aggregated effects to a DataFrame. + + Parameters + ---------- + aggregation : "simple" | "group" | "calendar" | "event" | "gt" + Use "gt" to export raw group-time effects. + """ + if aggregation == "gt": + rows = [] + for (g, t), eff in sorted(self.group_time_effects.items()): + row = {"cohort": g, "time": t, "relative_period": t - g} + row.update(eff) + rows.append(row) + return pd.DataFrame(rows) + + mapping = { + "simple": [ + { + "label": "ATT", + "att": self.overall_att, + "se": self.overall_se, + "t_stat": self.overall_t_stat, + "p_value": self.overall_p_value, + "conf_int_lo": self.overall_conf_int[0], + "conf_int_hi": self.overall_conf_int[1], + } + ], + "group": [ + { + "cohort": g, + **{k: v for k, v in eff.items() if k != "conf_int"}, + "conf_int_lo": eff["conf_int"][0], + "conf_int_hi": eff["conf_int"][1], + } + for g, eff in sorted((self.group_effects or {}).items()) + ], + "calendar": [ + { + "time": t, + **{k: v for k, v in eff.items() if k != "conf_int"}, + "conf_int_lo": eff["conf_int"][0], + "conf_int_hi": eff["conf_int"][1], + } + for t, eff in sorted((self.calendar_effects or {}).items()) + ], + "event": [ + { + "relative_period": k, + **{kk: vv for kk, vv in eff.items() if kk != "conf_int"}, + "conf_int_lo": eff["conf_int"][0], + "conf_int_hi": eff["conf_int"][1], + } + for k, eff in sorted((self.event_study_effects or {}).items()) + ], + } + rows = mapping.get(aggregation, []) + return pd.DataFrame(rows) + + def plot_event_study(self, **kwargs) -> None: + """Event study plot. Calls aggregate('event') if needed.""" + if self.event_study_effects is None: + self.aggregate("event") + from diff_diff.visualization import plot_event_study # type: ignore + + effects = {k: v["att"] for k, v in (self.event_study_effects or {}).items()} + se = {k: v["se"] for k, v in (self.event_study_effects or {}).items()} + plot_event_study(effects=effects, se=se, alpha=self.alpha, **kwargs) + + def __repr__(self) -> str: + n_gt = len(self.group_time_effects) + att_str = f"{self.overall_att:.4f}" if not np.isnan(self.overall_att) else "NaN" + se_str = f"{self.overall_se:.4f}" if not np.isnan(self.overall_se) else "NaN" + p_str = f"{self.overall_p_value:.4f}" if not np.isnan(self.overall_p_value) else "NaN" + return ( + f"WooldridgeDiDResults(" + f"ATT={att_str}, SE={se_str}, p={p_str}, " + f"n_gt={n_gt}, method={self.method!r})" + ) diff --git a/docs/api/index.rst b/docs/api/index.rst index 4dfd7ae7..92c506e0 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -25,6 +25,7 @@ Core estimator classes for DiD analysis: diff_diff.ContinuousDiD diff_diff.EfficientDiD diff_diff.TwoStageDiD + diff_diff.WooldridgeDiD diff_diff.BaconDecomposition Results Classes @@ -57,6 +58,7 @@ Result containers returned by estimators: diff_diff.TwoStageDiDResults diff_diff.TwoStageBootstrapResults diff_diff.BaconDecompositionResults + diff_diff.wooldridge_results.WooldridgeDiDResults diff_diff.Comparison2x2 Visualization @@ -233,6 +235,7 @@ Estimators continuous_did efficient_did two_stage + wooldridge_etwfe bacon Diagnostics & Inference diff --git a/docs/api/wooldridge_etwfe.rst b/docs/api/wooldridge_etwfe.rst new file mode 100644 index 00000000..a2e643dc --- /dev/null +++ b/docs/api/wooldridge_etwfe.rst @@ -0,0 +1,169 @@ +Wooldridge Extended Two-Way Fixed Effects (ETWFE) +=================================================== + +Extended Two-Way Fixed Effects estimator from Wooldridge (2021, 2023), +equivalent to the Stata ``jwdid`` package (Friosavila 2021). + +This module implements ETWFE via a single saturated regression that: + +1. **Estimates ATT(g,t)** for each cohort×time treatment cell simultaneously +2. **Supports linear (OLS), Poisson QMLE, and logit** link functions +3. **Uses ASF-based ATT** for nonlinear models: E[f(η₁)] − E[f(η₀)] +4. **Computes delta-method SEs** for all aggregations (event, group, calendar, simple) +5. **Matches Stata jwdid** output exactly for both OLS and nonlinear paths + +**When to use WooldridgeDiD:** + +- Staggered adoption design with heterogeneous treatment timing +- Nonlinear outcomes (binary, count, non-negative continuous) +- You want a single-regression approach matching Stata's ``jwdid`` +- You need event-study, group, calendar, or simple ATT aggregations + +**References:** + +- Wooldridge, J. M. (2021). Two-Way Fixed Effects, the Two-Way Mundlak + Regression, and Difference-in-Differences Estimators. *SSRN 3906345*. +- Wooldridge, J. M. (2023). Simple approaches to nonlinear + difference-in-differences with panel data. *The Econometrics Journal*, + 26(3), C31–C66. +- Friosavila, F. (2021). ``jwdid``: Stata module for ETWFE. SSC s459114. + +.. module:: diff_diff.wooldridge + +WooldridgeDiD +-------------- + +Main estimator class for Wooldridge ETWFE. + +.. autoclass:: diff_diff.WooldridgeDiD + :members: + :undoc-members: + :show-inheritance: + + .. rubric:: Methods + + .. autosummary:: + + ~WooldridgeDiD.fit + ~WooldridgeDiD.get_params + ~WooldridgeDiD.set_params + +WooldridgeDiDResults +--------------------- + +Results container returned by ``WooldridgeDiD.fit()``. + +.. autoclass:: diff_diff.wooldridge_results.WooldridgeDiDResults + :members: + :undoc-members: + :show-inheritance: + + .. rubric:: Methods + + .. autosummary:: + + ~WooldridgeDiDResults.aggregate + ~WooldridgeDiDResults.summary + +Example Usage +------------- + +Basic OLS (matches Stata ``jwdid y, ivar(unit) tvar(time) gvar(cohort)``):: + + import pandas as pd + from diff_diff import WooldridgeDiD + + df = pd.read_stata("mpdta.dta") + df['first_treat'] = df['first_treat'].astype(int) + + m = WooldridgeDiD() + r = m.fit(df, outcome='lemp', unit='countyreal', time='year', cohort='first_treat') + + r.aggregate('event').aggregate('group').aggregate('simple') + print(r.summary('event')) + print(r.summary('group')) + print(r.summary('simple')) + +View cohort×time cell estimates (post-treatment):: + + for (g, t), v in sorted(r.group_time_effects.items()): + if t >= g: + print(f"g={g} t={t} ATT={v['att']:.4f} SE={v['se']:.4f}") + +Poisson QMLE for non-negative outcomes +(matches Stata ``jwdid emp, method(poisson)``):: + + import numpy as np + df['emp'] = np.exp(df['lemp']) + + m_pois = WooldridgeDiD(method='poisson') + r_pois = m_pois.fit(df, outcome='emp', unit='countyreal', + time='year', cohort='first_treat') + r_pois.aggregate('event').aggregate('group').aggregate('simple') + print(r_pois.summary('simple')) + +Logit for binary outcomes +(matches Stata ``jwdid y, method(logit)``):: + + m_logit = WooldridgeDiD(method='logit') + r_logit = m_logit.fit(df, outcome='hi_emp', unit='countyreal', + time='year', cohort='first_treat') + r_logit.aggregate('group').aggregate('simple') + print(r_logit.summary('group')) + +Aggregation Methods +------------------- + +Call ``.aggregate(type)`` before ``.summary(type)``: + +.. list-table:: + :header-rows: 1 + :widths: 15 30 25 + + * - Type + - Description + - Stata equivalent + * - ``'event'`` + - ATT by relative time k = t − g + - ``estat event`` + * - ``'group'`` + - ATT averaged across post-treatment periods per cohort + - ``estat group`` + * - ``'calendar'`` + - ATT averaged across cohorts per calendar period + - ``estat calendar`` + * - ``'simple'`` + - Overall weighted average ATT + - ``estat simple`` + +Comparison with Other Staggered Estimators +------------------------------------------ + +.. list-table:: + :header-rows: 1 + :widths: 20 27 27 26 + + * - Feature + - WooldridgeDiD (ETWFE) + - CallawaySantAnna + - ImputationDiD + * - Approach + - Single saturated regression + - Separate 2×2 DiD per cell + - Impute Y(0) via FE model + * - Nonlinear outcomes + - Yes (Poisson, Logit) + - No + - No + * - Covariates + - Via regression (linear index) + - OR, IPW, DR + - Supported + * - SE for aggregations + - Delta method + - Multiplier bootstrap + - Multiplier bootstrap + * - Stata equivalent + - ``jwdid`` + - ``csdid`` + - ``did_imputation`` diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 554acd28..0add0945 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -15,6 +15,7 @@ This document provides the academic foundations and key implementation requireme - [ImputationDiD](#imputationdid) - [TwoStageDiD](#twostagedid) - [StackedDiD](#stackeddid) + - [WooldridgeDiD (ETWFE)](#wooldridgedid-etwfe) 3. [Advanced Estimators](#advanced-estimators) - [SyntheticDiD](#syntheticdid) - [TripleDifference](#tripledifference) @@ -1072,6 +1073,106 @@ The paper text states a stricter bound (T_min + 1) but the R code by the co-auth --- +## WooldridgeDiD (ETWFE) + +**Primary source:** Wooldridge, J. M. (2025). Two-way fixed effects, the two-way Mundlak regression, and difference-in-differences estimators. *Empirical Economics*, 69(5), 2545–2587. (Published version of the 2021 SSRN working paper NBER WP 29154.) + +**Secondary source:** Wooldridge, J. M. (2023). Simple approaches to nonlinear difference-in-differences with panel data. *The Econometrics Journal*, 26(3), C31–C66. https://doi.org/10.1093/ectj/utad016 + +**Application reference:** Nagengast, A. J., Rios-Avila, F., & Yotov, Y. V. (2026). The European single market and intra-EU trade: an assessment with heterogeneity-robust difference-in-differences methods. *Economica*, 93(369), 298–331. + +**Reference implementation:** Stata: `jwdid` package (Rios-Avila, 2021). R: `etwfe` package (McDermott, 2023). + +**Key implementation requirements:** + +*Core estimand:* + + ATT(g, t) = E[Y_it(g) - Y_it(0) | G_i = g, T = t] for t >= g + +where `g` is cohort (first treatment period), `t` is calendar time. + +*OLS design matrix (Wooldridge 2025, Section 5):* + +The saturated ETWFE regression includes: +1. Unit fixed effects (absorbed via within-transformation or as dummies) +2. Time fixed effects (absorbed or as dummies) +3. Cohort×time treatment interactions: `I(G_i = g) * I(T = t)` for each post-treatment (g, t) cell +4. Additional covariates X_it interacted with cohort×time indicators (optional) + +The interaction coefficient `δ_{g,t}` identifies `ATT(g, t)` under parallel trends. + +*Nonlinear extensions (Wooldridge 2023):* + +For binary outcomes (logit) and count outcomes (Poisson), Wooldridge (2023) provides an +Average Structural Function (ASF) approach. For each treated cell (g, t): + + ATT(g, t) = mean_i[g(η_i + δ_{g,t}) - g(η_i)] over units i in cell (g, t) + +where `g(·)` is the link inverse (logistic or exp), `η_i` is the individual linear predictor +(fixed effects + controls), and `δ_{g,t}` is the interaction coefficient from the nonlinear model. + +*Standard errors:* +- OLS: Cluster-robust sandwich estimator at the unit level (default) +- Logit/Poisson: QMLE sandwich `(X'WX)^{-1} meat (X'WX)^{-1}` via `compute_robust_vcov(..., weights=w, weight_type="aweight")` where `w = p_i(1-p_i)` for logit or `w = μ_i` for Poisson +- Delta-method SEs for ATT(g,t) from nonlinear models: `Var(ATT) = ∇θ' Σ_β ∇θ` +- Joint delta method for overall ATT: `agg_grad = Σ_k (w_k/w_total) * ∇θ_k` +- **Deviation from R:** R's `etwfe` package uses `fixest` for nonlinear paths; this implementation uses direct QMLE via `compute_robust_vcov` to avoid a statsmodels/fixest dependency. +- **Note:** QMLE sandwich uses `weight_type="aweight"` which applies `(G/(G-1)) * ((n-1)/(n-k))` small-sample adjustment. Stata `jwdid` uses `G/(G-1)` only. The `(n-1)/(n-k)` term is conservative (inflates SEs slightly). For typical ETWFE panels where n >> k, the difference is negligible. + +*Aggregations (matching `jwdid_estat`):* +- `simple`: Weighted average across all post-treatment (g, t) cells with weights `n_{g,t}`: + + ATT_overall = Σ_{(g,t): t≥g} n_{g,t} · ATT(g,t) / Σ_{(g,t): t≥g} n_{g,t} + + Cell weight `n_{g,t}` = count of obs in cohort g at time t in estimation sample. + - **Note:** Cell-level weighting (n_{g,t} observation counts) matches Stata `jwdid_estat` behavior. Differs from W2025 Eqs. 7.2-7.4 cohort-share weights that account for the number of post-treatment periods per cohort. + +- `group`: Weighted average across t for each cohort g +- `calendar`: Weighted average across g for each calendar time t +- `event`: Weighted average across (g, t) cells by relative period k = t - g + +*Covariates:* +- `exovar`: Time-invariant covariates, added without demeaning (corresponds to W2025 Eq. 5.2 `x_i`) +- `xtvar`: Time-varying covariates, demeaned within cohort×period cells when `demean_covariates=True` (corresponds to W2025 Eq. 10.2 `x_hat_itgs = x_it - x_bar_gs`) +- `xgvar`: Covariates interacted with each cohort indicator +- **Note:** `xtvar` demeaning operates at the cohort×period level (W2025 Eq. 10.2), not the cohort level (W2025 Eq. 5.2). These are identical for time-constant covariates but differ for time-varying covariates. + +*Control groups:* +- `not_yet_treated` (default): Control pool includes units not yet treated at time t (same as Callaway-Sant'Anna) +- `never_treated`: Control pool restricted to never-treated units only + +*Edge cases:* +- Single cohort (no staggered adoption): Reduces to standard 2×2 DiD +- Missing cohorts: Only cohorts observed in the data are included in interactions +- Anticipation: When `anticipation > 0`, interactions include periods `t >= g - anticipation` +- Never-treated control only: Pre-treatment periods still estimable as placebo ATTs +- **Note:** Poisson QMLE with cohort+time dummies (not unit dummies) is consistent even in short panels (Wooldridge 1999, JBES). The exponential mean function is unique in that incidental parameters from group dummies do not cause inconsistency. +- **Note:** Logit path uses cohort×time additive dummies (not unit dummies) to avoid incidental parameters bias — a standard limitation of logit FE in short panels. This matches Stata `jwdid method(logit)` which uses `i.gvar i.tvar`. + +*Algorithm:* +1. Identify cohorts G and time periods T from data +2. Build within-transformed design matrix (absorb unit + time FE) +3. Append cohort×time interaction columns for all post-treatment cells +4. Fit OLS/logit/Poisson +5. For nonlinear: compute ASF-based ATT(g,t) and delta-method SEs per cell +6. For OLS: extract δ_{g,t} coefficients directly as ATT(g,t) +7. Compute overall ATT as weighted average; store full vcov for aggregate SEs +8. Optionally run multiplier bootstrap for overall SE + +**Requirements checklist:** +- [x] Saturated cohort×time interaction design matrix +- [x] Unit + time FE absorption (within-transformation) +- [x] OLS, logit (IRLS), and Poisson (IRLS) fitting methods +- [x] Cluster-robust SEs at unit level for all methods +- [x] ASF-based ATT for nonlinear methods with delta-method SEs +- [x] Joint delta-method SE for aggregate ATT in nonlinear models +- [x] Four aggregation types: simple, group, calendar, event +- [x] Both control groups: not_yet_treated, never_treated +- [x] Anticipation parameter support +- [x] Multiplier bootstrap (Rademacher/Webb/Mammen) for OLS overall SE + +--- + # Advanced Estimators ## SyntheticDiD diff --git a/docs/tutorials/16_wooldridge_etwfe.ipynb b/docs/tutorials/16_wooldridge_etwfe.ipynb new file mode 100644 index 00000000..acbd9643 --- /dev/null +++ b/docs/tutorials/16_wooldridge_etwfe.ipynb @@ -0,0 +1,498 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a1b2c3d4", + "metadata": {}, + "source": "# Wooldridge Extended Two-Way Fixed Effects (ETWFE)\n\nThis tutorial demonstrates the `WooldridgeDiD` estimator (alias: `ETWFE`), which implements Wooldridge's (2021, 2023) Extended Two-Way Fixed Effects approach — the basis of the Stata `jwdid` package.\n\n**What ETWFE does:** Estimates cohort×time Average Treatment Effects (ATT(g,t)) via a single saturated regression that interacts treatment indicators with cohort×time cells. Unlike standard TWFE, it correctly handles heterogeneous treatment effects across cohorts and time periods. The key insight is to include all cohort×time interaction terms simultaneously, with additive cohort and time fixed effects.\n\n**Key features:**\n- Follows the Stata `jwdid` specification (OLS and nonlinear paths; see Methodology Registry for documented SE/aggregation deviations)\n- Supports **linear (OLS)**, **Poisson**, and **logit** link functions\n- Nonlinear ATTs use the Average Structural Function (ASF): E[f(η₁)] − E[f(η₀)]\n- Delta-method standard errors for all aggregations\n- Cluster-robust sandwich variance\n\n**Topics covered:**\n1. Basic OLS estimation\n2. Cohort×time cell estimates ATT(g,t)\n3. Aggregation: event-study, group, simple\n4. Poisson QMLE for count / non-negative outcomes\n5. Logit for binary outcomes\n6. Comparison with Callaway-Sant'Anna\n7. Parameter reference and guidance\n\n*Prerequisites: [Tutorial 02](02_staggered_did.ipynb) (Staggered DiD).*\n\n*See also: [Tutorial 15](15_efficient_did.ipynb) for Efficient DiD, [Tutorial 11](11_imputation_did.ipynb) for Imputation DiD.*" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2c3d4e5", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from diff_diff import WooldridgeDiD, CallawaySantAnna, generate_staggered_data\n", + "\n", + "try:\n", + " import matplotlib.pyplot as plt\n", + " plt.style.use('seaborn-v0_8-whitegrid')\n", + " HAS_MATPLOTLIB = True\n", + "except ImportError:\n", + " HAS_MATPLOTLIB = False\n", + " print(\"matplotlib not installed - visualization examples will be skipped\")" + ] + }, + { + "cell_type": "markdown", + "id": "c3d4e5f6", + "metadata": {}, + "source": [ + "## Data Setup\n", + "\n", + "We use `generate_staggered_data()` to create a balanced panel with 3 treatment cohorts, a never-treated group, and a known ATT of 2.0. This makes it easy to verify estimation accuracy.\n", + "\n", + "We also demonstrate with the **mpdta** dataset (Callaway & Sant'Anna 2021), which contains county-level log employment data with staggered minimum-wage adoption — the canonical benchmark for staggered DiD methods." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4e5f6a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Simulated data\n", + "data = generate_staggered_data(\n", + " n_units=300, n_periods=10, treatment_effect=2.0,\n", + " dynamic_effects=False, seed=42\n", + ")\n", + "\n", + "print(f\"Shape: {data.shape}\")\n", + "print(f\"Cohorts: {sorted(data['first_treat'].unique())}\")\n", + "print(f\"Periods: {sorted(data['period'].unique())}\")\n", + "print()\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "e5f6a7b8", + "metadata": {}, + "source": [ + "## Basic OLS Estimation\n", + "\n", + "The default `method='ols'` fits a single regression with:\n", + "- Treatment interaction dummies (one per treatment cohort × post-treatment period cell)\n", + "- Additive cohort fixed effects (`i.gvar` in Stata)\n", + "- Additive time fixed effects (`i.tvar` in Stata)\n", + "\n", + "This matches Stata's `jwdid y, ivar(unit) tvar(time) gvar(cohort)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6a7b8c9", + "metadata": {}, + "outputs": [], + "source": [ + "m = WooldridgeDiD() # default: method='ols'\n", + "r = m.fit(data, outcome='outcome', unit='unit', time='period', cohort='first_treat')\n", + "\n", + "# Compute aggregations\n", + "r.aggregate('event').aggregate('group').aggregate('simple')\n", + "\n", + "print(r.summary('simple'))" + ] + }, + { + "cell_type": "markdown", + "id": "a7b8c9d0", + "metadata": {}, + "source": [ + "## Cohort×Time Cell Estimates ATT(g,t)\n", + "\n", + "The raw building blocks are ATT(g,t) — the treatment effect for cohort `g` at calendar time `t`. These are stored in `r.group_time_effects` and correspond to Stata's regression output table (`first_treat#year#c.__tr__`).\n", + "\n", + "Post-treatment cells have `t >= g`; pre-treatment cells (`t < g`) serve as placebo checks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8c9d0e1", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Post-treatment ATT(g,t) cells\")\n", + "print(\"{:>8} {:>8} | {:>10} {:>10} {:>7} {:>7}\".format(\n", + " \"cohort\", \"year\", \"Coef.\", \"Std.Err.\", \"t\", \"P>|t|\"))\n", + "print(\"-\" * 60)\n", + "\n", + "for (g, t), v in sorted(r.group_time_effects.items()):\n", + " if t < g:\n", + " continue\n", + " row = \"{:>8} {:>8} | {:>10.4f} {:>10.4f} {:>7.2f} {:>7.3f}\".format(\n", + " int(g), int(t), v['att'], v['se'], v['t_stat'], v['p_value']\n", + " )\n", + " print(row)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9d0e1f2", + "metadata": {}, + "outputs": [], + "source": [ + "# Also show pre-treatment placebo cells\n", + "print(\"Pre-treatment placebo ATT(g,t) cells (should be ~0 under parallel trends)\")\n", + "print(\"{:>8} {:>8} | {:>10} {:>10} {:>7} {:>7}\".format(\n", + " \"cohort\", \"year\", \"Coef.\", \"Std.Err.\", \"t\", \"P>|t|\"))\n", + "print(\"-\" * 60)\n", + "\n", + "for (g, t), v in sorted(r.group_time_effects.items()):\n", + " if t >= g:\n", + " continue\n", + " row = \"{:>8} {:>8} | {:>10.4f} {:>10.4f} {:>7.2f} {:>7.3f}\".format(\n", + " int(g), int(t), v['att'], v['se'], v['t_stat'], v['p_value']\n", + " )\n", + " print(row)" + ] + }, + { + "cell_type": "markdown", + "id": "d0e1f2a3", + "metadata": {}, + "source": [ + "## Aggregation Methods\n", + "\n", + "ETWFE supports four aggregation types, matching Stata's `estat` post-estimation commands:\n", + "\n", + "| Python | Stata | Description |\n", + "|--------|-------|-------------|\n", + "| `aggregate('event')` | `estat event` | By relative time k = t − g |\n", + "| `aggregate('group')` | `estat group` | By treatment cohort g |\n", + "| `aggregate('calendar')` | `estat calendar` | By calendar time t |\n", + "| `aggregate('simple')` | `estat simple` | Overall weighted average ATT |\n", + "\n", + "Standard errors use the delta method, propagating uncertainty from the cell-level ATT covariance matrix." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1f2a3b4", + "metadata": {}, + "outputs": [], + "source": [ + "# Event-study aggregation: ATT by relative time k = t - g\n", + "print(r.summary('event'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2a3b4c5", + "metadata": {}, + "outputs": [], + "source": [ + "# Group aggregation: ATT averaged across post-treatment periods for each cohort\n", + "print(r.summary('group'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3b4c5d6", + "metadata": {}, + "outputs": [], + "source": [ + "# Simple ATT: overall weighted average\n", + "print(r.summary('simple'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4c5d6e7", + "metadata": {}, + "outputs": [], + "source": [ + "# Event study plot\n", + "if HAS_MATPLOTLIB:\n", + " es = r.event_study_effects\n", + " ks = sorted(es.keys())\n", + " atts = [es[k]['att'] for k in ks]\n", + " lo = [es[k]['conf_int'][0] for k in ks]\n", + " hi = [es[k]['conf_int'][1] for k in ks]\n", + "\n", + " fig, ax = plt.subplots(figsize=(9, 5))\n", + " ax.errorbar(ks, atts, yerr=[np.array(atts) - np.array(lo), np.array(hi) - np.array(atts)],\n", + " fmt='o-', capsize=4, color='steelblue', label='ETWFE (OLS)')\n", + " ax.axhline(0, color='black', linestyle='--', linewidth=0.8)\n", + " ax.axvline(-0.5, color='red', linestyle=':', linewidth=0.8, label='Treatment onset')\n", + " ax.set_xlabel('Relative period (k = t − g)')\n", + " ax.set_ylabel('ATT')\n", + " ax.set_title('ETWFE Event Study')\n", + " ax.legend()\n", + " plt.tight_layout()\n", + " plt.show()\n", + "else:\n", + " print(\"Install matplotlib to see the event study plot: pip install matplotlib\")" + ] + }, + { + "cell_type": "markdown", + "id": "c5d6e7f8", + "metadata": {}, + "source": [ + "## Poisson QMLE for Count / Non-Negative Outcomes\n", + "\n", + "`method='poisson'` fits a Poisson QMLE regression. This is valid for any non-negative continuous outcome, not just count data — the Poisson log-likelihood produces consistent estimates whenever the conditional mean is correctly specified as exp(Xβ).\n", + "\n", + "The ATT is computed as the **Average Structural Function (ASF) difference**:\n", + "\n", + "$$\\text{ATT}(g,t) = \\frac{1}{N_{g,t}} \\sum_{i \\in g,t} \\left[\\exp(\\eta_{i,1}) - \\exp(\\eta_{i,0})\\right]$$\n", + "\n", + "where η₁ = Xβ (with treatment) and η₀ = Xβ − δ (counterfactual without treatment).\n", + "\n", + "This matches Stata's `jwdid y, method(poisson)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6e7f8a9", + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate a non-negative outcome (e.g., employment level)\n", + "data_pois = data.copy()\n", + "data_pois['emp'] = np.exp(data_pois['outcome'] / 4 + 3) # positive outcome\n", + "\n", + "m_pois = WooldridgeDiD(method='poisson')\n", + "r_pois = m_pois.fit(data_pois, outcome='emp', unit='unit', time='period', cohort='first_treat')\n", + "r_pois.aggregate('event').aggregate('group').aggregate('simple')\n", + "\n", + "print(r_pois.summary('simple'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7f8a9b0", + "metadata": {}, + "outputs": [], + "source": [ + "# Cohort×time cells (post-treatment, Poisson)\n", + "print(\"Poisson ATT(g,t) — post-treatment cells\")\n", + "print(\"{:>8} {:>8} | {:>10} {:>10} {:>7} {:>7}\".format(\n", + " \"cohort\", \"year\", \"ATT\", \"Std.Err.\", \"t\", \"P>|t|\"))\n", + "print(\"-\" * 60)\n", + "\n", + "for (g, t), v in sorted(r_pois.group_time_effects.items()):\n", + " if t < g:\n", + " continue\n", + " print(\"{:>8} {:>8} | {:>10.4f} {:>10.4f} {:>7.2f} {:>7.3f}\".format(\n", + " int(g), int(t), v['att'], v['se'], v['t_stat'], v['p_value']\n", + " ))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8a9b0c1", + "metadata": {}, + "outputs": [], + "source": [ + "print(r_pois.summary('event'))\n", + "print(r_pois.summary('group'))" + ] + }, + { + "cell_type": "markdown", + "id": "a9b0c1d2", + "metadata": {}, + "source": [ + "## Logit for Binary Outcomes\n", + "\n", + "`method='logit'` fits a logit model and computes ATT as the ASF probability difference:\n", + "\n", + "$$\\text{ATT}(g,t) = \\frac{1}{N_{g,t}} \\sum_{i \\in g,t} \\left[\\Lambda(\\eta_{i,1}) - \\Lambda(\\eta_{i,0})\\right]$$\n", + "\n", + "where Λ(·) is the logistic function. Standard errors use the delta method.\n", + "\n", + "This matches Stata's `jwdid y, method(logit)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0c1d2e3", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a binary outcome\n", + "data_logit = data.copy()\n", + "median_val = data_logit.loc[data_logit['period'] == data_logit['period'].min(), 'outcome'].median()\n", + "data_logit['hi_outcome'] = (data_logit['outcome'] > median_val).astype(int)\n", + "\n", + "print(f\"Binary outcome mean: {data_logit['hi_outcome'].mean():.3f}\")\n", + "\n", + "m_logit = WooldridgeDiD(method='logit')\n", + "r_logit = m_logit.fit(data_logit, outcome='hi_outcome', unit='unit', time='period', cohort='first_treat')\n", + "r_logit.aggregate('event').aggregate('group').aggregate('simple')\n", + "\n", + "print(r_logit.summary('simple'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1d2e3f4", + "metadata": {}, + "outputs": [], + "source": [ + "print(r_logit.summary('group'))" + ] + }, + { + "cell_type": "markdown", + "id": "d2e3f4a5", + "metadata": {}, + "source": "## mpdta: Real-World Example\n\nThe **mpdta** dataset (Callaway & Sant'Anna 2021) contains county-level log employment (`lemp`) data with staggered minimum-wage adoption (`first_treat` = year of treatment, 0 = never treated). It is the canonical benchmark for staggered DiD methods.\n\nThis follows Stata's `jwdid lemp, ivar(countyreal) tvar(year) gvar(first_treat)` specification. See the Methodology Registry for documented SE/aggregation deviations." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3f4a5b6", + "metadata": {}, + "outputs": [], + "source": "from diff_diff import load_mpdta\n\nmpdta = load_mpdta()\nprint(f\"mpdta loaded: {mpdta.shape}\")\nprint(f\"Cohorts: {sorted(mpdta['first_treat'].unique())}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4a5b6c7", + "metadata": {}, + "outputs": [], + "source": "# OLS — matches: jwdid lemp, ivar(countyreal) tvar(year) gvar(first_treat)\nm_ols = WooldridgeDiD(method='ols')\nr_ols = m_ols.fit(mpdta, outcome='lemp', unit='countyreal', time='year', cohort='first_treat')\nr_ols.aggregate('event').aggregate('group').aggregate('simple')\nprint(r_ols.summary('event'))" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5b6c7d8", + "metadata": {}, + "outputs": [], + "source": "# cohort x time ATT cells (post-treatment)\n# Matches Stata: first_treat#year#c.__tr__ output table\nprint(\"ATT(g,t) — post-treatment cells (matches Stata jwdid output)\")\nprint(\"{:>6} {:>6} | {:>9} {:>9} {:>7} {:>7}\".format(\n \"cohort\", \"year\", \"Coef.\", \"Std.Err.\", \"t\", \"P>|t|\"))\nprint(\"-\" * 55)\nfor (g, t), v in sorted(r_ols.group_time_effects.items()):\n if t < g:\n continue\n print(\"{:>6} {:>6} | {:>9.4f} {:>9.4f} {:>7.2f} {:>7.3f}\".format(\n g, t, v['att'], v['se'], v['t_stat'], v['p_value']))" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6c7d8e9", + "metadata": {}, + "outputs": [], + "source": "# Poisson — matches: gen emp=exp(lemp) / jwdid emp, method(poisson)\nmpdta['emp'] = np.exp(mpdta['lemp'])\n\nm_pois2 = WooldridgeDiD(method='poisson')\nr_pois2 = m_pois2.fit(mpdta, outcome='emp', unit='countyreal', time='year', cohort='first_treat')\nr_pois2.aggregate('event').aggregate('group').aggregate('simple')\n\nprint(r_pois2.summary('event'))\nprint(r_pois2.summary('group'))\nprint(r_pois2.summary('simple'))" + }, + { + "cell_type": "markdown", + "id": "c7d8e9f0", + "metadata": {}, + "source": [ + "## Comparison with Callaway-Sant'Anna\n", + "\n", + "ETWFE and Callaway-Sant'Anna are both valid for staggered designs. Under homogeneous treatment effects and additive parallel trends, they should produce similar ATT(g,t) point estimates. Key differences:\n", + "\n", + "| Aspect | WooldridgeDiD (ETWFE) | CallawaySantAnna |\n", + "|--------|----------------------|------------------|\n", + "| Approach | Single saturated regression | Separate 2×2 DiD per cell |\n", + "| Nonlinear outcomes | Yes (Poisson, Logit) | No |\n", + "| Covariates | Via regression (linear index) | OR, IPW, DR |\n", + "| SE for aggregations | Delta method | Multiplier bootstrap |\n", + "| Stata equivalent | `jwdid` | `csdid` |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8e9f0a1", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare overall ATT: ETWFE vs Callaway-Sant'Anna\n", + "cs = CallawaySantAnna()\n", + "r_cs = cs.fit(data, outcome='outcome', unit='unit', time='period', first_treat='first_treat')\n", + "\n", + "m_etwfe = WooldridgeDiD(method='ols')\n", + "r_etwfe = m_etwfe.fit(data, outcome='outcome', unit='unit', time='period', cohort='first_treat')\n", + "r_etwfe.aggregate('simple')\n", + "\n", + "print(\"Overall ATT Comparison (true effect = 2.0)\")\n", + "print(\"=\" * 60)\n", + "print(\"{:<25} {:>10} {:>10} {:>12}\".format(\"Estimator\", \"ATT\", \"SE\", \"95% CI\"))\n", + "print(\"-\" * 60)\n", + "\n", + "for name, est_r in [(\"WooldridgeDiD (ETWFE)\", r_etwfe), (\"CallawaySantAnna\", r_cs)]:\n", + " ci = est_r.overall_conf_int\n", + " print(\"{:<25} {:>10.4f} {:>10.4f} [{:.3f}, {:.3f}]\".format(\n", + " name, est_r.overall_att, est_r.overall_se, ci[0], ci[1]\n", + " ))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9f0a1b2", + "metadata": {}, + "outputs": [], + "source": [ + "# Event-study comparison\n", + "r_cs_es = CallawaySantAnna().fit(\n", + " data, outcome='outcome', unit='unit', time='period',\n", + " first_treat='first_treat', aggregate='event_study'\n", + ")\n", + "\n", + "if HAS_MATPLOTLIB:\n", + " es_etwfe = r_etwfe.event_study_effects\n", + " es_cs = {int(row['relative_period']): row\n", + " for _, row in r_cs_es.to_dataframe(level='event_study').iterrows()}\n", + "\n", + " ks = sorted(es_etwfe.keys())\n", + "\n", + " fig, ax = plt.subplots(figsize=(10, 5))\n", + " offset = 0.1\n", + "\n", + " atts_e = [es_etwfe[k]['att'] for k in ks]\n", + " lo_e = [es_etwfe[k]['conf_int'][0] for k in ks]\n", + " hi_e = [es_etwfe[k]['conf_int'][1] for k in ks]\n", + " ax.errorbar([k - offset for k in ks], atts_e,\n", + " yerr=[np.array(atts_e) - np.array(lo_e), np.array(hi_e) - np.array(atts_e)],\n", + " fmt='o-', capsize=4, color='steelblue', label='ETWFE')\n", + "\n", + " ks_cs = sorted(es_cs.keys())\n", + " atts_cs = [es_cs[k]['effect'] for k in ks_cs]\n", + " lo_cs = [es_cs[k]['conf_int_lower'] for k in ks_cs]\n", + " hi_cs = [es_cs[k]['conf_int_upper'] for k in ks_cs]\n", + " ax.errorbar([k + offset for k in ks_cs], atts_cs,\n", + " yerr=[np.array(atts_cs) - np.array(lo_cs), np.array(hi_cs) - np.array(atts_cs)],\n", + " fmt='s--', capsize=4, color='darkorange', label='Callaway-Sant\\'Anna')\n", + "\n", + " ax.axhline(0, color='black', linestyle='--', linewidth=0.8)\n", + " ax.axvline(-0.5, color='red', linestyle=':', linewidth=0.8)\n", + " ax.set_xlabel('Relative period (k = t − g)')\n", + " ax.set_ylabel('ATT')\n", + " ax.set_title('Event Study: ETWFE vs Callaway-Sant\\'Anna')\n", + " ax.legend()\n", + " plt.tight_layout()\n", + " plt.show()\n", + "else:\n", + " print(\"Install matplotlib to see the comparison plot: pip install matplotlib\")" + ] + }, + { + "cell_type": "markdown", + "id": "f0a1b2c3", + "metadata": {}, + "source": "## Summary\n\n**Key takeaways:**\n\n1. **ETWFE via a single regression**: all ATT(g,t) cells estimated jointly, not separately — computationally efficient and internally consistent\n2. **OLS path** follows the Stata `jwdid` specification: additive cohort + time FEs, treatment interaction dummies\n3. **Nonlinear paths** (Poisson, Logit) use the ASF formula: E[f(η₁)] − E[f(η₀)] — the only valid ATT definition for nonlinear models\n4. **Four aggregations** mirror Stata's `estat` commands: event, group, calendar, simple\n5. **Delta-method SEs** for all aggregations, including nonlinear paths\n6. **When to prefer ETWFE**: nonlinear outcomes, or when a single-regression framework is preferred\n7. **When to prefer CS/ImputationDiD**: covariate adjustment via IPW/DR, or multiplier bootstrap inference\n\n**Parameter reference:**\n\n| Parameter | Default | Description |\n|-----------|---------|-------------|\n| `method` | `'ols'` | `'ols'`, `'poisson'`, or `'logit'` |\n| `control_group` | `'not_yet_treated'` | `'not_yet_treated'` or `'never_treated'` |\n| `anticipation` | `0` | Anticipation periods before treatment |\n| `alpha` | `0.05` | Significance level |\n| `cluster` | `None` | Column for clustering (default: unit variable) |\n\n**References:**\n- Wooldridge, J. M. (2021). Two-Way Fixed Effects, the Two-Way Mundlak Regression, and Difference-in-Differences Estimators. *SSRN 3906345*.\n- Wooldridge, J. M. (2023). Simple approaches to nonlinear difference-in-differences with panel data. *The Econometrics Journal*, 26(3), C31–C66.\n- Friosavila, F. (2021). `jwdid`: Stata module for ETWFE. SSC s459114.\n\n*See also: [Tutorial 02](02_staggered_did.ipynb) for Callaway-Sant'Anna, [Tutorial 15](15_efficient_did.ipynb) for Efficient DiD.*" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/docs/tutorials/README.md b/docs/tutorials/README.md index 0b44c8ea..82ff81ec 100644 --- a/docs/tutorials/README.md +++ b/docs/tutorials/README.md @@ -51,7 +51,16 @@ Efficient Difference-in-Differences (Chen, Sant'Anna & Xie 2025): - Event study and group-level aggregation - Bootstrap inference and diagnostics -### 16. Survey-Aware DiD (`16_survey_did.ipynb`) +### 16. Wooldridge ETWFE (`16_wooldridge_etwfe.ipynb`) +Wooldridge Extended Two-Way Fixed Effects (ETWFE) for staggered DiD: +- Basic OLS estimation with cohort x time ATT cells +- Aggregation methods: event-study, group, calendar, simple +- Poisson QMLE for count / non-negative outcomes +- Logit for binary outcomes +- Comparison with Callaway-Sant'Anna +- Delta-method standard errors + +### Survey-Aware DiD (`16_survey_did.ipynb`) Survey-aware DiD with complex sampling designs (strata, PSU, FPC, weights): - Why survey design matters for DiD inference - Setting up `SurveyDesign` (weights, strata, PSU, FPC) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 0bde37d9..cc49c113 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -10,6 +10,7 @@ compute_r_squared, compute_robust_vcov, solve_ols, + solve_poisson, ) @@ -1856,3 +1857,46 @@ def test_solve_ols_no_runtime_warnings(self): f"{[str(x.message) for x in runtime_warnings]}" ) assert np.allclose(coefficients, beta_true, atol=0.1) + + +class TestSolvePoisson: + def test_basic_convergence(self): + """solve_poisson converges on simple count data.""" + rng = np.random.default_rng(42) + n = 200 + X = np.column_stack([np.ones(n), rng.standard_normal((n, 2))]) + true_beta = np.array([0.5, 0.3, -0.2]) + mu = np.exp(X @ true_beta) + y = rng.poisson(mu).astype(float) + beta, W = solve_poisson(X, y) + assert beta.shape == (3,) + assert W.shape == (n,) + assert np.allclose(beta, true_beta, atol=0.15) + + def test_returns_weights(self): + """solve_poisson returns final mu weights for vcov computation.""" + rng = np.random.default_rng(0) + n = 100 + X = np.column_stack([np.ones(n), rng.standard_normal(n)]) + y = rng.poisson(2.0, size=n).astype(float) + beta, W = solve_poisson(X, y) + assert (W > 0).all() + + def test_non_negative_output(self): + """Fitted mu = exp(Xb) should be strictly positive.""" + rng = np.random.default_rng(1) + n = 50 + X = np.column_stack([np.ones(n), rng.standard_normal(n)]) + y = rng.poisson(1.0, size=n).astype(float) + beta, W = solve_poisson(X, y) + mu_hat = np.exp(X @ beta) + assert (mu_hat > 0).all() + + def test_no_intercept_prepended(self): + """solve_poisson does NOT add intercept (caller's responsibility).""" + rng = np.random.default_rng(2) + n = 80 + X = np.column_stack([np.ones(n), rng.standard_normal(n)]) + y = rng.poisson(1.5, size=n).astype(float) + beta, _ = solve_poisson(X, y) + assert len(beta) == 2 # not 3 diff --git a/tests/test_wooldridge.py b/tests/test_wooldridge.py new file mode 100644 index 00000000..3be4a5d4 --- /dev/null +++ b/tests/test_wooldridge.py @@ -0,0 +1,779 @@ +"""Tests for WooldridgeDiD estimator and WooldridgeDiDResults.""" + +import numpy as np +import pandas as pd +import pytest + +from diff_diff.wooldridge import ( + WooldridgeDiD, + _build_interaction_matrix, + _filter_sample, + _prepare_covariates, +) +from diff_diff.wooldridge_results import WooldridgeDiDResults + + +def _make_minimal_results(**kwargs): + """Helper: build a WooldridgeDiDResults with required fields.""" + defaults = dict( + group_time_effects={ + (2, 2): { + "att": 1.0, + "se": 0.5, + "t_stat": 2.0, + "p_value": 0.04, + "conf_int": (0.02, 1.98), + }, + (2, 3): { + "att": 1.5, + "se": 0.6, + "t_stat": 2.5, + "p_value": 0.01, + "conf_int": (0.32, 2.68), + }, + (3, 3): { + "att": 0.8, + "se": 0.4, + "t_stat": 2.0, + "p_value": 0.04, + "conf_int": (0.02, 1.58), + }, + }, + overall_att=1.1, + overall_se=0.35, + overall_t_stat=3.14, + overall_p_value=0.002, + overall_conf_int=(0.41, 1.79), + group_effects=None, + calendar_effects=None, + event_study_effects=None, + method="ols", + control_group="not_yet_treated", + groups=[2, 3], + time_periods=[1, 2, 3], + n_obs=300, + n_treated_units=100, + n_control_units=200, + alpha=0.05, + _gt_weights={(2, 2): 50, (2, 3): 50, (3, 3): 30}, + _gt_vcov=None, + ) + defaults.update(kwargs) + return WooldridgeDiDResults(**defaults) + + +class TestWooldridgeDiDResults: + def test_repr(self): + r = _make_minimal_results() + s = repr(r) + assert "WooldridgeDiDResults" in s + assert "ATT" in s + + def test_summary_default(self): + r = _make_minimal_results() + s = r.summary() + assert "1.1" in s or "ATT" in s + + def test_to_dataframe_event(self): + r = _make_minimal_results() + r.aggregate("event") + df = r.to_dataframe("event") + assert isinstance(df, pd.DataFrame) + assert "att" in df.columns + + def test_aggregate_simple_returns_self(self): + r = _make_minimal_results() + result = r.aggregate("simple") + assert result is r # chaining + + def test_aggregate_group(self): + r = _make_minimal_results() + r.aggregate("group") + assert r.group_effects is not None + assert 2 in r.group_effects + assert 3 in r.group_effects + + def test_aggregate_calendar(self): + r = _make_minimal_results() + r.aggregate("calendar") + assert r.calendar_effects is not None + assert 2 in r.calendar_effects or 3 in r.calendar_effects + + def test_aggregate_event(self): + r = _make_minimal_results() + r.aggregate("event") + assert r.event_study_effects is not None + # relative period 0 (treatment period itself) should be present + assert 0 in r.event_study_effects or 1 in r.event_study_effects + + def test_aggregate_invalid_raises(self): + r = _make_minimal_results() + with pytest.raises(ValueError, match="type"): + r.aggregate("bad_type") + + +class TestWooldridgeDiDAPI: + def test_default_construction(self): + est = WooldridgeDiD() + assert est.method == "ols" + assert est.control_group == "not_yet_treated" + assert est.anticipation == 0 + assert est.demean_covariates is True + assert est.alpha == 0.05 + assert est.cluster is None + assert est.n_bootstrap == 0 + assert est.bootstrap_weights == "rademacher" + assert est.seed is None + assert est.rank_deficient_action == "warn" + assert not est.is_fitted_ + + def test_invalid_method_raises(self): + with pytest.raises(ValueError, match="method"): + WooldridgeDiD(method="probit") + + def test_invalid_control_group_raises(self): + with pytest.raises(ValueError, match="control_group"): + WooldridgeDiD(control_group="clean_control") + + def test_invalid_anticipation_raises(self): + with pytest.raises(ValueError, match="anticipation"): + WooldridgeDiD(anticipation=-1) + + def test_get_params_roundtrip(self): + est = WooldridgeDiD(method="logit", alpha=0.1, anticipation=1) + params = est.get_params() + assert params["method"] == "logit" + assert params["alpha"] == 0.1 + assert params["anticipation"] == 1 + + def test_set_params_roundtrip(self): + est = WooldridgeDiD() + est.set_params(alpha=0.01, n_bootstrap=100) + assert est.alpha == 0.01 + assert est.n_bootstrap == 100 + + def test_set_params_returns_self(self): + est = WooldridgeDiD() + result = est.set_params(alpha=0.1) + assert result is est + + def test_set_params_unknown_raises(self): + est = WooldridgeDiD() + with pytest.raises(ValueError, match="Unknown"): + est.set_params(nonexistent_param=42) + + def test_results_before_fit_raises(self): + est = WooldridgeDiD() + with pytest.raises(RuntimeError, match="fit"): + _ = est.results_ + + +def _make_panel(n_units=10, n_periods=5, treat_share=0.5, seed=0): + """Create a simple balanced panel for testing.""" + rng = np.random.default_rng(seed) + units = np.arange(n_units) + n_treated = int(n_units * treat_share) + # Two cohorts: half treated in period 3, rest never treated + cohort = np.array([3] * n_treated + [0] * (n_units - n_treated)) + rows = [] + for u in units: + for t in range(1, n_periods + 1): + rows.append( + { + "unit": u, + "time": t, + "cohort": cohort[u], + "y": rng.standard_normal(), + "x1": rng.standard_normal(), + } + ) + return pd.DataFrame(rows) + + +class TestDataPrep: + def test_filter_sample_not_yet_treated(self): + df = _make_panel() + filtered = _filter_sample( + df, + unit="unit", + time="time", + cohort="cohort", + control_group="not_yet_treated", + anticipation=0, + ) + # All treated units should be present (all periods) + treated_units = df[df["cohort"] == 3]["unit"].unique() + assert set(treated_units).issubset(filtered["unit"].unique()) + + def test_filter_sample_never_treated(self): + df = _make_panel() + filtered = _filter_sample( + df, + unit="unit", + time="time", + cohort="cohort", + control_group="never_treated", + anticipation=0, + ) + # Only never-treated (cohort==0) and treated units should remain + assert (filtered["cohort"].isin([0, 3])).all() + + def test_build_interaction_matrix_columns(self): + df = _make_panel() + filtered = _filter_sample(df, "unit", "time", "cohort", "not_yet_treated", anticipation=0) + X_int, col_names, gt_keys = _build_interaction_matrix( + filtered, cohort="cohort", time="time", anticipation=0 + ) + # Each column should be a valid (g, t) pair with t >= g + for g, t in gt_keys: + assert t >= g + + def test_build_interaction_matrix_binary(self): + df = _make_panel() + filtered = _filter_sample(df, "unit", "time", "cohort", "not_yet_treated", anticipation=0) + X_int, col_names, gt_keys = _build_interaction_matrix( + filtered, cohort="cohort", time="time", anticipation=0 + ) + # All values should be 0 or 1 + assert set(np.unique(X_int)).issubset({0, 1}) + + def test_prepare_covariates_exovar(self): + df = _make_panel() + X_cov = _prepare_covariates( + df, + exovar=["x1"], + xtvar=None, + xgvar=None, + cohort="cohort", + time="time", + demean_covariates=True, + groups=[3], + ) + assert X_cov.shape[0] == len(df) + assert X_cov.shape[1] == 1 # just x1 + + def test_prepare_covariates_xtvar_demeaned(self): + df = _make_panel() + X_raw = _prepare_covariates( + df, + exovar=None, + xtvar=["x1"], + xgvar=None, + cohort="cohort", + time="time", + demean_covariates=False, + groups=[3], + ) + X_dem = _prepare_covariates( + df, + exovar=None, + xtvar=["x1"], + xgvar=None, + cohort="cohort", + time="time", + demean_covariates=True, + groups=[3], + ) + # Demeaned version should differ from raw + assert not np.allclose(X_raw, X_dem) + + +class TestWooldridgeDiDFitOLS: + @pytest.fixture + def mpdta(self): + from diff_diff.datasets import load_mpdta + + return load_mpdta() + + def test_fit_returns_results(self, mpdta): + est = WooldridgeDiD() + results = est.fit( + mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat" + ) + assert isinstance(results, WooldridgeDiDResults) + + def test_fit_sets_is_fitted(self, mpdta): + est = WooldridgeDiD() + est.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert est.is_fitted_ + + def test_overall_att_finite(self, mpdta): + est = WooldridgeDiD() + r = est.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert np.isfinite(r.overall_att) + assert np.isfinite(r.overall_se) + assert r.overall_se > 0 + + def test_group_time_effects_populated(self, mpdta): + est = WooldridgeDiD() + r = est.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert len(r.group_time_effects) > 0 + for (g, t), eff in r.group_time_effects.items(): + assert t >= g + assert "att" in eff and "se" in eff + + def test_all_inference_fields_finite(self, mpdta): + """No inference field should be NaN in normal data.""" + est = WooldridgeDiD() + r = est.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert np.isfinite(r.overall_t_stat) + assert np.isfinite(r.overall_p_value) + assert all(np.isfinite(c) for c in r.overall_conf_int) + + def test_never_treated_control_group(self, mpdta): + est = WooldridgeDiD(control_group="never_treated") + r = est.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert len(r.group_time_effects) > 0 + + def test_metadata_correct(self, mpdta): + est = WooldridgeDiD() + r = est.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert r.method == "ols" + assert r.n_obs > 0 + assert r.n_treated_units > 0 + assert r.n_control_units > 0 + + +class TestAggregations: + @pytest.fixture + def fitted(self): + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + est = WooldridgeDiD() + return est.fit(df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + + def test_simple_matches_manual_weighted_average(self, fitted): + """simple ATT must equal manually computed weighted average of ATT(g,t).""" + gt = fitted.group_time_effects + w = fitted._gt_weights + post_keys = [(g, t) for (g, t) in w if t >= g] + w_total = sum(w[k] for k in post_keys) + manual_att = sum(w[k] * gt[k]["att"] for k in post_keys) / w_total + assert abs(fitted.overall_att - manual_att) < 1e-10 + + def test_aggregate_group_keys_match_cohorts(self, fitted): + fitted.aggregate("group") + assert set(fitted.group_effects.keys()) == set(fitted.groups) + + def test_aggregate_event_relative_periods(self, fitted): + fitted.aggregate("event") + for k in fitted.event_study_effects: + assert isinstance(k, (int, np.integer)) + + def test_aggregate_calendar_finite(self, fitted): + fitted.aggregate("calendar") + for t, eff in fitted.calendar_effects.items(): + assert np.isfinite(eff["att"]) + + def test_summary_runs(self, fitted): + s = fitted.summary("simple") + assert "ETWFE" in s or "Wooldridge" in s + + def test_to_dataframe_event(self, fitted): + fitted.aggregate("event") + df = fitted.to_dataframe("event") + assert "relative_period" in df.columns + assert "att" in df.columns + + def test_to_dataframe_gt(self, fitted): + df = fitted.to_dataframe("gt") + assert "cohort" in df.columns + assert "time" in df.columns + assert len(df) == len(fitted.group_time_effects) + + +class TestWooldridgeDiDLogit: + @pytest.fixture + def binary_panel(self): + """Simulated binary outcome panel with known positive ATT.""" + rng = np.random.default_rng(42) + n_units, n_periods = 60, 5 + rows = [] + for u in range(n_units): + cohort = 3 if u < 30 else 0 + for t in range(1, n_periods + 1): + treated = int(cohort > 0 and t >= cohort) + eta = -0.5 + 1.0 * treated + 0.1 * rng.standard_normal() + y = int(rng.random() < 1 / (1 + np.exp(-eta))) + rows.append({"unit": u, "time": t, "cohort": cohort, "y": y}) + return pd.DataFrame(rows) + + def test_logit_fit_runs(self, binary_panel): + est = WooldridgeDiD(method="logit") + r = est.fit(binary_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert isinstance(r, WooldridgeDiDResults) + + def test_logit_att_sign(self, binary_panel): + """ATT should be positive (treatment increases binary outcome).""" + est = WooldridgeDiD(method="logit") + r = est.fit(binary_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert r.overall_att > 0 + + def test_logit_se_positive(self, binary_panel): + est = WooldridgeDiD(method="logit") + r = est.fit(binary_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert r.overall_se > 0 + + def test_logit_method_stored(self, binary_panel): + est = WooldridgeDiD(method="logit") + r = est.fit(binary_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert r.method == "logit" + + +class TestWooldridgeDiDPoisson: + @pytest.fixture + def count_panel(self): + rng = np.random.default_rng(7) + n_units, n_periods = 60, 5 + rows = [] + for u in range(n_units): + cohort = 3 if u < 30 else 0 + for t in range(1, n_periods + 1): + treated = int(cohort > 0 and t >= cohort) + mu = np.exp(0.5 + 0.8 * treated + 0.1 * rng.standard_normal()) + y = rng.poisson(mu) + rows.append({"unit": u, "time": t, "cohort": cohort, "y": float(y)}) + return pd.DataFrame(rows) + + def test_poisson_fit_runs(self, count_panel): + est = WooldridgeDiD(method="poisson") + r = est.fit(count_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert isinstance(r, WooldridgeDiDResults) + + def test_poisson_att_sign(self, count_panel): + est = WooldridgeDiD(method="poisson") + r = est.fit(count_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert r.overall_att > 0 + + def test_poisson_se_positive(self, count_panel): + est = WooldridgeDiD(method="poisson") + r = est.fit(count_panel, outcome="y", unit="unit", time="time", cohort="cohort") + assert r.overall_se > 0 + + +class TestBootstrap: + @pytest.mark.slow + def test_multiplier_bootstrap_ols(self, ci_params): + """Bootstrap SE should be close to analytic SE.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + n_boot = ci_params.bootstrap(50, min_n=19) + est = WooldridgeDiD(n_bootstrap=n_boot, seed=42) + r = est.fit(df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert abs(r.overall_se - r.overall_att) / max(abs(r.overall_att), 1e-8) < 10 + + def test_bootstrap_zero_disables(self): + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + est = WooldridgeDiD(n_bootstrap=0) + r = est.fit(df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert np.isfinite(r.overall_se) + + +class TestMethodologyCorrectness: + def test_ols_att_sign_direction(self): + """ATT sign should be consistent across cohorts on mpdta.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + est = WooldridgeDiD(control_group="never_treated") + r = est.fit(df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + assert np.isfinite(r.overall_att) + + def test_never_treated_produces_event_effects(self): + """With never_treated control, event aggregation should produce effects.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + est = WooldridgeDiD(control_group="never_treated") + r = est.fit(df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + r.aggregate("event") + assert r.event_study_effects is not None + assert len(r.event_study_effects) > 0 + assert all(k >= 0 for k in r.event_study_effects.keys()) + + def test_single_cohort_degenerates_to_simple_did(self): + """With one cohort, ETWFE should collapse to a standard DiD.""" + rng = np.random.default_rng(0) + n = 100 + rows = [] + for u in range(n): + cohort = 2 if u < 50 else 0 + for t in [1, 2]: + treated = int(cohort > 0 and t >= cohort) + y = 1.0 * treated + rng.standard_normal() + rows.append({"unit": u, "time": t, "cohort": cohort, "y": y}) + df = pd.DataFrame(rows) + r = WooldridgeDiD().fit(df, outcome="y", unit="unit", time="time", cohort="cohort") + assert len(r.group_time_effects) == 1 + assert abs(r.overall_att - 1.0) < 0.5 + + def test_aggregation_weights_sum_to_one(self): + """Simple aggregation weights should sum to 1.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + r = WooldridgeDiD().fit( + df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat" + ) + w = r._gt_weights + post_keys = [(g, t) for (g, t) in w if t >= g] + w_total = sum(w[k] for k in post_keys) + norm_weights = [w[k] / w_total for k in post_keys] + assert abs(sum(norm_weights) - 1.0) < 1e-10 + + def test_logit_delta_gradient_matches_finite_difference(self): + """Analytic delta-method gradient is finite and produces non-negative SE.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta().copy() + df["lemp_bin"] = (df["lemp"] > df["lemp"].median()).astype(int) + + est = WooldridgeDiD(method="logit") + results = est.fit( + df, outcome="lemp_bin", unit="countyreal", time="year", cohort="first_treat" + ) + + grad_found = False + for key, cell in results.group_time_effects.items(): + if "_gradient" not in cell: + continue + grad_found = True + analytic_grad = cell["_gradient"] + assert np.all(np.isfinite(analytic_grad)), f"Non-finite gradient at {key}" + assert cell["se"] >= 0, f"Negative SE at {key}" + assert grad_found, "No _gradient entries found in group_time_effects" + + def test_poisson_delta_gradient_finite_check(self): + """Poisson gradient entries are finite and produce non-negative SE.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta().copy() + df["emp_count"] = np.exp(df["lemp"]).round().astype(int) + + est = WooldridgeDiD(method="poisson") + results = est.fit( + df, outcome="emp_count", unit="countyreal", time="year", cohort="first_treat" + ) + + grad_found = False + for key, cell in results.group_time_effects.items(): + if "_gradient" not in cell: + continue + grad_found = True + assert np.all(np.isfinite(cell["_gradient"])), f"Non-finite gradient at {key}" + assert cell["se"] >= 0 + assert grad_found, "No _gradient entries found in group_time_effects" + + def test_ols_etwfe_att_matches_callaway_santanna(self): + """OLS ETWFE ATT(g,t) equals CallawaySantAnna ATT(g,t) (Proposition 3.1).""" + from diff_diff import CallawaySantAnna + from diff_diff.datasets import load_mpdta + + mpdta = load_mpdta() + + etwfe = WooldridgeDiD(method="ols", control_group="not_yet_treated") + cs = CallawaySantAnna(control_group="not_yet_treated") + + er = etwfe.fit(mpdta, outcome="lemp", unit="countyreal", time="year", cohort="first_treat") + cr = cs.fit( + mpdta, outcome="lemp", unit="countyreal", time="year", first_treat="first_treat" + ) + + matched = 0 + for key, effect in er.group_time_effects.items(): + if key in cr.group_time_effects: + cs_att = cr.group_time_effects[key]["effect"] + np.testing.assert_allclose( + effect["att"], + cs_att, + atol=5e-3, + err_msg=f"ATT mismatch at {key}: ETWFE={effect['att']:.4f}, CS={cs_att:.4f}", + ) + matched += 1 + assert matched > 0, "No matching (g,t) keys found between ETWFE and CS" + + +class TestExports: + def test_top_level_import(self): + from diff_diff import ETWFE, WooldridgeDiD + + assert ETWFE is WooldridgeDiD + + def test_alias_etwfe(self): + import diff_diff + + assert hasattr(diff_diff, "ETWFE") + assert diff_diff.ETWFE is diff_diff.WooldridgeDiD + + +class TestAnticipation: + def test_anticipation_includes_pre_treatment_cells(self): + """With anticipation=1, cells include t >= g-1 (one period before treatment).""" + rng = np.random.default_rng(42) + rows = [] + for u in range(40): + cohort = 3 if u < 20 else 0 + for t in range(1, 6): + y = rng.standard_normal() + (1.0 if cohort > 0 and t >= cohort else 0) + rows.append({"unit": u, "time": t, "cohort": cohort, "y": y}) + df = pd.DataFrame(rows) + est = WooldridgeDiD(anticipation=1) + r = est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort") + # With anticipation=1, should have cells for t >= g-1 = 2 + keys = list(r.group_time_effects.keys()) + min_t = min(t for (g, t) in keys) + assert min_t == 2, f"Expected min t=2 with anticipation=1, got {min_t}" + + +class TestXgvarCovariates: + def test_xgvar_fit_runs(self): + """xgvar covariates should not crash and should produce finite results.""" + rng = np.random.default_rng(0) + rows = [] + for u in range(60): + cohort = 3 if u < 30 else 0 + x1 = rng.standard_normal() + for t in range(1, 6): + y = rng.standard_normal() + 0.5 * x1 + rows.append({"unit": u, "time": t, "cohort": cohort, "y": y, "x1": x1}) + df = pd.DataFrame(rows) + est = WooldridgeDiD() + r = est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort", xgvar=["x1"]) + assert np.isfinite(r.overall_att) + assert np.isfinite(r.overall_se) + assert r.overall_se > 0 + + +class TestAllEventuallyTreated: + def test_no_never_treated_not_yet_treated_control(self): + """All units eventually treated, using not_yet_treated control group. + + With not_yet_treated control, the latest cohort's post-treatment + cells are rank-deficient (no controls remain). The estimator drops + those columns, so we check that at least the earlier cohort cells + produce finite ATT effects and the overall ATT is computed from them. + """ + rng = np.random.default_rng(7) + rows = [] + for u in range(200): + # Three cohorts: t=3, t=5, t=8 — wide gaps give plenty of + # not-yet-treated controls for the earlier cohorts. + if u < 70: + cohort = 3 + elif u < 140: + cohort = 5 + else: + cohort = 8 + for t in range(1, 10): + treated = int(t >= cohort) + y = rng.standard_normal() + 1.5 * treated + rows.append({"unit": u, "time": t, "cohort": cohort, "y": y}) + df = pd.DataFrame(rows) + est = WooldridgeDiD(control_group="not_yet_treated") + r = est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort") + # At least some cells should have finite ATT + finite_cells = [k for k, v in r.group_time_effects.items() if np.isfinite(v["att"])] + assert len(finite_cells) > 0 + # Early cohort (g=3) should have identifiable effects + early_finite = [k for k in finite_cells if k[0] == 3] + assert len(early_finite) > 0 + for k in early_finite: + assert r.group_time_effects[k]["att"] > 0 # true effect is 1.5 + + +class TestEmptyCells: + def test_sparse_panel_no_crash(self): + """Panel where some cohort-time cells have few/no obs should not crash.""" + rng = np.random.default_rng(3) + rows = [] + for u in range(80): + cohort = 3 if u < 20 else (5 if u < 40 else 0) + for t in range(1, 7): + y = rng.standard_normal() + (1.0 if cohort > 0 and t >= cohort else 0) + rows.append({"unit": u, "time": t, "cohort": cohort, "y": y}) + df = pd.DataFrame(rows) + est = WooldridgeDiD() + r = est.fit(df, outcome="y", unit="unit", time="time", cohort="cohort") + assert np.isfinite(r.overall_att) + r.aggregate("event") + assert r.event_study_effects is not None + + +class TestMpdtaLogitPoisson: + @pytest.fixture + def mpdta(self): + from diff_diff.datasets import load_mpdta + + return load_mpdta() + + def test_logit_on_mpdta(self, mpdta): + """Logit fit on binary outcome derived from mpdta should produce finite results.""" + df = mpdta.copy() + df["lemp_bin"] = (df["lemp"] > df["lemp"].median()).astype(int) + est = WooldridgeDiD(method="logit") + r = est.fit(df, outcome="lemp_bin", unit="countyreal", time="year", cohort="first_treat") + assert np.isfinite(r.overall_att) + assert np.isfinite(r.overall_se) + assert r.overall_se > 0 + r.aggregate("event") + assert r.event_study_effects is not None + + def test_poisson_on_mpdta(self, mpdta): + """Poisson fit on exp(lemp) should produce finite results.""" + df = mpdta.copy() + df["emp"] = np.exp(df["lemp"]) + est = WooldridgeDiD(method="poisson") + r = est.fit(df, outcome="emp", unit="countyreal", time="year", cohort="first_treat") + assert np.isfinite(r.overall_att) + assert np.isfinite(r.overall_se) + assert r.overall_se > 0 + r.aggregate("simple") + assert np.isfinite(r.overall_att) + + +class TestControlGroupDistinction: + """P0 regression test: never_treated and not_yet_treated must differ.""" + + def test_never_treated_differs_from_not_yet_treated(self): + """On multi-cohort data with never-treated group, the two control + group settings must produce different overall ATT estimates.""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + r_nyt = WooldridgeDiD(control_group="not_yet_treated").fit( + df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat" + ) + r_nt = WooldridgeDiD(control_group="never_treated").fit( + df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat" + ) + assert np.isfinite(r_nyt.overall_att) + assert np.isfinite(r_nt.overall_att) + # They must differ — if they don't, control_group is a no-op + assert r_nyt.overall_att != r_nt.overall_att, ( + f"never_treated ATT ({r_nt.overall_att:.6f}) == not_yet_treated ATT " + f"({r_nyt.overall_att:.6f}); control_group has no effect" + ) + + def test_never_treated_fewer_observations(self): + """never_treated should produce a smaller estimation sample than + not_yet_treated (pre-treatment obs from treated units excluded).""" + from diff_diff.datasets import load_mpdta + + df = load_mpdta() + r_nyt = WooldridgeDiD(control_group="not_yet_treated").fit( + df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat" + ) + r_nt = WooldridgeDiD(control_group="never_treated").fit( + df, outcome="lemp", unit="countyreal", time="year", cohort="first_treat" + ) + assert r_nt.n_obs < r_nyt.n_obs + + +class TestBootstrapWeightsValidation: + def test_invalid_bootstrap_weights_raises(self): + with pytest.raises(ValueError, match="bootstrap_weights"): + WooldridgeDiD(bootstrap_weights="invalid_dist")