Skip to content

Commit 7578ba2

Browse files
baogorekclaude
andcommitted
Blend entity values on would_file draws; remove wrong entity weights
Matrix builder: precompute entity values with would_file=False alongside the all-True values, then blend per tax unit based on the would_file draw before applying target takeup draws. This ensures X@w matches sim.calculate for targets affected by non-target state variables. Fixes #609 publish_local_area: remove explicit sub-entity weight overrides (tax_unit_weight, spm_unit_weight, family_weight, marital_unit_weight, person_weight) that used incorrect person-count splitting. These are formula variables in policyengine-us that correctly derive from household_weight at runtime. Fixes #610 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2815699 commit 7578ba2

2 files changed

Lines changed: 258 additions & 28 deletions

File tree

policyengine_us_data/calibration/publish_local_area.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -311,22 +311,6 @@ def build_h5(
311311
unique_geo = derive_geography_from_blocks(unique_blocks)
312312
clone_geo = {k: v[block_inv] for k, v in unique_geo.items()}
313313

314-
# === Calculate weights for all entity levels ===
315-
person_weights = np.repeat(clone_weights, persons_per_clone)
316-
per_person_wt = clone_weights / np.maximum(persons_per_clone, 1)
317-
318-
entity_weights = {}
319-
for ek in SUB_ENTITIES:
320-
n_ents = len(entity_clone_idx[ek])
321-
ent_person_counts = np.zeros(n_ents, dtype=np.int32)
322-
np.add.at(
323-
ent_person_counts,
324-
new_person_entity_ids[ek],
325-
1,
326-
)
327-
clone_ids_e = np.repeat(np.arange(n_clones), entities_per_clone[ek])
328-
entity_weights[ek] = per_person_wt[clone_ids_e] * ent_person_counts
329-
330314
# === Determine variables to save ===
331315
vars_to_save = set(sim.input_variables)
332316
vars_to_save.add("county")
@@ -413,16 +397,12 @@ def build_h5(
413397
}
414398

415399
# === Override weights ===
400+
# Only write household_weight; sub-entity weights (tax_unit_weight,
401+
# spm_unit_weight, person_weight, etc.) are formula variables in
402+
# policyengine-us that derive from household_weight at runtime.
416403
data["household_weight"] = {
417404
time_period: clone_weights.astype(np.float32),
418405
}
419-
data["person_weight"] = {
420-
time_period: person_weights.astype(np.float32),
421-
}
422-
for ek in SUB_ENTITIES:
423-
data[f"{ek}_weight"] = {
424-
time_period: entity_weights[ek].astype(np.float32),
425-
}
426406

427407
# === Override geography ===
428408
data["state_fips"] = {

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 255 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,38 @@ def _compute_single_state(
152152
exc,
153153
)
154154

155-
return (state, {"hh": hh, "person": person, "entity": entity_vals})
155+
entity_wf_false = {}
156+
if rerandomize_takeup:
157+
has_tu_target = any(
158+
info["entity"] == "tax_unit" for info in affected_targets.values()
159+
)
160+
if has_tu_target:
161+
n_tu = len(state_sim.calculate("tax_unit_id", map_to="tax_unit").values)
162+
state_sim.set_input(
163+
"would_file_taxes_voluntarily",
164+
time_period,
165+
np.zeros(n_tu, dtype=bool),
166+
)
167+
for var in get_calculated_variables(state_sim):
168+
state_sim.delete_arrays(var)
169+
for tvar, info in affected_targets.items():
170+
if info["entity"] != "tax_unit":
171+
continue
172+
entity_wf_false[tvar] = state_sim.calculate(
173+
tvar,
174+
time_period,
175+
map_to="tax_unit",
176+
).values.astype(np.float32)
177+
178+
return (
179+
state,
180+
{
181+
"hh": hh,
182+
"person": person,
183+
"entity": entity_vals,
184+
"entity_wf_false": entity_wf_false,
185+
},
186+
)
156187

157188

158189
def _compute_single_state_group_counties(
@@ -278,7 +309,40 @@ def _compute_single_state_group_counties(
278309
exc,
279310
)
280311

281-
results.append((county_fips, {"hh": hh, "entity": entity_vals}))
312+
entity_wf_false = {}
313+
if rerandomize_takeup:
314+
has_tu_target = any(
315+
info["entity"] == "tax_unit" for info in affected_targets.values()
316+
)
317+
if has_tu_target:
318+
n_tu = len(state_sim.calculate("tax_unit_id", map_to="tax_unit").values)
319+
state_sim.set_input(
320+
"would_file_taxes_voluntarily",
321+
time_period,
322+
np.zeros(n_tu, dtype=bool),
323+
)
324+
for var in get_calculated_variables(state_sim):
325+
if var != "county":
326+
state_sim.delete_arrays(var)
327+
for tvar, info in affected_targets.items():
328+
if info["entity"] != "tax_unit":
329+
continue
330+
entity_wf_false[tvar] = state_sim.calculate(
331+
tvar,
332+
time_period,
333+
map_to="tax_unit",
334+
).values.astype(np.float32)
335+
336+
results.append(
337+
(
338+
county_fips,
339+
{
340+
"hh": hh,
341+
"entity": entity_vals,
342+
"entity_wf_false": entity_wf_false,
343+
},
344+
)
345+
)
282346

283347
return results
284348

@@ -552,11 +616,37 @@ def _process_single_clone(
552616
# Takeup re-randomisation
553617
if do_takeup and affected_target_info:
554618
from policyengine_us_data.utils.takeup import (
619+
SIMPLE_TAKEUP_VARS,
555620
compute_block_takeup_for_entities,
556621
)
557622

558623
clone_blocks = geo_blocks[col_start:col_end]
559624

625+
# Phase 1: compute non-target draws (would_file) FIRST
626+
wf_draws = {}
627+
for spec in SIMPLE_TAKEUP_VARS:
628+
if spec.get("target") is not None:
629+
continue
630+
var_name = spec["variable"]
631+
entity = spec["entity"]
632+
rate_key = spec["rate_key"]
633+
if rate_key not in precomputed_rates:
634+
continue
635+
ent_hh = entity_hh_idx_map[entity]
636+
ent_blocks = clone_blocks[ent_hh]
637+
ent_hh_ids = household_ids[ent_hh]
638+
draws = compute_block_takeup_for_entities(
639+
var_name,
640+
precomputed_rates[rate_key],
641+
ent_blocks,
642+
ent_hh_ids,
643+
)
644+
wf_draws[entity] = draws
645+
if var_name in person_vars:
646+
pidx = entity_to_person_idx[entity]
647+
person_vars[var_name] = draws[pidx].astype(np.float32)
648+
649+
# Phase 2: target loop with would_file blending
560650
for tvar, info in affected_target_info.items():
561651
if tvar.endswith("_count"):
562652
continue
@@ -586,6 +676,34 @@ def _process_single_clone(
586676
if tvar in sv:
587677
ent_eligible[m] = sv[tvar][m]
588678

679+
# Blend: for tax_unit targets, select between
680+
# all-takeup-true and would_file=false values
681+
if entity_level == "tax_unit" and "tax_unit" in wf_draws:
682+
ent_wf_false = np.zeros(n_ent, dtype=np.float32)
683+
if tvar in county_dep_targets and county_values:
684+
ent_counties = clone_counties[ent_hh]
685+
for cfips in np.unique(ent_counties):
686+
m = ent_counties == cfips
687+
cv = county_values.get(cfips, {}).get("entity_wf_false", {})
688+
if tvar in cv:
689+
ent_wf_false[m] = cv[tvar][m]
690+
else:
691+
st = int(cfips[:2])
692+
sv = state_values[st].get("entity_wf_false", {})
693+
if tvar in sv:
694+
ent_wf_false[m] = sv[tvar][m]
695+
else:
696+
for st in np.unique(ent_states):
697+
m = ent_states == st
698+
sv = state_values[int(st)].get("entity_wf_false", {})
699+
if tvar in sv:
700+
ent_wf_false[m] = sv[tvar][m]
701+
ent_eligible = np.where(
702+
wf_draws["tax_unit"],
703+
ent_eligible,
704+
ent_wf_false,
705+
)
706+
589707
ent_blocks = clone_blocks[ent_hh]
590708
ent_hh_ids = household_ids[ent_hh]
591709

@@ -950,10 +1068,43 @@ def _build_state_values(
9501068
exc,
9511069
)
9521070

1071+
entity_wf_false = {}
1072+
if rerandomize_takeup:
1073+
has_tu_target = any(
1074+
info["entity"] == "tax_unit"
1075+
for info in affected_targets.values()
1076+
)
1077+
if has_tu_target:
1078+
n_tu = len(
1079+
state_sim.calculate(
1080+
"tax_unit_id",
1081+
map_to="tax_unit",
1082+
).values
1083+
)
1084+
state_sim.set_input(
1085+
"would_file_taxes_voluntarily",
1086+
self.time_period,
1087+
np.zeros(n_tu, dtype=bool),
1088+
)
1089+
for var in get_calculated_variables(state_sim):
1090+
state_sim.delete_arrays(var)
1091+
for (
1092+
tvar,
1093+
info,
1094+
) in affected_targets.items():
1095+
if info["entity"] != "tax_unit":
1096+
continue
1097+
entity_wf_false[tvar] = state_sim.calculate(
1098+
tvar,
1099+
self.time_period,
1100+
map_to="tax_unit",
1101+
).values.astype(np.float32)
1102+
9531103
state_values[state] = {
9541104
"hh": hh,
9551105
"person": person,
9561106
"entity": entity_vals,
1107+
"entity_wf_false": entity_wf_false,
9571108
}
9581109
if (i + 1) % 10 == 0 or i == 0:
9591110
logger.info(
@@ -1216,9 +1367,43 @@ def _build_county_values(
12161367
exc,
12171368
)
12181369

1370+
entity_wf_false = {}
1371+
if rerandomize_takeup:
1372+
has_tu_target = any(
1373+
info["entity"] == "tax_unit"
1374+
for info in affected_targets.values()
1375+
)
1376+
if has_tu_target:
1377+
n_tu = len(
1378+
state_sim.calculate(
1379+
"tax_unit_id",
1380+
map_to="tax_unit",
1381+
).values
1382+
)
1383+
state_sim.set_input(
1384+
"would_file_taxes_voluntarily",
1385+
self.time_period,
1386+
np.zeros(n_tu, dtype=bool),
1387+
)
1388+
for var in get_calculated_variables(state_sim):
1389+
if var != "county":
1390+
state_sim.delete_arrays(var)
1391+
for (
1392+
tvar,
1393+
info,
1394+
) in affected_targets.items():
1395+
if info["entity"] != "tax_unit":
1396+
continue
1397+
entity_wf_false[tvar] = state_sim.calculate(
1398+
tvar,
1399+
self.time_period,
1400+
map_to="tax_unit",
1401+
).values.astype(np.float32)
1402+
12191403
county_values[county_fips] = {
12201404
"hh": hh,
12211405
"entity": entity_vals,
1406+
"entity_wf_false": entity_wf_false,
12221407
}
12231408
county_count += 1
12241409
if county_count % 500 == 0 or county_count == 1:
@@ -1928,10 +2113,14 @@ def build_matrix(
19282113
len(affected_target_info),
19292114
)
19302115

1931-
# Pre-compute takeup rates (constant across clones)
2116+
# Pre-compute takeup rates for ALL takeup vars
2117+
from policyengine_us_data.utils.takeup import (
2118+
SIMPLE_TAKEUP_VARS as _ALL_TAKEUP,
2119+
)
2120+
19322121
precomputed_rates = {}
1933-
for tvar, info in affected_target_info.items():
1934-
rk = info["rate_key"]
2122+
for spec in _ALL_TAKEUP:
2123+
rk = spec["rate_key"]
19352124
if rk not in precomputed_rates:
19362125
precomputed_rates[rk] = load_take_up_rate(rk, self.time_period)
19372126

@@ -2083,6 +2272,36 @@ def build_matrix(
20832272
# for affected target variables
20842273
if rerandomize_takeup and affected_target_info:
20852274
clone_blocks = geography.block_geoid[col_start:col_end]
2275+
2276+
from policyengine_us_data.utils.takeup import (
2277+
SIMPLE_TAKEUP_VARS as _SEQ_TAKEUP,
2278+
)
2279+
2280+
# Phase 1: non-target draws (would_file) FIRST
2281+
wf_draws = {}
2282+
for spec in _SEQ_TAKEUP:
2283+
if spec.get("target") is not None:
2284+
continue
2285+
var_name = spec["variable"]
2286+
entity = spec["entity"]
2287+
rate_key = spec["rate_key"]
2288+
if rate_key not in precomputed_rates:
2289+
continue
2290+
ent_hh = entity_hh_idx_map[entity]
2291+
ent_blocks = clone_blocks[ent_hh]
2292+
ent_hh_ids = household_ids[ent_hh]
2293+
draws = compute_block_takeup_for_entities(
2294+
var_name,
2295+
precomputed_rates[rate_key],
2296+
ent_blocks,
2297+
ent_hh_ids,
2298+
)
2299+
wf_draws[entity] = draws
2300+
if var_name in person_vars:
2301+
pidx = entity_to_person_idx[entity]
2302+
person_vars[var_name] = draws[pidx].astype(np.float32)
2303+
2304+
# Phase 2: target loop with would_file blending
20862305
for (
20872306
tvar,
20882307
info,
@@ -2116,6 +2335,37 @@ def build_matrix(
21162335
if tvar in sv:
21172336
ent_eligible[m] = sv[tvar][m]
21182337

2338+
# Blend for tax_unit targets
2339+
if entity_level == "tax_unit" and "tax_unit" in wf_draws:
2340+
ent_wf_false = np.zeros(n_ent, dtype=np.float32)
2341+
if tvar in county_dep_targets and county_values:
2342+
ent_counties = clone_counties[ent_hh]
2343+
for cfips in np.unique(ent_counties):
2344+
m = ent_counties == cfips
2345+
cv = county_values.get(cfips, {}).get(
2346+
"entity_wf_false", {}
2347+
)
2348+
if tvar in cv:
2349+
ent_wf_false[m] = cv[tvar][m]
2350+
else:
2351+
st = int(cfips[:2])
2352+
sv = state_values[st].get("entity_wf_false", {})
2353+
if tvar in sv:
2354+
ent_wf_false[m] = sv[tvar][m]
2355+
else:
2356+
for st in np.unique(ent_states):
2357+
m = ent_states == st
2358+
sv = state_values[int(st)].get(
2359+
"entity_wf_false", {}
2360+
)
2361+
if tvar in sv:
2362+
ent_wf_false[m] = sv[tvar][m]
2363+
ent_eligible = np.where(
2364+
wf_draws["tax_unit"],
2365+
ent_eligible,
2366+
ent_wf_false,
2367+
)
2368+
21192369
ent_blocks = clone_blocks[ent_hh]
21202370
ent_hh_ids = household_ids[ent_hh]
21212371

0 commit comments

Comments
 (0)