From 1a7ad392121593a926f2cad000282636c13f9a48 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 11 May 2026 14:38:44 +0200 Subject: [PATCH 1/2] Add selected loss.py targets to unified calibration --- changelog.d/945.added | 1 + .../calibration/target_config.yaml | 68 ++++++ .../db/create_field_valid_values.py | 1 + policyengine_us_data/db/etl_irs_soi.py | 214 ++++++++++++++++- .../db/etl_national_targets.py | 197 +++++++++++++++- tests/unit/calibration/test_target_config.py | 91 +++++++ tests/unit/test_etl_irs_soi_overlay.py | 223 ++++++++++++++++++ tests/unit/test_etl_national_targets.py | 110 +++++++++ 8 files changed, 902 insertions(+), 3 deletions(-) create mode 100644 changelog.d/945.added diff --git a/changelog.d/945.added b/changelog.d/945.added new file mode 100644 index 000000000..ca604daa2 --- /dev/null +++ b/changelog.d/945.added @@ -0,0 +1 @@ +Added selected legacy `loss.py` target families to unified calibration target ETL and selection. diff --git a/policyengine_us_data/calibration/target_config.yaml b/policyengine_us_data/calibration/target_config.yaml index 50aa05547..837c90d8f 100644 --- a/policyengine_us_data/calibration/target_config.yaml +++ b/policyengine_us_data/calibration/target_config.yaml @@ -48,15 +48,23 @@ include: # REMOVED: is_pregnant — 100% unachievable across all 51 state geos - variable: snap geo_level: state + - variable: household_count + geo_level: state + domain_variable: snap - variable: tanf geo_level: state - variable: adjusted_gross_income geo_level: state + - variable: rent + geo_level: state - variable: spm_unit_count geo_level: state domain_variable: tanf - variable: eitc geo_level: state + - variable: tax_unit_count + geo_level: state + domain_variable: eitc - variable: refundable_ctc geo_level: state - variable: non_refundable_ctc @@ -107,6 +115,42 @@ include: - variable: adjusted_gross_income geo_level: national domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits + - variable: tax_unit_count + geo_level: national + domain_variable: adjusted_gross_income,income_tax_before_credits,irs_employment_income + - variable: irs_employment_income + geo_level: national + domain_variable: adjusted_gross_income,income_tax_before_credits,irs_employment_income + - variable: tax_unit_count + geo_level: national + domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits,irs_employment_income + - variable: irs_employment_income + geo_level: national + domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits,irs_employment_income + - variable: tax_unit_count + geo_level: national + domain_variable: adjusted_gross_income,income_tax_before_credits,pension_income + - variable: pension_income + geo_level: national + domain_variable: adjusted_gross_income,income_tax_before_credits,pension_income + - variable: tax_unit_count + geo_level: national + domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits,pension_income + - variable: pension_income + geo_level: national + domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits,pension_income + - variable: tax_unit_count + geo_level: national + domain_variable: adjusted_gross_income,income_tax_before_credits,social_security + - variable: social_security + geo_level: national + domain_variable: adjusted_gross_income,income_tax_before_credits,social_security + - variable: tax_unit_count + geo_level: national + domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits,social_security + - variable: social_security + geo_level: national + domain_variable: adjusted_gross_income,filing_status,income_tax_before_credits,social_security # === NATIONAL — wealth target (Federal Reserve SCF, no filer filter) === - variable: net_worth @@ -126,6 +170,9 @@ include: domain_variable: medicare_enrolled - variable: medicare_part_b_premium geo_level: national + - variable: medicare_part_b_premium + geo_level: national + domain_variable: age - variable: real_estate_taxes geo_level: national - variable: rent @@ -149,6 +196,9 @@ include: - variable: spm_unit_count geo_level: national domain_variable: tanf + - variable: household_count + geo_level: national + domain_variable: spm_unit_energy_subsidy_reported - variable: tip_income geo_level: national - variable: unemployment_compensation @@ -280,6 +330,12 @@ include: - variable: unemployment_compensation geo_level: national domain_variable: unemployment_compensation + - variable: refundable_american_opportunity_credit + geo_level: national + domain_variable: refundable_american_opportunity_credit + - variable: education_tax_credits + geo_level: national + domain_variable: education_tax_credits # === NATIONAL — IRS SOI filer count targets (restored: |rel_err| < 10%) === - variable: tax_unit_count @@ -311,6 +367,18 @@ include: - variable: tax_unit_count geo_level: national domain_variable: total_self_employment_income + - variable: tax_unit_count + geo_level: national + domain_variable: refundable_american_opportunity_credit + - variable: tax_unit_count + geo_level: national + domain_variable: education_tax_credits + - variable: tax_unit_count + geo_level: national + domain_variable: real_estate_taxes,tax_unit_itemizes + - variable: tax_unit_count + geo_level: state + domain_variable: real_estate_taxes,tax_unit_itemizes # === NATIONAL — identity / population count targets from old loss.py === - variable: person_count diff --git a/policyengine_us_data/db/create_field_valid_values.py b/policyengine_us_data/db/create_field_valid_values.py index 6795132bc..649408e52 100644 --- a/policyengine_us_data/db/create_field_valid_values.py +++ b/policyengine_us_data/db/create_field_valid_values.py @@ -70,6 +70,7 @@ def populate_field_valid_values(session: Session) -> None: source_values = [ ("source", "Census ACS S0101", "survey"), ("source", "IRS SOI", "administrative"), + ("source", "IRS EITC Central", "administrative"), ("source", "CMS Marketplace", "administrative"), ("source", "CMS 2024 OEP state metal status PUF", "administrative"), ("source", "CMS Medicaid", "administrative"), diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index e89658c0e..c17b6aeb0 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -202,6 +202,11 @@ def _skip_coarse_state_agi_person_count_target(geo_type: str, agi_stub: int) -> "adjusted_gross_income": "adjusted_gross_income", "count": "tax_unit_count", } +SOI_TAXABLE_AGI_DOMAIN_TARGET_VARIABLES = { + "employment_income": "irs_employment_income", + "total_pension_income": "pension_income", + "total_social_security": "social_security", +} SOI_FILING_STATUS_CONSTRAINTS = { "Single": ("==", "SINGLE"), "Head of Household": ("==", "HEAD_OF_HOUSEHOLD"), @@ -819,6 +824,79 @@ def _get_or_create_national_taxable_agi_filing_status_stratum( return stratum +def _get_or_create_national_taxable_agi_domain_filing_status_stratum( + session: Session, + national_filer_stratum_id: int, + *, + domain_variable: str, + agi_lower_bound: float, + agi_upper_bound: float, + filing_status: str, +) -> Stratum: + note = ( + "National taxable filers, AGI >= " + f"{agi_lower_bound}, AGI < {agi_upper_bound}, {domain_variable} > 0" + ) + filing_constraint = SOI_FILING_STATUS_CONSTRAINTS.get(filing_status) + if filing_constraint is not None: + note += f", filing status = {filing_status}" + + stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == national_filer_stratum_id, + Stratum.notes == note, + ) + ).first() + if stratum: + return stratum + + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="income_tax_before_credits", + operation=">", + value="0", + ), + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation=">=", + value=str(agi_lower_bound), + ), + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation="<", + value=str(agi_upper_bound), + ), + StratumConstraint( + constraint_variable=domain_variable, + operation=">", + value="0", + ), + ] + if filing_constraint is not None: + operation, value = filing_constraint + constraints.append( + StratumConstraint( + constraint_variable="filing_status", + operation=operation, + value=value, + ) + ) + + stratum = Stratum( + parent_stratum_id=national_filer_stratum_id, + notes=note, + ) + stratum.constraints_rel.extend(constraints) + session.add(stratum) + session.flush() + return stratum + + def load_national_geography_ctc_targets( session: Session, national_filer_stratum_id: int, geography_year: int ) -> None: @@ -1002,6 +1080,46 @@ def load_national_taxable_agi_filing_status_targets( ) +def load_national_taxable_agi_domain_filing_status_targets( + session: Session, + national_filer_stratum_id: int, + target_year: int, +) -> None: + """Create positive-domain SOI income targets by AGI band and filing status.""" + soi = select_best_tracked_soi_rows(load_tracked_soi_targets(), target_year) + rows = soi[ + soi["Variable"].isin(SOI_TAXABLE_AGI_DOMAIN_TARGET_VARIABLES) + & (soi["Taxable only"]) + & (soi["AGI upper bound"] > 10_000) + ].copy() + + for _, row in rows.iterrows(): + source_variable = row["Variable"] + target_variable = SOI_TAXABLE_AGI_DOMAIN_TARGET_VARIABLES[source_variable] + stratum = _get_or_create_national_taxable_agi_domain_filing_status_stratum( + session, + national_filer_stratum_id, + domain_variable=target_variable, + agi_lower_bound=float(row["AGI lower bound"]), + agi_upper_bound=float(row["AGI upper bound"]), + filing_status=row["Filing status"], + ) + notes = ( + f"Publication 1304 {row['SOI table']} taxable AGI/filing-status " + f"{source_variable} target " + f"(source year {int(row['Year'])}, row {int(row['XLSX row'])})" + ) + _upsert_target( + session, + stratum_id=stratum.stratum_id, + variable="tax_unit_count" if bool(row["Count"]) else target_variable, + period=int(row["Year"]), + value=float(row["Value"]), + source="IRS SOI", + notes=notes, + ) + + def load_national_workbook_soi_targets( session: Session, national_filer_stratum_id: int, target_year: int ) -> None: @@ -1051,6 +1169,78 @@ def load_national_workbook_soi_targets( ) +def load_state_eitc_claim_count_targets( + session: Session, + filer_strata: dict, + target_year: int, +) -> None: + """Create state EITC claimant-count targets from IRS EITC Central controls.""" + path = CALIBRATION_FOLDER / "eitc_claim_controls.csv" + if not path.exists(): + return + + controls = pd.read_csv(path, comment="#") + years = sorted(int(year) for year in controls["year"].unique()) + prior_years = [year for year in years if year <= int(target_year)] + data_year = max(prior_years) if prior_years else min(years) + state_rows = controls[ + (controls["year"].astype(int) == data_year) + & controls["GEO_ID"].str.startswith("0400000US") + ].copy() + + for row in state_rows.itertuples(index=False): + state_fips = int(str(row.GEO_ID)[-2:]) + parent_stratum_id = filer_strata["state"].get(state_fips) + if parent_stratum_id is None: + continue + + note = f"State FIPS {state_fips} EITC claimants" + stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.notes == note, + ) + ).first() + if not stratum: + stratum = Stratum( + parent_stratum_id=parent_stratum_id, + notes=note, + ) + stratum.constraints_rel.extend( + [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(state_fips), + ), + StratumConstraint( + constraint_variable="eitc", + operation=">", + value="0", + ), + ] + ) + session.add(stratum) + session.flush() + + _upsert_target( + session, + stratum_id=stratum.stratum_id, + variable="tax_unit_count", + period=data_year, + value=float(row.Returns), + source="IRS EITC Central", + notes=( + f"IRS EITC Central state EITC return count (source year {data_year})" + ), + ) + + def extract_state_fine_agi_data(year: int) -> pd.DataFrame: """Download the state-level SOI file (in55cmcsv) with stubs 9 and 10.""" year_prefix = _year_prefix(year) @@ -1399,7 +1589,12 @@ def transform_soi_data(raw_df): return converted -def load_soi_data(long_dfs, year, national_year: Optional[int] = None): +def load_soi_data( + long_dfs, + year, + national_year: Optional[int] = None, + target_year: Optional[int] = None, +): """Load a list of databases into the db, critically dependent on order""" DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" @@ -1519,10 +1714,20 @@ def load_soi_data(long_dfs, year, national_year: Optional[int] = None): filer_strata["national"], national_year, ) + load_national_taxable_agi_domain_filing_status_targets( + session, + filer_strata["national"], + national_year, + ) load_national_fine_agi_targets(session, filer_strata["national"], national_year) load_national_ltcg_agi_targets(session, filer_strata["national"], national_year) load_state_fine_agi_targets(session, filer_strata, year) + load_state_eitc_claim_count_targets( + session, + filer_strata, + target_year or national_year or year, + ) session.commit() # Load EITC data -------------------------------------------------------- @@ -2027,7 +2232,12 @@ def add_lag_arg(parser): long_dfs = transform_soi_data(raw_df) # Load --------------------- - load_soi_data(long_dfs, geography_year, national_year=national_year) + load_soi_data( + long_dfs, + geography_year, + national_year=national_year, + target_year=dataset_year, + ) if __name__ == "__main__": diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index f5a817ddb..6c42d6c85 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -3,7 +3,7 @@ from sqlmodel import Session, create_engine, select import pandas as pd -from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.storage import CALIBRATION_FOLDER, STORAGE_FOLDER from policyengine_us_data.db.create_database_tables import ( Stratum, StratumConstraint, @@ -23,11 +23,196 @@ from policyengine_us_data.utils.db import ( DEFAULT_YEAR, etl_argparser, + get_geographic_strata, ) WIC_NATIONAL_ANNUAL_SUMMARY_SOURCE = ( "https://www.fns.usda.gov/sites/default/files/resource-files/wisummary-4.xlsx" ) +MEDICARE_PART_B_AGE_TARGET_YEAR = 2024 + + +def _best_available_yeared_csv(stem: str, requested_year: int): + paths_by_year = {} + for path in CALIBRATION_FOLDER.glob(f"{stem}_*.csv"): + try: + year = int(path.stem.removeprefix(f"{stem}_")) + except ValueError: + continue + paths_by_year[year] = path + + if not paths_by_year: + return None, None + + years = sorted(paths_by_year) + prior_years = [year for year in years if year <= int(requested_year)] + year = max(prior_years) if prior_years else years[0] + return paths_by_year[year], year + + +def _upsert_baseline_target( + session: Session, + *, + stratum_id: int, + variable: str, + period: int, + value: float, + source: str, + notes: str, +) -> None: + existing_target = session.exec( + select(Target).where( + Target.stratum_id == stratum_id, + Target.variable == variable, + Target.period == period, + Target.reform_id == 0, + ) + ).first() + if existing_target: + existing_target.value = value + existing_target.source = source + existing_target.notes = notes + existing_target.active = True + return + + session.add( + Target( + stratum_id=stratum_id, + variable=variable, + period=period, + value=value, + active=True, + source=source, + notes=notes, + ) + ) + + +def extract_state_acs_housing_cost_targets(year: int = DEFAULT_YEAR): + """Load the best available state ACS housing-cost target file.""" + path, data_year = _best_available_yeared_csv("acs_housing_costs", year) + if path is None: + return pd.DataFrame(), None + + targets = pd.read_csv(path, dtype={"state_fips": str}) + return targets, data_year + + +def load_state_acs_rent_targets(targets: pd.DataFrame, year: int) -> None: + """Load state aggregate contract rent targets from ACS housing-cost data.""" + if targets.empty: + return + + database_url = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + engine = create_engine(database_url) + + with Session(engine) as session: + geo_strata = get_geographic_strata(session) + for row in targets.itertuples(index=False): + state_fips = int(str(row.state_fips)) + stratum_id = geo_strata["state"].get(state_fips) + if stratum_id is None: + continue + + _upsert_baseline_target( + session, + stratum_id=stratum_id, + variable="rent", + period=int(year), + value=float(row.annual_contract_rent), + source="PolicyEngine", + notes=( + "Census ACS state aggregate contract rent, annualized from " + "monthly ACS aggregate contract rent | Source: Census ACS " + f"{year} 1-year table B25060" + ), + ) + + session.commit() + + +def extract_medicare_part_b_age_targets() -> pd.DataFrame: + """Load Medicare Part B premium age-bucket targets.""" + path = CALIBRATION_FOLDER / "healthcare_spending.csv" + if not path.exists(): + return pd.DataFrame() + + targets = pd.read_csv(path) + return targets.loc[:, ~targets.columns.duplicated()].copy() + + +def load_medicare_part_b_age_targets(targets: pd.DataFrame) -> None: + """Load national Medicare Part B premium targets by 10-year age bucket.""" + if targets.empty: + return + + database_url = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + engine = create_engine(database_url) + + targets = targets.copy() + targets["age_10_year_lower_bound"] = targets["age_10_year_lower_bound"].astype(int) + top_age_lower_bound = int(targets["age_10_year_lower_bound"].max()) + + with Session(engine) as session: + us_stratum = session.exec( + select(Stratum).where(Stratum.parent_stratum_id.is_(None)) + ).first() + if not us_stratum: + raise ValueError( + "National stratum not found. Run create_initial_strata.py first." + ) + + for _, row in targets.iterrows(): + age_lower_bound = int(row["age_10_year_lower_bound"]) + is_top_bucket = age_lower_bound == top_age_lower_bound + if is_top_bucket: + note = f"National people age {age_lower_bound}+" + else: + note = f"National people age {age_lower_bound}-{age_lower_bound + 9}" + + stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == us_stratum.stratum_id, + Stratum.notes == note, + ) + ).first() + if not stratum: + stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + notes=note, + ) + stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="age", + operation=">=", + value=str(age_lower_bound), + ) + ) + if not is_top_bucket: + stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="age", + operation="<", + value=str(age_lower_bound + 10), + ) + ) + session.add(stratum) + session.flush() + + _upsert_baseline_target( + session, + stratum_id=stratum.stratum_id, + variable="medicare_part_b_premium", + period=MEDICARE_PART_B_AGE_TARGET_YEAR, + value=float(row["medicare_part_b_premiums"]), + source="PolicyEngine", + notes=( + "Legacy healthcare_spending.csv Medicare Part B premium " + "age-bucket target" + ), + ) + + session.commit() def extract_national_targets(year: int = DEFAULT_YEAR): @@ -867,6 +1052,16 @@ def main(): tax_expenditure_df, conditional_targets, ) + state_acs_targets, state_acs_year = extract_state_acs_housing_cost_targets( + year=time_period + ) + if state_acs_year is not None: + print("Loading state ACS rent targets...") + load_state_acs_rent_targets(state_acs_targets, state_acs_year) + + medicare_part_b_age_targets = extract_medicare_part_b_age_targets() + print("Loading Medicare Part B age-bucket targets...") + load_medicare_part_b_age_targets(medicare_part_b_age_targets) print("\nETL pipeline complete!") diff --git a/tests/unit/calibration/test_target_config.py b/tests/unit/calibration/test_target_config.py index c48af09ea..aed916a95 100644 --- a/tests/unit/calibration/test_target_config.py +++ b/tests/unit/calibration/test_target_config.py @@ -471,6 +471,97 @@ def test_training_config_includes_wic_national_targets(self): "domain_variable": "wic", } in include_rules + def test_training_config_includes_ported_loss_py_target_families(self): + config = load_target_config( + str( + Path(__file__).resolve().parents[3] + / "policyengine_us_data" + / "calibration" + / "target_config.yaml" + ) + ) + + include_rules = config["include"] + assert { + "variable": "household_count", + "geo_level": "state", + "domain_variable": "snap", + } in include_rules + assert { + "variable": "household_count", + "geo_level": "national", + "domain_variable": "spm_unit_energy_subsidy_reported", + } in include_rules + assert {"variable": "rent", "geo_level": "state"} in include_rules + assert { + "variable": "tax_unit_count", + "geo_level": "state", + "domain_variable": "eitc", + } in include_rules + assert { + "variable": "tax_unit_count", + "geo_level": "state", + "domain_variable": "real_estate_taxes,tax_unit_itemizes", + } in include_rules + assert { + "variable": "medicare_part_b_premium", + "geo_level": "national", + "domain_variable": "age", + } in include_rules + assert { + "variable": "refundable_american_opportunity_credit", + "geo_level": "national", + "domain_variable": "refundable_american_opportunity_credit", + } in include_rules + assert { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": "education_tax_credits", + } in include_rules + + def test_training_config_includes_soi_income_agi_grid_targets(self): + config = load_target_config( + str( + Path(__file__).resolve().parents[3] + / "policyengine_us_data" + / "calibration" + / "target_config.yaml" + ) + ) + + include_rules = config["include"] + variables = [ + "irs_employment_income", + "pension_income", + "social_security", + ] + for variable in variables: + all_domain = f"adjusted_gross_income,income_tax_before_credits,{variable}" + filing_status_domain = ( + "adjusted_gross_income,filing_status," + f"income_tax_before_credits,{variable}" + ) + assert { + "variable": variable, + "geo_level": "national", + "domain_variable": all_domain, + } in include_rules + assert { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": all_domain, + } in include_rules + assert { + "variable": variable, + "geo_level": "national", + "domain_variable": filing_status_domain, + } in include_rules + assert { + "variable": "tax_unit_count", + "geo_level": "national", + "domain_variable": filing_status_domain, + } in include_rules + class TestCalibrationPackageRoundTrip: def test_round_trip(self, sample_targets, tmp_path): diff --git a/tests/unit/test_etl_irs_soi_overlay.py b/tests/unit/test_etl_irs_soi_overlay.py index cb9c610fd..b0f12586d 100644 --- a/tests/unit/test_etl_irs_soi_overlay.py +++ b/tests/unit/test_etl_irs_soi_overlay.py @@ -26,8 +26,10 @@ load_national_geography_ctc_agi_targets, load_national_geography_ctc_targets, load_national_ltcg_agi_targets, + load_national_taxable_agi_domain_filing_status_targets, load_national_taxable_agi_filing_status_targets, load_national_workbook_soi_targets, + load_state_eitc_claim_count_targets, ) @@ -860,3 +862,224 @@ def test_load_national_taxable_agi_filing_status_targets_creates_structured_rows assert ("income_tax_before_credits", ">", "0") in count_constraints assert ("adjusted_gross_income", ">=", "20000.0") in count_constraints assert ("adjusted_gross_income", "<", "25000.0") in count_constraints + + +def test_load_national_taxable_agi_domain_filing_status_targets_creates_structured_rows( + monkeypatch, tmp_path +): + db_uri, engine = _create_test_engine(tmp_path) + soi_rows = pd.DataFrame( + [ + { + "Year": 2023, + "SOI table": "Table 1.4", + "XLSX column": "G", + "XLSX row": 19, + "Variable": "employment_income", + "Filing status": "All", + "AGI lower bound": 50_000.0, + "AGI upper bound": 75_000.0, + "Count": False, + "Taxable only": True, + "Full population": False, + "Value": 1_000_000.0, + }, + { + "Year": 2023, + "SOI table": "Table 1.4", + "XLSX column": "F", + "XLSX row": 19, + "Variable": "employment_income", + "Filing status": "All", + "AGI lower bound": 50_000.0, + "AGI upper bound": 75_000.0, + "Count": True, + "Taxable only": True, + "Full population": False, + "Value": 2_000.0, + }, + { + "Year": 2023, + "SOI table": "Table 1.4", + "XLSX column": "O", + "XLSX row": 18, + "Variable": "total_pension_income", + "Filing status": "Head of Household", + "AGI lower bound": 30_000.0, + "AGI upper bound": 40_000.0, + "Count": False, + "Taxable only": True, + "Full population": False, + "Value": 3_000_000.0, + }, + { + "Year": 2023, + "SOI table": "Table 1.4", + "XLSX column": "N", + "XLSX row": 18, + "Variable": "total_pension_income", + "Filing status": "Head of Household", + "AGI lower bound": 30_000.0, + "AGI upper bound": 40_000.0, + "Count": True, + "Taxable only": True, + "Full population": False, + "Value": 4_000.0, + }, + { + "Year": 2023, + "SOI table": "Table 1.4", + "XLSX column": "G", + "XLSX row": 12, + "Variable": "employment_income", + "Filing status": "All", + "AGI lower bound": 1.0, + "AGI upper bound": 10_000.0, + "Count": False, + "Taxable only": True, + "Full population": False, + "Value": 999.0, + }, + ] + ) + monkeypatch.setattr( + "policyengine_us_data.db.etl_irs_soi.load_tracked_soi_targets", + lambda: soi_rows, + ) + + with Session(engine) as session: + national_filer_stratum = _create_national_filer_stratum(session) + load_national_taxable_agi_domain_filing_status_targets( + session, + national_filer_stratum.stratum_id, + target_year=2024, + ) + session.commit() + + builder = UnifiedMatrixBuilder(db_uri=db_uri, time_period=2024) + rows = builder._query_targets( + { + "variables": ["irs_employment_income", "pension_income", "tax_unit_count"], + "domain_variables": [ + "adjusted_gross_income,income_tax_before_credits,irs_employment_income", + "adjusted_gross_income,filing_status,income_tax_before_credits,pension_income", + ], + } + ) + + assert set(rows["variable"]) == { + "irs_employment_income", + "pension_income", + "tax_unit_count", + } + assert set(rows["value"].astype(float)) == { + 1_000_000.0, + 2_000.0, + 3_000_000.0, + 4_000.0, + } + assert 999.0 not in set(rows["value"].astype(float)) + + with engine.connect() as conn: + constraints = conn.execute( + text( + """ + SELECT tv.variable, sc.constraint_variable, sc.operation, sc.value + FROM target_overview tv + JOIN stratum_constraints sc ON tv.stratum_id = sc.stratum_id + WHERE tv.variable IN ('irs_employment_income', 'pension_income') + ORDER BY tv.variable, sc.constraint_variable + """ + ) + ).fetchall() + + constraint_set = { + (target_variable, variable, operation, constraint_value) + for target_variable, variable, operation, constraint_value in constraints + } + assert ( + "irs_employment_income", + "irs_employment_income", + ">", + "0", + ) in constraint_set + assert ( + "pension_income", + "filing_status", + "==", + "HEAD_OF_HOUSEHOLD", + ) in constraint_set + assert ( + "pension_income", + "pension_income", + ">", + "0", + ) in constraint_set + + +def test_load_state_eitc_claim_count_targets_creates_state_rows(monkeypatch, tmp_path): + db_uri, engine = _create_test_engine(tmp_path) + calibration_dir = tmp_path / "calibration_targets" + calibration_dir.mkdir() + (calibration_dir / "eitc_claim_controls.csv").write_text( + "year,GEO_ID,Returns,Amount\n" + "2024,0100000US,1000,2000\n" + "2024,0400000US06,123,456\n" + ) + monkeypatch.setattr( + "policyengine_us_data.db.etl_irs_soi.CALIBRATION_FOLDER", + calibration_dir, + ) + + with Session(engine) as session: + state_geo = Stratum(notes="California") + state_geo.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="6", + ) + ] + session.add(state_geo) + session.commit() + session.refresh(state_geo) + + state_filer = Stratum( + parent_stratum_id=state_geo.stratum_id, + notes="State FIPS 6 - Tax Filers", + ) + state_filer.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="6", + ), + ] + session.add(state_filer) + session.commit() + session.refresh(state_filer) + + load_state_eitc_claim_count_targets( + session, + {"state": {6: state_filer.stratum_id}}, + target_year=2024, + ) + session.commit() + + builder = UnifiedMatrixBuilder(db_uri=db_uri, time_period=2024) + rows = builder._query_targets( + { + "variables": ["tax_unit_count"], + "domain_variables": ["eitc"], + } + ) + + assert len(rows) == 1 + assert rows.iloc[0]["geo_level"] == "state" + assert rows.iloc[0]["geographic_id"] == "6" + assert float(rows.iloc[0]["value"]) == 123.0 diff --git a/tests/unit/test_etl_national_targets.py b/tests/unit/test_etl_national_targets.py index e7a504fbe..d7e3ebb41 100644 --- a/tests/unit/test_etl_national_targets.py +++ b/tests/unit/test_etl_national_targets.py @@ -1,6 +1,7 @@ import inspect import pandas as pd +from sqlalchemy import text from sqlmodel import Session, select from policyengine_us_data.db import etl_national_targets @@ -11,8 +12,11 @@ create_database, ) from policyengine_us_data.db.etl_national_targets import ( + MEDICARE_PART_B_AGE_TARGET_YEAR, extract_national_targets, + load_medicare_part_b_age_targets, load_national_targets, + load_state_acs_rent_targets, ) @@ -235,6 +239,112 @@ def test_load_national_targets_supports_liheap_household_counts(tmp_path, monkey assert liheap_target.value == 5_876_646 +def test_load_state_acs_rent_targets_creates_state_rows(tmp_path, monkeypatch): + calibration_dir = tmp_path / "calibration" + calibration_dir.mkdir() + db_uri = f"sqlite:///{calibration_dir / 'policy_data.db'}" + engine = create_database(db_uri) + + with Session(engine) as session: + _make_stratum(session, notes="United States") + _make_stratum( + session, + notes="California", + constraints=[ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="6", + ) + ], + ) + + monkeypatch.setattr( + "policyengine_us_data.db.etl_national_targets.STORAGE_FOLDER", + tmp_path, + ) + + targets = pd.DataFrame( + [ + { + "state_code": "CA", + "state_fips": "06", + "annual_contract_rent": 143_291_068_800, + "real_estate_taxes": 52_872_735_400, + } + ] + ) + load_state_acs_rent_targets(targets, year=2024) + + with Session(engine) as session: + target = session.exec( + select(Target).where( + Target.variable == "rent", + Target.period == 2024, + ) + ).first() + assert target is not None + assert target.value == 143_291_068_800 + assert target.source == "PolicyEngine" + assert "Census ACS 2024 1-year table B25060" in target.notes + + +def test_load_medicare_part_b_age_targets_creates_age_domain_rows( + tmp_path, monkeypatch +): + calibration_dir = tmp_path / "calibration" + calibration_dir.mkdir() + db_uri = f"sqlite:///{calibration_dir / 'policy_data.db'}" + engine = create_database(db_uri) + + with Session(engine) as session: + _make_stratum(session, notes="United States") + + monkeypatch.setattr( + "policyengine_us_data.db.etl_national_targets.STORAGE_FOLDER", + tmp_path, + ) + + targets = pd.DataFrame( + [ + { + "age_10_year_lower_bound": 70, + "medicare_part_b_premiums": 54_002_252_445.0, + }, + { + "age_10_year_lower_bound": 80, + "medicare_part_b_premiums": 24_692_726_700.0, + }, + ] + ) + load_medicare_part_b_age_targets(targets) + + with Session(engine) as session: + rows = session.exec( + select(Target).where(Target.variable == "medicare_part_b_premium") + ).all() + assert len(rows) == 2 + assert {row.period for row in rows} == {MEDICARE_PART_B_AGE_TARGET_YEAR} + + with engine.connect() as conn: + overview = conn.execute( + text( + """ + SELECT variable, domain_variable, value + FROM target_overview + WHERE variable = 'medicare_part_b_premium' + ORDER BY value + """ + ) + ).fetchall() + + assert {row.domain_variable for row in overview} == {"age"} + assert {float(row.value) for row in overview} == { + 54_002_252_445.0, + 24_692_726_700.0, + } + + def test_extract_national_targets_drops_survey_spm_targets(): targets = extract_national_targets(year=2024) direct_sum_variables = { From 63feccf605e4fc15006f831f0e23fc2a269ca010 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 14 May 2026 08:16:44 -0400 Subject: [PATCH 2/2] Fix SOI domain target uprating --- policyengine_us_data/db/etl_irs_soi.py | 7 ++++--- tests/unit/test_etl_irs_soi_overlay.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index c17b6aeb0..3d10d8b2d 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -31,6 +31,7 @@ save_bytes, ) from policyengine_us_data.utils.soi import ( + get_soi, get_tracked_soi_row, load_tracked_soi_targets, select_best_tracked_soi_rows, @@ -1086,7 +1087,7 @@ def load_national_taxable_agi_domain_filing_status_targets( target_year: int, ) -> None: """Create positive-domain SOI income targets by AGI band and filing status.""" - soi = select_best_tracked_soi_rows(load_tracked_soi_targets(), target_year) + soi = get_soi(target_year) rows = soi[ soi["Variable"].isin(SOI_TAXABLE_AGI_DOMAIN_TARGET_VARIABLES) & (soi["Taxable only"]) @@ -1113,7 +1114,7 @@ def load_national_taxable_agi_domain_filing_status_targets( session, stratum_id=stratum.stratum_id, variable="tax_unit_count" if bool(row["Count"]) else target_variable, - period=int(row["Year"]), + period=int(target_year), value=float(row["Value"]), source="IRS SOI", notes=notes, @@ -1717,7 +1718,7 @@ def load_soi_data( load_national_taxable_agi_domain_filing_status_targets( session, filer_strata["national"], - national_year, + target_year or national_year, ) load_national_fine_agi_targets(session, filer_strata["national"], national_year) load_national_ltcg_agi_targets(session, filer_strata["national"], national_year) diff --git a/tests/unit/test_etl_irs_soi_overlay.py b/tests/unit/test_etl_irs_soi_overlay.py index b0f12586d..8f80a04a3 100644 --- a/tests/unit/test_etl_irs_soi_overlay.py +++ b/tests/unit/test_etl_irs_soi_overlay.py @@ -942,9 +942,14 @@ def test_load_national_taxable_agi_domain_filing_status_targets_creates_structur }, ] ) + + def fake_get_soi(year: int) -> pd.DataFrame: + assert year == 2024 + return soi_rows + monkeypatch.setattr( - "policyengine_us_data.db.etl_irs_soi.load_tracked_soi_targets", - lambda: soi_rows, + "policyengine_us_data.db.etl_irs_soi.get_soi", + fake_get_soi, ) with Session(engine) as session: @@ -978,6 +983,7 @@ def test_load_national_taxable_agi_domain_filing_status_targets_creates_structur 3_000_000.0, 4_000.0, } + assert set(rows["period"].astype(int)) == {2024} assert 999.0 not in set(rows["value"].astype(float)) with engine.connect() as conn: