|
15 | 15 | ExtendedCPS_2024_Half, |
16 | 16 | CPS_2024, |
17 | 17 | ) |
| 18 | +from policyengine_us_data.utils.randomness import seeded_rng |
| 19 | +from policyengine_us_data.utils.takeup import ( |
| 20 | + ACA_POST_CALIBRATION_PERSON_TARGETS, |
| 21 | + extend_aca_takeup_to_match_target, |
| 22 | +) |
18 | 23 | import logging |
19 | 24 |
|
20 | 25 | try: |
|
23 | 28 | torch = None |
24 | 29 |
|
25 | 30 |
|
| 31 | +def _get_period_array(period_values: dict, period: int) -> np.ndarray: |
| 32 | + """Get a period array from a TIME_PERIOD_ARRAYS variable dict.""" |
| 33 | + value = period_values.get(period) |
| 34 | + if value is None: |
| 35 | + value = period_values.get(str(period)) |
| 36 | + if value is None: |
| 37 | + raise KeyError(f"Missing period {period}") |
| 38 | + return np.asarray(value) |
| 39 | + |
| 40 | + |
| 41 | +def create_aca_2025_takeup_override( |
| 42 | + base_takeup: np.ndarray, |
| 43 | + person_enrolled_if_takeup: np.ndarray, |
| 44 | + person_weights: np.ndarray, |
| 45 | + person_tax_unit_ids: np.ndarray, |
| 46 | + tax_unit_ids: np.ndarray, |
| 47 | + target_people: float = ACA_POST_CALIBRATION_PERSON_TARGETS[2025], |
| 48 | +) -> np.ndarray: |
| 49 | + """Add 2025 ACA takers until weighted APTC enrollment hits target.""" |
| 50 | + tax_unit_id_to_idx = { |
| 51 | + int(tax_unit_id): idx for idx, tax_unit_id in enumerate(tax_unit_ids) |
| 52 | + } |
| 53 | + person_tax_unit_idx = np.array( |
| 54 | + [tax_unit_id_to_idx[int(tax_unit_id)] for tax_unit_id in person_tax_unit_ids], |
| 55 | + dtype=np.int64, |
| 56 | + ) |
| 57 | + enrolled_person_weights = np.zeros(len(tax_unit_ids), dtype=np.float64) |
| 58 | + np.add.at( |
| 59 | + enrolled_person_weights, |
| 60 | + person_tax_unit_idx, |
| 61 | + person_enrolled_if_takeup.astype(np.float64) * person_weights, |
| 62 | + ) |
| 63 | + draws = seeded_rng("takes_up_aca_if_eligible").random(len(tax_unit_ids)) |
| 64 | + |
| 65 | + return extend_aca_takeup_to_match_target( |
| 66 | + base_takeup=np.asarray(base_takeup, dtype=bool), |
| 67 | + entity_draws=draws, |
| 68 | + enrolled_person_weights=enrolled_person_weights, |
| 69 | + target_people=target_people, |
| 70 | + ) |
| 71 | + |
| 72 | + |
26 | 73 | def reweight( |
27 | 74 | original_weights, |
28 | 75 | loss_matrix, |
@@ -142,6 +189,7 @@ def generate(self): |
142 | 189 |
|
143 | 190 | sim = Microsimulation(dataset=self.input_dataset) |
144 | 191 | data = sim.dataset.load_dataset() |
| 192 | + base_year = int(sim.default_calculation_period) |
145 | 193 | data["household_weight"] = {} |
146 | 194 | original_weights = sim.calculate("household_weight") |
147 | 195 | original_weights = original_weights.values + np.random.normal( |
@@ -216,6 +264,52 @@ def generate(self): |
216 | 264 | f"{int(np.sum(w > 0))} non-zero" |
217 | 265 | ) |
218 | 266 |
|
| 267 | + if 2025 in ACA_POST_CALIBRATION_PERSON_TARGETS: |
| 268 | + sim.set_input( |
| 269 | + "household_weight", |
| 270 | + base_year, |
| 271 | + _get_period_array(data["household_weight"], base_year).astype( |
| 272 | + np.float32 |
| 273 | + ), |
| 274 | + ) |
| 275 | + sim.set_input( |
| 276 | + "takes_up_aca_if_eligible", |
| 277 | + 2025, |
| 278 | + np.ones( |
| 279 | + len(_get_period_array(data["tax_unit_id"], base_year)), |
| 280 | + dtype=bool, |
| 281 | + ), |
| 282 | + ) |
| 283 | + sim.delete_arrays("aca_ptc") |
| 284 | + |
| 285 | + data["takes_up_aca_if_eligible"][2025] = create_aca_2025_takeup_override( |
| 286 | + base_takeup=_get_period_array( |
| 287 | + data["takes_up_aca_if_eligible"], |
| 288 | + base_year, |
| 289 | + ), |
| 290 | + person_enrolled_if_takeup=np.asarray( |
| 291 | + sim.calculate( |
| 292 | + "aca_ptc", |
| 293 | + map_to="person", |
| 294 | + period=2025, |
| 295 | + use_weights=False, |
| 296 | + ) |
| 297 | + ) |
| 298 | + > 0, |
| 299 | + person_weights=np.asarray( |
| 300 | + sim.calculate( |
| 301 | + "person_weight", |
| 302 | + period=2025, |
| 303 | + use_weights=False, |
| 304 | + ) |
| 305 | + ), |
| 306 | + person_tax_unit_ids=_get_period_array( |
| 307 | + data["person_tax_unit_id"], |
| 308 | + base_year, |
| 309 | + ), |
| 310 | + tax_unit_ids=_get_period_array(data["tax_unit_id"], base_year), |
| 311 | + ) |
| 312 | + |
219 | 313 | logging.info("Post-generation weight validation passed") |
220 | 314 |
|
221 | 315 | self.save_dataset(data) |
|
0 commit comments