Skip to content

Commit dfe2b48

Browse files
MaxGhenisbaogorek
authored andcommitted
Fix AGI-weighted geography targeting
1 parent b5f3e0e commit dfe2b48

7 files changed

Lines changed: 138 additions & 25 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix district AGI geography assignment to match target shares and use the requested calibration database when loading district AGI targets.

policyengine_us_data/calibration/clone_and_assign.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,28 @@ def load_global_block_distribution():
6868

6969

7070
def _build_agi_block_probs(cds, pop_probs, cd_agi_targets):
71-
"""Multiply population block probs by CD AGI target weights."""
71+
"""Reweight block probabilities to match district AGI target shares.
72+
73+
District totals should be proportional to ``cd_agi_targets``, while
74+
block shares within each district should preserve the original
75+
population-weighted distribution.
76+
"""
7277
agi_weights = np.array([cd_agi_targets.get(cd, 0.0) for cd in cds])
7378
agi_weights = np.maximum(agi_weights, 0.0)
7479
if agi_weights.sum() == 0:
7580
return pop_probs
76-
agi_probs = pop_probs * agi_weights
81+
82+
district_pop_mass = (
83+
pd.Series(pop_probs, copy=False).groupby(cds).transform("sum").to_numpy()
84+
)
85+
agi_probs = np.divide(
86+
pop_probs * agi_weights,
87+
district_pop_mass,
88+
out=np.zeros_like(pop_probs, dtype=np.float64),
89+
where=district_pop_mass > 0,
90+
)
91+
if agi_probs.sum() == 0:
92+
return pop_probs
7793
return agi_probs / agi_probs.sum()
7894

7995

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -931,28 +931,19 @@ def run_calibration(
931931
time_period,
932932
)
933933

934+
db_uri = f"sqlite:///{db_path}"
935+
builder = UnifiedMatrixBuilder(
936+
db_uri=db_uri,
937+
time_period=time_period,
938+
)
939+
934940
# Compute base household AGI for conditional geographic assignment
935941
base_agi = sim.calculate("adjusted_gross_income", map_to="household").values.astype(
936942
np.float64
937943
)
938944

939945
# Load CD-level AGI targets from database
940-
import sqlite3
941-
942-
from policyengine_us_data.storage import STORAGE_FOLDER
943-
944-
db_path = str(STORAGE_FOLDER / "calibration" / "policy_data.db")
945-
conn = sqlite3.connect(db_path)
946-
rows = conn.execute(
947-
"SELECT sc.value, t.value "
948-
"FROM targets t "
949-
"JOIN stratum_constraints sc ON t.stratum_id = sc.stratum_id "
950-
"WHERE t.variable = 'adjusted_gross_income' "
951-
"AND sc.constraint_variable = 'congressional_district_geoid' "
952-
"AND t.active = 1"
953-
).fetchall()
954-
conn.close()
955-
cd_agi_targets = {str(row[0]): float(row[1]) for row in rows}
946+
cd_agi_targets = builder.get_district_agi_targets()
956947
logger.info(
957948
"Loaded %d CD AGI targets for conditional assignment",
958949
len(cd_agi_targets),
@@ -1033,12 +1024,7 @@ def run_calibration(
10331024
# Step 6: Build sparse calibration matrix
10341025
do_rerandomize = not skip_takeup_rerandomize
10351026
t_matrix = time.time()
1036-
db_uri = f"sqlite:///{db_path}"
1037-
builder = UnifiedMatrixBuilder(
1038-
db_uri=db_uri,
1039-
time_period=time_period,
1040-
dataset_path=dataset_for_matrix,
1041-
)
1027+
builder.dataset_path = dataset_for_matrix
10421028
targets_df, X_sparse, target_names = builder.build_matrix(
10431029
geography=geography,
10441030
sim=sim,

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,19 @@ def _query_targets(self, target_filter: dict) -> pd.DataFrame:
16211621
params={"time_period": self.time_period},
16221622
)
16231623

1624+
def get_district_agi_targets(self) -> Dict[str, float]:
1625+
"""Return current-law district AGI targets for geography assignment."""
1626+
targets_df = self._query_targets({"variables": ["adjusted_gross_income"]})
1627+
district_rows = targets_df[
1628+
(targets_df["geo_level"] == "district")
1629+
& (targets_df["reform_id"] == 0)
1630+
& (targets_df["domain_variable"].fillna("") == "")
1631+
]
1632+
return {
1633+
str(row["geographic_id"]): float(row["value"])
1634+
for _, row in district_rows.iterrows()
1635+
}
1636+
16241637
# ---------------------------------------------------------------
16251638
# Uprating
16261639
# ---------------------------------------------------------------

tests/unit/calibration/test_clone_and_assign.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from policyengine_us_data.calibration.clone_and_assign import (
1313
GeographyAssignment,
14+
_build_agi_block_probs,
1415
load_global_block_distribution,
1516
assign_random_geography,
1617
double_geography_for_puf,
@@ -87,6 +88,25 @@ def test_state_fips_extracted(self, tmp_path):
8788

8889

8990
class TestAssignRandomGeography:
91+
def test_build_agi_block_probs_matches_district_target_shares(self):
92+
cds = np.array(["101", "101", "102", "102"])
93+
pop_probs = np.array([0.45, 0.45, 0.05, 0.05], dtype=np.float64)
94+
agi_targets = {"101": 1.0, "102": 3.0}
95+
96+
agi_probs = _build_agi_block_probs(cds, pop_probs, agi_targets)
97+
98+
by_cd = {cd: agi_probs[cds == cd].sum() for cd in np.unique(cds)}
99+
np.testing.assert_allclose(by_cd["101"], 0.25)
100+
np.testing.assert_allclose(by_cd["102"], 0.75)
101+
np.testing.assert_allclose(
102+
agi_probs[cds == "101"] / agi_probs[cds == "101"].sum(),
103+
pop_probs[cds == "101"] / pop_probs[cds == "101"].sum(),
104+
)
105+
np.testing.assert_allclose(
106+
agi_probs[cds == "102"] / agi_probs[cds == "102"].sum(),
107+
pop_probs[cds == "102"] / pop_probs[cds == "102"].sum(),
108+
)
109+
90110
@patch(
91111
"policyengine_us_data.calibration.clone_and_assign"
92112
".load_global_block_distribution"

tests/unit/calibration/test_unified_calibration.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
"""
77

88
import numpy as np
9+
import pytest
10+
from types import SimpleNamespace
11+
from unittest.mock import patch
912

1013
from policyengine_us_data.utils.randomness import seeded_rng
1114
from policyengine_us_data.utils.takeup import (
@@ -352,6 +355,68 @@ def test_county_fips_length(self):
352355
assert all(len(c) == 5 for c in ga.county_fips)
353356

354357

358+
class TestRunCalibrationAgiTargets:
359+
def test_uses_requested_db_for_district_agi_targets(self):
360+
from policyengine_us_data.calibration.unified_calibration import (
361+
run_calibration,
362+
)
363+
364+
captured = {}
365+
366+
class StopAfterAssignment(RuntimeError):
367+
pass
368+
369+
class FakeMicrosimulation:
370+
def __init__(self, dataset, reform=None):
371+
self.dataset = SimpleNamespace(
372+
load_dataset=lambda: {"household_id": {2024: np.array([1, 2])}}
373+
)
374+
375+
def calculate(self, variable, *args, **kwargs):
376+
if variable == "household_id":
377+
return SimpleNamespace(values=np.array([1, 2], dtype=np.int64))
378+
if variable == "adjusted_gross_income":
379+
return SimpleNamespace(
380+
values=np.array([100.0, 200.0], dtype=np.float64)
381+
)
382+
raise AssertionError(f"Unexpected calculate({variable!r})")
383+
384+
class FakeBuilder:
385+
def __init__(self, db_uri, time_period, dataset_path=None):
386+
captured["db_uri"] = db_uri
387+
captured["time_period"] = time_period
388+
captured["dataset_path_at_init"] = dataset_path
389+
390+
def get_district_agi_targets(self):
391+
return {"601": 123.0}
392+
393+
def fake_assign_random_geography(**kwargs):
394+
captured["assign_kwargs"] = kwargs
395+
raise StopAfterAssignment
396+
397+
with (
398+
patch("policyengine_us.Microsimulation", FakeMicrosimulation),
399+
patch(
400+
"policyengine_us_data.calibration.unified_matrix_builder.UnifiedMatrixBuilder",
401+
FakeBuilder,
402+
),
403+
patch(
404+
"policyengine_us_data.calibration.clone_and_assign.assign_random_geography",
405+
fake_assign_random_geography,
406+
),
407+
):
408+
with pytest.raises(StopAfterAssignment):
409+
run_calibration(
410+
dataset_path="input.h5",
411+
db_path="/tmp/custom-policy-data.db",
412+
n_clones=2,
413+
)
414+
415+
assert captured["db_uri"] == "sqlite:////tmp/custom-policy-data.db"
416+
assert captured["time_period"] == 2024
417+
assert captured["assign_kwargs"]["cd_agi_targets"] == {"601": 123.0}
418+
419+
355420
class TestBlockTakeupSeeding:
356421
"""Verify compute_block_takeup_for_entities is
357422
reproducible and clone-dependent."""

tests/unit/calibration/test_unified_matrix_builder.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _create_legacy_target_overview(engine):
116116

117117
def _insert_aca_ptc_data(engine):
118118
with engine.connect() as conn:
119-
strata = [1, 2, 3, 4, 5, 6, 7, 8, 9]
119+
strata = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
120120
for sid in strata:
121121
conn.execute(
122122
text(
@@ -147,6 +147,8 @@ def _insert_aca_ptc_data(engine):
147147
(14, 8, "aca_ptc", ">", "0"),
148148
(15, 8, "congressional_district_geoid", "=", "3702"),
149149
(16, 9, "aca_ptc", ">", "0"),
150+
(17, 10, "congressional_district_geoid", "=", "601"),
151+
(18, 11, "congressional_district_geoid", "=", "602"),
150152
]
151153
for cid, sid, var, op, val in constraints:
152154
conn.execute(
@@ -183,6 +185,9 @@ def _insert_aca_ptc_data(engine):
183185
(17, 9, "person_count", 0, 19743689.0, 2024, 1),
184186
(18, 1, "aca_ptc", 1, 999.0, 2022, 1),
185187
(19, 1, "aca_ptc", 0, 12345.0, 2024, 0),
188+
(20, 10, "adjusted_gross_income", 0, 1000.0, 2021, 1),
189+
(21, 10, "adjusted_gross_income", 0, 1500.0, 2022, 1),
190+
(22, 11, "adjusted_gross_income", 0, 800.0, 2022, 1),
186191
]
187192
for tid, sid, var, reform_id, val, period, active in targets:
188193
conn.execute(
@@ -297,6 +302,13 @@ def test_target_name_adds_expenditure_suffix_for_reforms(self):
297302
)
298303
self.assertEqual(name, "national/salt_deduction_expenditure")
299304

305+
def test_get_district_agi_targets_uses_requested_db_periods(self):
306+
b = self._make_builder(time_period=2024)
307+
self.assertEqual(
308+
b.get_district_agi_targets(),
309+
{"601": 1500.0, "602": 800.0},
310+
)
311+
300312

301313
class TestHierarchicalUprating(unittest.TestCase):
302314
@classmethod

0 commit comments

Comments
 (0)