Skip to content

Commit 7380319

Browse files
committed
Use comm.scatter for mesh distribution
1 parent 2d37bfa commit 7380319

3 files changed

Lines changed: 44 additions & 35 deletions

File tree

examples/wave/wave-min-mpi.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class WaveTag:
4949

5050
def main(ctx_factory, dim=2, order=4, visualize=False):
5151
comm = MPI.COMM_WORLD
52-
num_parts = comm.Get_size()
52+
num_parts = comm.size
5353

5454
cl_ctx = cl.create_some_context()
5555
queue = cl.CommandQueue(cl_ctx)
@@ -60,10 +60,10 @@ def main(ctx_factory, dim=2, order=4, visualize=False):
6060
force_device_scalars=True,
6161
)
6262

63-
from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
64-
mesh_dist = MPIMeshDistributor(comm)
63+
from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map
64+
from meshmode.mesh.processing import partition_mesh
6565

66-
if mesh_dist.is_mananger_rank():
66+
if comm.rank == 0:
6767
from meshmode.mesh.generation import generate_regular_rect_mesh
6868
mesh = generate_regular_rect_mesh(
6969
a=(-0.5,)*dim,
@@ -72,14 +72,16 @@ def main(ctx_factory, dim=2, order=4, visualize=False):
7272

7373
logger.info("%d elements", mesh.nelements)
7474

75-
part_per_element = get_partition_by_pymetis(mesh, num_parts)
76-
77-
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
75+
part_id_to_part = partition_mesh(mesh,
76+
membership_list_to_map(
77+
get_partition_by_pymetis(mesh, num_parts)))
78+
parts = [part_id_to_part[i] for i in range(num_parts)]
79+
local_mesh = comm.scatter(parts)
7880

7981
del mesh
8082

8183
else:
82-
local_mesh = mesh_dist.receive_mesh_part()
84+
local_mesh = comm.scatter(None)
8385

8486
dcoll = DiscretizationCollection(actx, local_mesh, order=order)
8587

examples/wave/wave-op-mpi.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def main(ctx_factory, dim=2, order=3,
184184
queue = cl.CommandQueue(cl_ctx)
185185

186186
comm = MPI.COMM_WORLD
187-
num_parts = comm.Get_size()
187+
num_parts = comm.size
188188

189189
from grudge.array_context import get_reasonable_array_context_class
190190
actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True)
@@ -195,12 +195,12 @@ def main(ctx_factory, dim=2, order=3,
195195
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)),
196196
force_device_scalars=True)
197197

198-
from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
199-
mesh_dist = MPIMeshDistributor(comm)
198+
from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map
199+
from meshmode.mesh.processing import partition_mesh
200200

201201
nel_1d = 16
202202

203-
if mesh_dist.is_mananger_rank():
203+
if comm.rank == 0:
204204
if use_nonaffine_mesh:
205205
from meshmode.mesh.generation import generate_warped_rect_mesh
206206
# FIXME: *generate_warped_rect_mesh* in meshmode warps a
@@ -218,14 +218,17 @@ def main(ctx_factory, dim=2, order=3,
218218

219219
logger.info("%d elements", mesh.nelements)
220220

221-
part_per_element = get_partition_by_pymetis(mesh, num_parts)
221+
part_id_to_part = partition_mesh(mesh,
222+
membership_list_to_map(
223+
get_partition_by_pymetis(mesh, num_parts)))
224+
parts = [part_id_to_part[i] for i in range(num_parts)]
222225

223-
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
226+
local_mesh = comm.scatter(parts)
224227

225228
del mesh
226229

227230
else:
228-
local_mesh = mesh_dist.receive_mesh_part()
231+
local_mesh = comm.scatter(None)
229232

230233
from meshmode.discretization.poly_element import \
231234
QuadratureSimplexGroupFactory, \

test/test_mpi_communication.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,26 @@ def _test_func_comparison_mpi_communication_entrypoint(actx):
115115

116116
comm = actx.mpi_communicator
117117

118-
from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
118+
from meshmode.distributed import (
119+
get_partition_by_pymetis, membership_list_to_map)
119120
from meshmode.mesh import BTAG_ALL
121+
from meshmode.mesh.processing import partition_mesh
120122

121-
num_parts = comm.Get_size()
123+
num_parts = comm.size
122124

123-
mesh_dist = MPIMeshDistributor(comm)
124-
125-
if mesh_dist.is_mananger_rank():
125+
if comm.rank == 0:
126126
from meshmode.mesh.generation import generate_regular_rect_mesh
127127
mesh = generate_regular_rect_mesh(a=(-1,)*2,
128128
b=(1,)*2,
129129
nelements_per_axis=(2,)*2)
130130

131-
part_per_element = get_partition_by_pymetis(mesh, num_parts)
132-
133-
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
131+
part_id_to_part = partition_mesh(mesh,
132+
membership_list_to_map(
133+
get_partition_by_pymetis(mesh, num_parts)))
134+
parts = [part_id_to_part[i] for i in range(num_parts)]
135+
local_mesh = comm.scatter(parts)
134136
else:
135-
local_mesh = mesh_dist.receive_mesh_part()
137+
local_mesh = comm.scatter(None)
136138

137139
dcoll = DiscretizationCollection(actx, local_mesh, order=5)
138140

@@ -188,28 +190,30 @@ def test_mpi_wave_op(actx_class, num_ranks):
188190

189191
def _test_mpi_wave_op_entrypoint(actx, visualize=False):
190192
comm = actx.mpi_communicator
191-
i_local_rank = comm.Get_rank()
192-
num_parts = comm.Get_size()
193+
num_parts = comm.size
193194

194-
from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis
195-
mesh_dist = MPIMeshDistributor(comm)
195+
from meshmode.distributed import (
196+
get_partition_by_pymetis, membership_list_to_map)
197+
from meshmode.mesh.processing import partition_mesh
196198

197199
dim = 2
198200
order = 4
199201

200-
if mesh_dist.is_mananger_rank():
202+
if comm.rank == 0:
201203
from meshmode.mesh.generation import generate_regular_rect_mesh
202204
mesh = generate_regular_rect_mesh(a=(-0.5,)*dim,
203205
b=(0.5,)*dim,
204206
nelements_per_axis=(16,)*dim)
205207

206-
part_per_element = get_partition_by_pymetis(mesh, num_parts)
207-
208-
local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts)
208+
part_id_to_part = partition_mesh(mesh,
209+
membership_list_to_map(
210+
get_partition_by_pymetis(mesh, num_parts)))
211+
parts = [part_id_to_part[i] for i in range(num_parts)]
212+
local_mesh = comm.scatter(parts)
209213

210214
del mesh
211215
else:
212-
local_mesh = mesh_dist.receive_mesh_part()
216+
local_mesh = comm.scatter(None)
213217

214218
dcoll = DiscretizationCollection(actx, local_mesh, order=order)
215219

@@ -270,7 +274,7 @@ def rhs(t, w):
270274

271275
final_t = 4
272276
nsteps = int(final_t/dt)
273-
logger.info("[%04d] dt %.5e nsteps %4d", i_local_rank, dt, nsteps)
277+
logger.info("[%04d] dt %.5e nsteps %4d", comm.rank, dt, nsteps)
274278

275279
step = 0
276280

@@ -308,7 +312,7 @@ def rhs(t, w):
308312

309313
logmgr.tick_after()
310314
logmgr.close()
311-
logger.info("Rank %d exiting", i_local_rank)
315+
logger.info("Rank %d exiting", comm.rank)
312316

313317
# }}}
314318

0 commit comments

Comments
 (0)