@@ -140,35 +140,27 @@ def get_version() -> str:
140140
141141
142142def partition_work (
143- states : List [str ],
144- districts : List [str ],
145- cities : List [str ],
143+ work_items : List [Dict ],
146144 num_workers : int ,
147145 completed : set ,
148146) -> List [List [Dict ]]:
149- """Partition work items across N workers."""
150- remaining = []
151-
152- for s in states :
153- item_id = f"state:{ s } "
154- if item_id not in completed :
155- remaining .append ({"type" : "state" , "id" : s , "weight" : 5 })
156-
157- for d in districts :
158- item_id = f"district:{ d } "
159- if item_id not in completed :
160- remaining .append ({"type" : "district" , "id" : d , "weight" : 1 })
147+ """Partition work items across N workers using LPT scheduling."""
148+ remaining = [
149+ item for item in work_items if f"{ item ['type' ]} :{ item ['id' ]} " not in completed
150+ ]
151+ remaining .sort (key = lambda x : - x ["weight" ])
161152
162- for c in cities :
163- item_id = f"city:{ c } "
164- if item_id not in completed :
165- remaining .append ({"type" : "city" , "id" : c , "weight" : 3 })
153+ n_workers = min (num_workers , len (remaining ))
154+ if n_workers == 0 :
155+ return []
166156
167- remaining .sort (key = lambda x : - x ["weight" ])
157+ heap = [(0 , i ) for i in range (n_workers )]
158+ chunks = [[] for _ in range (n_workers )]
168159
169- chunks = [[] for _ in range (num_workers )]
170- for i , item in enumerate (remaining ):
171- chunks [i % num_workers ].append (item )
160+ for item in remaining :
161+ load , idx = heapq .heappop (heap )
162+ chunks [idx ].append (item )
163+ heapq .heappush (heap , (load + item ["weight" ], idx ))
172164
173165 return [c for c in chunks if c ]
174166
@@ -197,9 +189,7 @@ def get_completed_from_volume(version_dir: Path) -> set:
197189
198190def run_phase (
199191 phase_name : str ,
200- states : List [str ],
201- districts : List [str ],
202- cities : List [str ],
192+ work_items : List [Dict ],
203193 num_workers : int ,
204194 completed : set ,
205195 branch : str ,
@@ -216,7 +206,7 @@ def run_phase(
216206 and crashes, and validation_rows is a list of per-target
217207 validation result dicts.
218208 """
219- work_chunks = partition_work (states , districts , cities , num_workers , completed )
209+ work_chunks = partition_work (work_items , num_workers , completed )
220210 total_remaining = sum (len (c ) for c in work_chunks )
221211
222212 print (f"\n --- Phase: { phase_name } ---" )
@@ -228,7 +218,8 @@ def run_phase(
228218
229219 handles = []
230220 for i , chunk in enumerate (work_chunks ):
231- print (f" Worker { i } : { len (chunk )} items" )
221+ total_weight = sum (item ["weight" ] for item in chunk )
222+ print (f" Worker { i } : { len (chunk )} items, weight { total_weight } " )
232223 handle = build_areas_worker .spawn (
233224 branch = branch ,
234225 version = version ,
@@ -753,7 +744,7 @@ def coordinate_publish(
753744cds = get_all_cds_from_database(db_uri)
754745states = list(STATE_CODES.values())
755746districts = [get_district_friendly_name(cd) for cd in cds]
756- print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"]}}))
747+ print(json.dumps({{"states": states, "districts": districts, "cities": ["NYC"], "cds": cds }}))
757748""" ,
758749 ],
759750 capture_output = True ,
@@ -769,6 +760,22 @@ def coordinate_publish(
769760 districts = work_info ["districts" ]
770761 cities = work_info ["cities" ]
771762
763+ from collections import Counter
764+ from policyengine_us_data .calibration .calibration_utils import STATE_CODES
765+
766+ raw_cds = work_info ["cds" ]
767+ cds_per_state = Counter (STATE_CODES .get (int (cd ) // 100 , "??" ) for cd in raw_cds )
768+
769+ CITY_WEIGHTS = {"NYC" : 11 }
770+
771+ work_items = []
772+ for s in states :
773+ work_items .append ({"type" : "state" , "id" : s , "weight" : cds_per_state .get (s , 1 )})
774+ for d in districts :
775+ work_items .append ({"type" : "district" , "id" : d , "weight" : 1 })
776+ for c in cities :
777+ work_items .append ({"type" : "city" , "id" : c , "weight" : CITY_WEIGHTS .get (c , 3 )})
778+
772779 staging_volume .reload ()
773780 completed = get_completed_from_volume (version_dir )
774781 print (f"Found { len (completed )} already-completed items on volume" )
@@ -786,32 +793,8 @@ def coordinate_publish(
786793 accumulated_validation_rows = []
787794
788795 completed , phase_errors , v_rows = run_phase (
789- "States" ,
790- states = states ,
791- districts = [],
792- cities = [],
793- completed = completed ,
794- ** phase_args ,
795- )
796- accumulated_errors .extend (phase_errors )
797- accumulated_validation_rows .extend (v_rows )
798-
799- completed , phase_errors , v_rows = run_phase (
800- "Districts" ,
801- states = [],
802- districts = districts ,
803- cities = [],
804- completed = completed ,
805- ** phase_args ,
806- )
807- accumulated_errors .extend (phase_errors )
808- accumulated_validation_rows .extend (v_rows )
809-
810- completed , phase_errors , v_rows = run_phase (
811- "Cities" ,
812- states = [],
813- districts = [],
814- cities = cities ,
796+ "All areas" ,
797+ work_items = work_items ,
815798 completed = completed ,
816799 ** phase_args ,
817800 )
0 commit comments