Skip to content

Commit f85dece

Browse files
committed
Deprecate MPIMeshDistributor
1 parent d35aa57 commit f85dece

3 files changed

Lines changed: 35 additions & 30 deletions

File tree

examples/parallel-vtkhdf.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,46 +32,48 @@ def main(*, ambient_dim: int) -> None:
3232

3333
from mpi4py import MPI
3434
comm = MPI.COMM_WORLD
35-
mpisize = comm.Get_size()
36-
mpirank = comm.Get_rank()
3735

38-
from meshmode.distributed import MPIMeshDistributor
39-
dist = MPIMeshDistributor(comm)
36+
from meshmode.mesh.processing import partition_mesh
37+
from meshmode.distributed import membership_list_to_map
4038

4139
order = 5
4240
nelements = 64 if ambient_dim == 3 else 256
4341

44-
logger.info("[%4d] distributing mesh: started", mpirank)
42+
logger.info("[%4d] distributing mesh: started", comm.rank)
4543

46-
if dist.is_mananger_rank():
44+
if comm.rank == 0:
4745
mesh = make_example_mesh(ambient_dim, nelements, order=order)
4846
logger.info("[%4d] mesh: nelements %d nvertices %d",
49-
mpirank, mesh.nelements, mesh.nvertices)
47+
comm.rank, mesh.nelements, mesh.nvertices)
5048

5149
rng = np.random.default_rng()
52-
part_per_element = rng.integers(mpisize, size=mesh.nelements)
5350

54-
local_mesh = dist.send_mesh_parts(mesh, part_per_element, mpisize)
51+
part_id_to_part = partition_mesh(mesh,
52+
membership_list_to_map(
53+
rng.integers(comm.size, size=mesh.nelements)))
54+
parts = [part_id_to_part[i] for i in range(comm.size)]
55+
local_mesh = comm.scatter(parts)
5556
else:
56-
local_mesh = dist.receive_mesh_part()
57+
# Reason for type-ignore: presumed faulty type annotation in mpi4py
58+
local_mesh = comm.scatter(None) # type: ignore[arg-type]
5759

58-
logger.info("[%4d] distributing mesh: finished", mpirank)
60+
logger.info("[%4d] distributing mesh: finished", comm.rank)
5961

6062
from meshmode.discretization import Discretization
6163
from meshmode.discretization.poly_element import default_simplex_group_factory
6264
discr = Discretization(actx, local_mesh,
6365
default_simplex_group_factory(local_mesh.dim, order=order))
6466

65-
logger.info("[%4d] discretization: finished", mpirank)
67+
logger.info("[%4d] discretization: finished", comm.rank)
6668

6769
vector_field = actx.thaw(discr.nodes())
6870
scalar_field = actx.np.sin(vector_field[0])
69-
part_id = 1.0 + mpirank + discr.zeros(actx) # type: ignore[operator]
70-
logger.info("[%4d] fields: finished", mpirank)
71+
part_id = 1.0 + comm.rank + discr.zeros(actx) # type: ignore[operator]
72+
logger.info("[%4d] fields: finished", comm.rank)
7173

7274
from meshmode.discretization.visualization import make_visualizer
7375
vis = make_visualizer(actx, discr, vis_order=order, force_equidistant=False)
74-
logger.info("[%4d] make_visualizer: finished", mpirank)
76+
logger.info("[%4d] make_visualizer: finished", comm.rank)
7577

7678
filename = f"parallel-vtkhdf-example-{ambient_dim}d.hdf"
7779
vis.write_vtkhdf_file(filename, [
@@ -80,7 +82,7 @@ def main(*, ambient_dim: int) -> None:
8082
("part_id", part_id)
8183
], comm=comm, overwrite=True, use_high_order=False)
8284

83-
logger.info("[%4d] write: finished: %s", mpirank, filename)
85+
logger.info("[%4d] write: finished: %s", comm.rank, filename)
8486

8587

8688
if __name__ == "__main__":

meshmode/distributed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""
2-
.. autoclass:: MPIMeshDistributor
32
.. autoclass:: InterRankBoundaryInfo
43
.. autoclass:: MPIBoundaryCommSetupHelper
54
@@ -85,6 +84,10 @@ def __init__(self, mpi_comm, manager_rank=0):
8584
self.mpi_comm = mpi_comm
8685
self.manager_rank = manager_rank
8786

87+
warn("MPIMeshDistributor is deprecated and will be removed in 2024. "
88+
"Directly call partition_mesh and use mpi_comm.scatter instead.",
89+
DeprecationWarning, stacklevel=2)
90+
8891
def is_mananger_rank(self):
8992
return self.mpi_comm.Get_rank() == self.manager_rank
9093

test/test_partition.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,14 @@ def count_tags(mesh, tag):
368368
# {{{ MPI test boundary swap
369369

370370
def _test_mpi_boundary_swap(dim, order, num_groups):
371-
from meshmode.distributed import MPIMeshDistributor, MPIBoundaryCommSetupHelper
371+
from meshmode.distributed import (MPIBoundaryCommSetupHelper,
372+
membership_list_to_map)
373+
from meshmode.mesh.processing import partition_mesh
372374

373375
from mpi4py import MPI
374376
mpi_comm = MPI.COMM_WORLD
375-
i_local_part = mpi_comm.Get_rank()
376-
num_parts = mpi_comm.Get_size()
377377

378-
mesh_dist = MPIMeshDistributor(mpi_comm)
379-
380-
if mesh_dist.is_mananger_rank():
378+
if mpi_comm.rank == 0:
381379
np.random.seed(42)
382380
from meshmode.mesh.generation import generate_warped_rect_mesh
383381
meshes = [generate_warped_rect_mesh(dim, order=order, nelements_side=4)
@@ -389,11 +387,14 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
389387
else:
390388
mesh = meshes[0]
391389

392-
part_per_element = np.random.randint(num_parts, size=mesh.nelements)
390+
part_id_to_part = partition_mesh(mesh,
391+
membership_list_to_map(
392+
np.random.randint(mpi_comm.size, size=mesh.nelements)))
393+
parts = [part_id_to_part[i] for i in range(mpi_comm.size)]
393394

394-
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
395+
local_mesh = mpi_comm.scatter(parts)
395396
else:
396-
local_mesh = mesh_dist.receive_mesh_part()
397+
local_mesh = mpi_comm.scatter(None)
397398

398399
group_factory = default_simplex_group_factory(base_dim=dim, order=order)
399400

@@ -436,14 +437,13 @@ def _test_mpi_boundary_swap(dim, order, num_groups):
436437
remote_to_local_bdry_conns,
437438
connected_parts)
438439

439-
logger.debug("Rank %d exiting", i_local_part)
440+
logger.debug("Rank %d exiting", mpi_comm.rank)
440441

441442

442443
def _test_connected_parts(mpi_comm, connected_parts):
443444
num_parts = mpi_comm.Get_size()
444-
i_local_part = mpi_comm.Get_rank()
445445

446-
assert i_local_part not in connected_parts
446+
assert mpi_comm.rank not in connected_parts
447447

448448
# Get the full adjacency
449449
connected_mask = np.empty(num_parts, dtype=bool)
@@ -456,7 +456,7 @@ def _test_connected_parts(mpi_comm, connected_parts):
456456
# make sure it agrees with connected_parts
457457
parts_connected_to_me = set()
458458
for i_remote_part in range(num_parts):
459-
if all_connected_masks[i_remote_part][i_local_part]:
459+
if all_connected_masks[i_remote_part][mpi_comm.rank]:
460460
parts_connected_to_me.add(i_remote_part)
461461
assert parts_connected_to_me == connected_parts
462462

0 commit comments

Comments
 (0)