@@ -32,46 +32,48 @@ def main(*, ambient_dim: int) -> None:
3232
3333 from mpi4py import MPI
3434 comm = MPI .COMM_WORLD
35- mpisize = comm .Get_size ()
36- mpirank = comm .Get_rank ()
3735
38- from meshmode .distributed import MPIMeshDistributor
39- dist = MPIMeshDistributor ( comm )
36+ from meshmode .mesh . processing import partition_mesh
37+ from meshmode . distributed import membership_list_to_map
4038
4139 order = 5
4240 nelements = 64 if ambient_dim == 3 else 256
4341
44- logger .info ("[%4d] distributing mesh: started" , mpirank )
42+ logger .info ("[%4d] distributing mesh: started" , comm . rank )
4543
46- if dist . is_mananger_rank () :
44+ if comm . rank == 0 :
4745 mesh = make_example_mesh (ambient_dim , nelements , order = order )
4846 logger .info ("[%4d] mesh: nelements %d nvertices %d" ,
49- mpirank , mesh .nelements , mesh .nvertices )
47+ comm . rank , mesh .nelements , mesh .nvertices )
5048
5149 rng = np .random .default_rng ()
52- part_per_element = rng .integers (mpisize , size = mesh .nelements )
5350
54- local_mesh = dist .send_mesh_parts (mesh , part_per_element , mpisize )
51+ part_id_to_part = partition_mesh (mesh ,
52+ membership_list_to_map (
53+ rng .integers (comm .size , size = mesh .nelements )))
54+ parts = [part_id_to_part [i ] for i in range (comm .size )]
55+ local_mesh = comm .scatter (parts )
5556 else :
56- local_mesh = dist .receive_mesh_part ()
57+ # Reason for type-ignore: presumed faulty type annotation in mpi4py
58+ local_mesh = comm .scatter (None ) # type: ignore[arg-type]
5759
58- logger .info ("[%4d] distributing mesh: finished" , mpirank )
60+ logger .info ("[%4d] distributing mesh: finished" , comm . rank )
5961
6062 from meshmode .discretization import Discretization
6163 from meshmode .discretization .poly_element import default_simplex_group_factory
6264 discr = Discretization (actx , local_mesh ,
6365 default_simplex_group_factory (local_mesh .dim , order = order ))
6466
65- logger .info ("[%4d] discretization: finished" , mpirank )
67+ logger .info ("[%4d] discretization: finished" , comm . rank )
6668
6769 vector_field = actx .thaw (discr .nodes ())
6870 scalar_field = actx .np .sin (vector_field [0 ])
69- part_id = 1.0 + mpirank + discr .zeros (actx ) # type: ignore[operator]
70- logger .info ("[%4d] fields: finished" , mpirank )
71+ part_id = 1.0 + comm . rank + discr .zeros (actx ) # type: ignore[operator]
72+ logger .info ("[%4d] fields: finished" , comm . rank )
7173
7274 from meshmode .discretization .visualization import make_visualizer
7375 vis = make_visualizer (actx , discr , vis_order = order , force_equidistant = False )
74- logger .info ("[%4d] make_visualizer: finished" , mpirank )
76+ logger .info ("[%4d] make_visualizer: finished" , comm . rank )
7577
7678 filename = f"parallel-vtkhdf-example-{ ambient_dim } d.hdf"
7779 vis .write_vtkhdf_file (filename , [
@@ -80,7 +82,7 @@ def main(*, ambient_dim: int) -> None:
8082 ("part_id" , part_id )
8183 ], comm = comm , overwrite = True , use_high_order = False )
8284
83- logger .info ("[%4d] write: finished: %s" , mpirank , filename )
85+ logger .info ("[%4d] write: finished: %s" , comm . rank , filename )
8486
8587
8688if __name__ == "__main__" :
0 commit comments