@@ -115,24 +115,26 @@ def _test_func_comparison_mpi_communication_entrypoint(actx):
115115
116116 comm = actx .mpi_communicator
117117
118- from meshmode .distributed import MPIMeshDistributor , get_partition_by_pymetis
118+ from meshmode .distributed import (
119+ get_partition_by_pymetis , membership_list_to_map )
119120 from meshmode .mesh import BTAG_ALL
121+ from meshmode .mesh .processing import partition_mesh
120122
121- num_parts = comm .Get_size ()
123+ num_parts = comm .size
122124
123- mesh_dist = MPIMeshDistributor (comm )
124-
125- if mesh_dist .is_mananger_rank ():
125+ if comm .rank == 0 :
126126 from meshmode .mesh .generation import generate_regular_rect_mesh
127127 mesh = generate_regular_rect_mesh (a = (- 1 ,)* 2 ,
128128 b = (1 ,)* 2 ,
129129 nelements_per_axis = (2 ,)* 2 )
130130
131- part_per_element = get_partition_by_pymetis (mesh , num_parts )
132-
133- local_mesh = mesh_dist .send_mesh_parts (mesh , part_per_element , num_parts )
131+ part_id_to_part = partition_mesh (mesh ,
132+ membership_list_to_map (
133+ get_partition_by_pymetis (mesh , num_parts )))
134+ parts = [part_id_to_part [i ] for i in range (num_parts )]
135+ local_mesh = comm .scatter (parts )
134136 else :
135- local_mesh = mesh_dist . receive_mesh_part ( )
137+ local_mesh = comm . scatter ( None )
136138
137139 dcoll = DiscretizationCollection (actx , local_mesh , order = 5 )
138140
@@ -188,28 +190,30 @@ def test_mpi_wave_op(actx_class, num_ranks):
188190
189191def _test_mpi_wave_op_entrypoint (actx , visualize = False ):
190192 comm = actx .mpi_communicator
191- i_local_rank = comm .Get_rank ()
192- num_parts = comm .Get_size ()
193+ num_parts = comm .size
193194
194- from meshmode .distributed import MPIMeshDistributor , get_partition_by_pymetis
195- mesh_dist = MPIMeshDistributor (comm )
195+ from meshmode .distributed import (
196+ get_partition_by_pymetis , membership_list_to_map )
197+ from meshmode .mesh .processing import partition_mesh
196198
197199 dim = 2
198200 order = 4
199201
200- if mesh_dist . is_mananger_rank () :
202+ if comm . rank == 0 :
201203 from meshmode .mesh .generation import generate_regular_rect_mesh
202204 mesh = generate_regular_rect_mesh (a = (- 0.5 ,)* dim ,
203205 b = (0.5 ,)* dim ,
204206 nelements_per_axis = (16 ,)* dim )
205207
206- part_per_element = get_partition_by_pymetis (mesh , num_parts )
207-
208- local_mesh = mesh_dist .send_mesh_parts (mesh , part_per_element , num_parts )
208+ part_id_to_part = partition_mesh (mesh ,
209+ membership_list_to_map (
210+ get_partition_by_pymetis (mesh , num_parts )))
211+ parts = [part_id_to_part [i ] for i in range (num_parts )]
212+ local_mesh = comm .scatter (parts )
209213
210214 del mesh
211215 else :
212- local_mesh = mesh_dist . receive_mesh_part ( )
216+ local_mesh = comm . scatter ( None )
213217
214218 dcoll = DiscretizationCollection (actx , local_mesh , order = order )
215219
@@ -270,7 +274,7 @@ def rhs(t, w):
270274
271275 final_t = 4
272276 nsteps = int (final_t / dt )
273- logger .info ("[%04d] dt %.5e nsteps %4d" , i_local_rank , dt , nsteps )
277+ logger .info ("[%04d] dt %.5e nsteps %4d" , comm . rank , dt , nsteps )
274278
275279 step = 0
276280
@@ -308,7 +312,7 @@ def rhs(t, w):
308312
309313 logmgr .tick_after ()
310314 logmgr .close ()
311- logger .info ("Rank %d exiting" , i_local_rank )
315+ logger .info ("Rank %d exiting" , comm . rank )
312316
313317# }}}
314318
0 commit comments