@@ -152,7 +152,38 @@ def _compute_single_state(
152152 exc ,
153153 )
154154
155- return (state , {"hh" : hh , "person" : person , "entity" : entity_vals })
155+ entity_wf_false = {}
156+ if rerandomize_takeup :
157+ has_tu_target = any (
158+ info ["entity" ] == "tax_unit" for info in affected_targets .values ()
159+ )
160+ if has_tu_target :
161+ n_tu = len (state_sim .calculate ("tax_unit_id" , map_to = "tax_unit" ).values )
162+ state_sim .set_input (
163+ "would_file_taxes_voluntarily" ,
164+ time_period ,
165+ np .zeros (n_tu , dtype = bool ),
166+ )
167+ for var in get_calculated_variables (state_sim ):
168+ state_sim .delete_arrays (var )
169+ for tvar , info in affected_targets .items ():
170+ if info ["entity" ] != "tax_unit" :
171+ continue
172+ entity_wf_false [tvar ] = state_sim .calculate (
173+ tvar ,
174+ time_period ,
175+ map_to = "tax_unit" ,
176+ ).values .astype (np .float32 )
177+
178+ return (
179+ state ,
180+ {
181+ "hh" : hh ,
182+ "person" : person ,
183+ "entity" : entity_vals ,
184+ "entity_wf_false" : entity_wf_false ,
185+ },
186+ )
156187
157188
158189def _compute_single_state_group_counties (
@@ -278,7 +309,40 @@ def _compute_single_state_group_counties(
278309 exc ,
279310 )
280311
281- results .append ((county_fips , {"hh" : hh , "entity" : entity_vals }))
312+ entity_wf_false = {}
313+ if rerandomize_takeup :
314+ has_tu_target = any (
315+ info ["entity" ] == "tax_unit" for info in affected_targets .values ()
316+ )
317+ if has_tu_target :
318+ n_tu = len (state_sim .calculate ("tax_unit_id" , map_to = "tax_unit" ).values )
319+ state_sim .set_input (
320+ "would_file_taxes_voluntarily" ,
321+ time_period ,
322+ np .zeros (n_tu , dtype = bool ),
323+ )
324+ for var in get_calculated_variables (state_sim ):
325+ if var != "county" :
326+ state_sim .delete_arrays (var )
327+ for tvar , info in affected_targets .items ():
328+ if info ["entity" ] != "tax_unit" :
329+ continue
330+ entity_wf_false [tvar ] = state_sim .calculate (
331+ tvar ,
332+ time_period ,
333+ map_to = "tax_unit" ,
334+ ).values .astype (np .float32 )
335+
336+ results .append (
337+ (
338+ county_fips ,
339+ {
340+ "hh" : hh ,
341+ "entity" : entity_vals ,
342+ "entity_wf_false" : entity_wf_false ,
343+ },
344+ )
345+ )
282346
283347 return results
284348
@@ -552,11 +616,37 @@ def _process_single_clone(
552616 # Takeup re-randomisation
553617 if do_takeup and affected_target_info :
554618 from policyengine_us_data .utils .takeup import (
619+ SIMPLE_TAKEUP_VARS ,
555620 compute_block_takeup_for_entities ,
556621 )
557622
558623 clone_blocks = geo_blocks [col_start :col_end ]
559624
625+ # Phase 1: compute non-target draws (would_file) FIRST
626+ wf_draws = {}
627+ for spec in SIMPLE_TAKEUP_VARS :
628+ if spec .get ("target" ) is not None :
629+ continue
630+ var_name = spec ["variable" ]
631+ entity = spec ["entity" ]
632+ rate_key = spec ["rate_key" ]
633+ if rate_key not in precomputed_rates :
634+ continue
635+ ent_hh = entity_hh_idx_map [entity ]
636+ ent_blocks = clone_blocks [ent_hh ]
637+ ent_hh_ids = household_ids [ent_hh ]
638+ draws = compute_block_takeup_for_entities (
639+ var_name ,
640+ precomputed_rates [rate_key ],
641+ ent_blocks ,
642+ ent_hh_ids ,
643+ )
644+ wf_draws [entity ] = draws
645+ if var_name in person_vars :
646+ pidx = entity_to_person_idx [entity ]
647+ person_vars [var_name ] = draws [pidx ].astype (np .float32 )
648+
649+ # Phase 2: target loop with would_file blending
560650 for tvar , info in affected_target_info .items ():
561651 if tvar .endswith ("_count" ):
562652 continue
@@ -586,6 +676,34 @@ def _process_single_clone(
586676 if tvar in sv :
587677 ent_eligible [m ] = sv [tvar ][m ]
588678
679+ # Blend: for tax_unit targets, select between
680+ # all-takeup-true and would_file=false values
681+ if entity_level == "tax_unit" and "tax_unit" in wf_draws :
682+ ent_wf_false = np .zeros (n_ent , dtype = np .float32 )
683+ if tvar in county_dep_targets and county_values :
684+ ent_counties = clone_counties [ent_hh ]
685+ for cfips in np .unique (ent_counties ):
686+ m = ent_counties == cfips
687+ cv = county_values .get (cfips , {}).get ("entity_wf_false" , {})
688+ if tvar in cv :
689+ ent_wf_false [m ] = cv [tvar ][m ]
690+ else :
691+ st = int (cfips [:2 ])
692+ sv = state_values [st ].get ("entity_wf_false" , {})
693+ if tvar in sv :
694+ ent_wf_false [m ] = sv [tvar ][m ]
695+ else :
696+ for st in np .unique (ent_states ):
697+ m = ent_states == st
698+ sv = state_values [int (st )].get ("entity_wf_false" , {})
699+ if tvar in sv :
700+ ent_wf_false [m ] = sv [tvar ][m ]
701+ ent_eligible = np .where (
702+ wf_draws ["tax_unit" ],
703+ ent_eligible ,
704+ ent_wf_false ,
705+ )
706+
589707 ent_blocks = clone_blocks [ent_hh ]
590708 ent_hh_ids = household_ids [ent_hh ]
591709
@@ -950,10 +1068,43 @@ def _build_state_values(
9501068 exc ,
9511069 )
9521070
1071+ entity_wf_false = {}
1072+ if rerandomize_takeup :
1073+ has_tu_target = any (
1074+ info ["entity" ] == "tax_unit"
1075+ for info in affected_targets .values ()
1076+ )
1077+ if has_tu_target :
1078+ n_tu = len (
1079+ state_sim .calculate (
1080+ "tax_unit_id" ,
1081+ map_to = "tax_unit" ,
1082+ ).values
1083+ )
1084+ state_sim .set_input (
1085+ "would_file_taxes_voluntarily" ,
1086+ self .time_period ,
1087+ np .zeros (n_tu , dtype = bool ),
1088+ )
1089+ for var in get_calculated_variables (state_sim ):
1090+ state_sim .delete_arrays (var )
1091+ for (
1092+ tvar ,
1093+ info ,
1094+ ) in affected_targets .items ():
1095+ if info ["entity" ] != "tax_unit" :
1096+ continue
1097+ entity_wf_false [tvar ] = state_sim .calculate (
1098+ tvar ,
1099+ self .time_period ,
1100+ map_to = "tax_unit" ,
1101+ ).values .astype (np .float32 )
1102+
9531103 state_values [state ] = {
9541104 "hh" : hh ,
9551105 "person" : person ,
9561106 "entity" : entity_vals ,
1107+ "entity_wf_false" : entity_wf_false ,
9571108 }
9581109 if (i + 1 ) % 10 == 0 or i == 0 :
9591110 logger .info (
@@ -1216,9 +1367,43 @@ def _build_county_values(
12161367 exc ,
12171368 )
12181369
1370+ entity_wf_false = {}
1371+ if rerandomize_takeup :
1372+ has_tu_target = any (
1373+ info ["entity" ] == "tax_unit"
1374+ for info in affected_targets .values ()
1375+ )
1376+ if has_tu_target :
1377+ n_tu = len (
1378+ state_sim .calculate (
1379+ "tax_unit_id" ,
1380+ map_to = "tax_unit" ,
1381+ ).values
1382+ )
1383+ state_sim .set_input (
1384+ "would_file_taxes_voluntarily" ,
1385+ self .time_period ,
1386+ np .zeros (n_tu , dtype = bool ),
1387+ )
1388+ for var in get_calculated_variables (state_sim ):
1389+ if var != "county" :
1390+ state_sim .delete_arrays (var )
1391+ for (
1392+ tvar ,
1393+ info ,
1394+ ) in affected_targets .items ():
1395+ if info ["entity" ] != "tax_unit" :
1396+ continue
1397+ entity_wf_false [tvar ] = state_sim .calculate (
1398+ tvar ,
1399+ self .time_period ,
1400+ map_to = "tax_unit" ,
1401+ ).values .astype (np .float32 )
1402+
12191403 county_values [county_fips ] = {
12201404 "hh" : hh ,
12211405 "entity" : entity_vals ,
1406+ "entity_wf_false" : entity_wf_false ,
12221407 }
12231408 county_count += 1
12241409 if county_count % 500 == 0 or county_count == 1 :
@@ -1928,10 +2113,14 @@ def build_matrix(
19282113 len (affected_target_info ),
19292114 )
19302115
1931- # Pre-compute takeup rates (constant across clones)
2116+ # Pre-compute takeup rates for ALL takeup vars
2117+ from policyengine_us_data .utils .takeup import (
2118+ SIMPLE_TAKEUP_VARS as _ALL_TAKEUP ,
2119+ )
2120+
19322121 precomputed_rates = {}
1933- for tvar , info in affected_target_info . items () :
1934- rk = info ["rate_key" ]
2122+ for spec in _ALL_TAKEUP :
2123+ rk = spec ["rate_key" ]
19352124 if rk not in precomputed_rates :
19362125 precomputed_rates [rk ] = load_take_up_rate (rk , self .time_period )
19372126
@@ -2083,6 +2272,36 @@ def build_matrix(
20832272 # for affected target variables
20842273 if rerandomize_takeup and affected_target_info :
20852274 clone_blocks = geography .block_geoid [col_start :col_end ]
2275+
2276+ from policyengine_us_data .utils .takeup import (
2277+ SIMPLE_TAKEUP_VARS as _SEQ_TAKEUP ,
2278+ )
2279+
2280+ # Phase 1: non-target draws (would_file) FIRST
2281+ wf_draws = {}
2282+ for spec in _SEQ_TAKEUP :
2283+ if spec .get ("target" ) is not None :
2284+ continue
2285+ var_name = spec ["variable" ]
2286+ entity = spec ["entity" ]
2287+ rate_key = spec ["rate_key" ]
2288+ if rate_key not in precomputed_rates :
2289+ continue
2290+ ent_hh = entity_hh_idx_map [entity ]
2291+ ent_blocks = clone_blocks [ent_hh ]
2292+ ent_hh_ids = household_ids [ent_hh ]
2293+ draws = compute_block_takeup_for_entities (
2294+ var_name ,
2295+ precomputed_rates [rate_key ],
2296+ ent_blocks ,
2297+ ent_hh_ids ,
2298+ )
2299+ wf_draws [entity ] = draws
2300+ if var_name in person_vars :
2301+ pidx = entity_to_person_idx [entity ]
2302+ person_vars [var_name ] = draws [pidx ].astype (np .float32 )
2303+
2304+ # Phase 2: target loop with would_file blending
20862305 for (
20872306 tvar ,
20882307 info ,
@@ -2116,6 +2335,37 @@ def build_matrix(
21162335 if tvar in sv :
21172336 ent_eligible [m ] = sv [tvar ][m ]
21182337
2338+ # Blend for tax_unit targets
2339+ if entity_level == "tax_unit" and "tax_unit" in wf_draws :
2340+ ent_wf_false = np .zeros (n_ent , dtype = np .float32 )
2341+ if tvar in county_dep_targets and county_values :
2342+ ent_counties = clone_counties [ent_hh ]
2343+ for cfips in np .unique (ent_counties ):
2344+ m = ent_counties == cfips
2345+ cv = county_values .get (cfips , {}).get (
2346+ "entity_wf_false" , {}
2347+ )
2348+ if tvar in cv :
2349+ ent_wf_false [m ] = cv [tvar ][m ]
2350+ else :
2351+ st = int (cfips [:2 ])
2352+ sv = state_values [st ].get ("entity_wf_false" , {})
2353+ if tvar in sv :
2354+ ent_wf_false [m ] = sv [tvar ][m ]
2355+ else :
2356+ for st in np .unique (ent_states ):
2357+ m = ent_states == st
2358+ sv = state_values [int (st )].get (
2359+ "entity_wf_false" , {}
2360+ )
2361+ if tvar in sv :
2362+ ent_wf_false [m ] = sv [tvar ][m ]
2363+ ent_eligible = np .where (
2364+ wf_draws ["tax_unit" ],
2365+ ent_eligible ,
2366+ ent_wf_false ,
2367+ )
2368+
21192369 ent_blocks = clone_blocks [ent_hh ]
21202370 ent_hh_ids = household_ids [ent_hh ]
21212371
0 commit comments