Skip to content

Commit 67429d1

Browse files
committed
Adaptive ASM
1 parent b710d29 commit 67429d1

1 file changed

Lines changed: 195 additions & 102 deletions

File tree

  • firedrake/preconditioners

firedrake/preconditioners/asm.py

Lines changed: 195 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22

3+
from itertools import chain
34
from pyop2.datatypes import IntType
45
from firedrake.preconditioners.base import PCBase
56
from firedrake.petsc import PETSc
@@ -140,68 +141,6 @@ def destroy(self, pc):
140141
self.asmpc.destroy()
141142

142143

143-
def get_entity_dofs(V, V_local_ises_indices, points):
144-
"""Extract degrees of freedom associated to mesh entities (points of the DMPlex)."""
145-
indices = []
146-
for (i, W) in enumerate(V):
147-
section = W.dm.getDefaultSection()
148-
for p in points:
149-
dof = section.getDof(p)
150-
if dof <= 0:
151-
continue
152-
off = section.getOffset(p)
153-
# Local indices within W
154-
W_slice = slice(off*W.block_size, W.block_size * (off + dof))
155-
indices.extend(V_local_ises_indices[i][W_slice])
156-
return indices
157-
158-
159-
def build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, seed_points):
160-
"""Build index sets for star patches."""
161-
points = []
162-
for seed in seed_points:
163-
# Only build patches over owned DoFs
164-
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
165-
continue
166-
# Create point list from mesh DM
167-
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
168-
star = order_points(mesh_dm, star, ordering, prefix)
169-
points.extend(star)
170-
171-
indices = get_entity_dofs(V, V_local_ises_indices, points)
172-
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
173-
return iset
174-
175-
176-
def build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix, include_star, seed_points):
177-
"""Build index sets for Vanka patches."""
178-
V_points = []
179-
Q_points = []
180-
for seed in seed_points:
181-
# Only build patches over owned DoFs
182-
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
183-
continue
184-
# Create point list from mesh DM
185-
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
186-
star = order_points(mesh_dm, star, ordering, prefix)
187-
if include_star:
188-
Q_points.extend(star)
189-
else:
190-
Q_points.append(seed)
191-
192-
closure = []
193-
for s in reversed(star):
194-
cs, _ = mesh_dm.getTransitiveClosure(s, useCone=True)
195-
closure.extend(cs)
196-
# Grab unique points with stable ordering
197-
V_points.extend(reversed(dict.fromkeys(closure)))
198-
199-
indices = get_entity_dofs(Z[0], Z_local_ises_indices[0], V_points)
200-
indices.extend(get_entity_dofs(Z[1], Z_local_ises_indices[1], Q_points))
201-
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
202-
return iset
203-
204-
205144
class ASMStarPC(ASMPatchPC):
206145
'''Patch-based PC using Star of mesh entities implmented as an
207146
:class:`ASMPatchPC`.
@@ -214,13 +153,13 @@ class ASMStarPC(ASMPatchPC):
214153
_prefix = "pc_star_"
215154

216155
def get_patches(self, V):
217-
mesh = V._mesh
156+
mesh = V.mesh()
218157
if len(set(mesh)) == 1:
219-
mesh_unique = mesh.unique()
158+
mesh = mesh.unique()
220159
else:
221160
raise NotImplementedError("Not implemented for general mixed meshes")
222-
mesh_dm = mesh_unique.topology_dm
223-
if mesh_unique.cell_set._extruded:
161+
mesh_dm = mesh.topology_dm
162+
if mesh.cell_set._extruded:
224163
warning("applying ASMStarPC on an extruded mesh")
225164

226165
# Obtain the topological entities to use to construct the stars
@@ -229,25 +168,31 @@ def get_patches(self, V):
229168
depth = opts.getInt("construct_dim", default=0)
230169
coloring = opts.getBool("coloring", default=False)
231170
ordering = opts.getString("mat_ordering_type", default="natural")
232-
validate_overlap(mesh_unique, depth, "star")
171+
validate_overlap(mesh, depth, "star")
233172

234173
# Accessing .indices causes the allocation of a global array,
235174
# so we need to cache these for efficiency
236175
V_local_ises_indices = tuple(iset.indices for iset in V.dof_dset.local_ises)
237176

238-
(start, end) = mesh_dm.getDepthStratum(depth)
239-
if coloring and end > start:
240-
colors = mesh_dm.createColoring(depth=depth, distance=1)
241-
colors = [color for color in colors if color.getLocalSize() > 0]
242-
if len(colors) == 0:
243-
shift = 0
244-
else:
245-
shift = start - min(color.indices.min() for color in colors)
246-
ises = [build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, color.indices+shift)
177+
point_subset = None
178+
if hasattr(mesh, "netgen_mesh"):
179+
cell_subset = get_refined_cells(mesh)
180+
point_subset = get_adjacent_stratum(mesh_dm, depth, subset=cell_subset)
181+
182+
if coloring:
183+
colors = get_point_coloring(mesh_dm, depth, 1)
184+
if point_subset is not None:
185+
colors = tuple(numpy.intersect1d(point_subset, color) for color in colors)
186+
187+
ises = [build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, color)
247188
for color in colors]
248189
else:
190+
(start, end) = mesh_dm.getDepthStratum(depth)
191+
if point_subset is None:
192+
point_subset = range(start, end)
193+
249194
ises = [build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, (seed,))
250-
for seed in range(start, end)]
195+
for seed in point_subset]
251196
return ises
252197

253198

@@ -263,13 +208,13 @@ class ASMVankaPC(ASMPatchPC):
263208
_prefix = "pc_vanka_"
264209

265210
def get_patches(self, V):
266-
mesh = V._mesh
211+
mesh = V.mesh()
267212
if len(set(mesh)) == 1:
268-
mesh_unique = mesh.unique()
213+
mesh = mesh.unique()
269214
else:
270215
raise NotImplementedError("Not implemented for general mixed meshes")
271-
mesh_dm = mesh_unique.topology_dm
272-
if mesh_unique.layers:
216+
mesh_dm = mesh.topology_dm
217+
if mesh.layers:
273218
warning("applying ASMVankaPC on an extruded mesh")
274219

275220
# Obtain the topological entities to use to construct the stars
@@ -300,29 +245,29 @@ def splitting(V):
300245
Z_local_ises_indices = splitting(V_local_ises_indices)
301246

302247
# Build index sets for the patches
303-
if depth != -1:
304-
(start, end) = mesh_dm.getDepthStratum(depth)
305-
patch_dim = depth
306-
else:
307-
(start, end) = mesh_dm.getHeightStratum(height)
308-
patch_dim = mesh_dm.getDimension() - height
309-
validate_overlap(mesh_unique, patch_dim, "vanka")
310-
311-
if start == end:
312-
ises = []
313-
elif coloring:
314-
colors = mesh_dm.createColoring(depth=patch_dim, distance=2)
315-
colors = [color for color in colors if color.getLocalSize() > 0]
316-
if len(colors) == 0:
317-
shift = start
318-
else:
319-
shift = start - min(color.indices.min() for color in colors)
248+
if depth == -1:
249+
depth = mesh_dm.getDimension() - height
250+
validate_overlap(mesh, depth, "vanka")
251+
252+
point_subset = None
253+
if hasattr(mesh, "netgen_mesh"):
254+
cell_subset = get_refined_cells(mesh)
255+
point_subset = get_adjacent_stratum(mesh_dm, depth, subset=cell_subset)
256+
257+
if coloring:
258+
colors = get_point_coloring(mesh_dm, depth, 2)
259+
if point_subset is not None:
260+
colors = tuple(numpy.intersect1d(point_subset, color) for color in colors)
261+
320262
ises = [build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix,
321-
include_star, color.indices+shift)
322-
for color in colors]
263+
include_star, color) for color in colors]
323264
else:
265+
(start, end) = mesh_dm.getDepthStratum(depth)
266+
if point_subset is None:
267+
point_subset = range(start, end)
268+
324269
ises = [build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix, include_star, (seed,))
325-
for seed in range(start, end)]
270+
for seed in point_subset]
326271
return ises
327272

328273

@@ -348,7 +293,7 @@ class ASMLinesmoothPC(ASMPatchPC):
348293
_prefix = "pc_linesmooth_"
349294

350295
def get_patches(self, V):
351-
mesh = V._mesh
296+
mesh = V.mesh()
352297
if len(set(mesh)) == 1:
353298
mesh_unique = mesh.unique()
354299
else:
@@ -593,3 +538,151 @@ def validate_overlap(mesh, patch_dim, patch_type):
593538
if overlap_depth < patch_depth:
594539
warning(f"Mesh overlap depth of {overlap_depth} does not support {patch_type}-patches. "
595540
"Did you forget to set overlap_type in your mesh's distribution_parameters?")
541+
542+
543+
def get_entity_dofs(V, V_local_ises_indices, points):
544+
"""Extract degrees of freedom associated to mesh entities (points of the DMPlex)."""
545+
indices = []
546+
for (i, W) in enumerate(V):
547+
section = W.dm.getDefaultSection()
548+
for p in points:
549+
dof = section.getDof(p)
550+
if dof <= 0:
551+
continue
552+
off = section.getOffset(p)
553+
# Local indices within W
554+
W_slice = slice(off*W.block_size, W.block_size * (off + dof))
555+
indices.extend(V_local_ises_indices[i][W_slice])
556+
return indices
557+
558+
559+
def build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, seed_points):
560+
"""Build index sets for star patches."""
561+
points = []
562+
for seed in seed_points:
563+
# Only build patches over owned DoFs
564+
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
565+
continue
566+
# Create point list from mesh DM
567+
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
568+
star = order_points(mesh_dm, star, ordering, prefix)
569+
points.extend(star)
570+
571+
indices = get_entity_dofs(V, V_local_ises_indices, points)
572+
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
573+
return iset
574+
575+
576+
def build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix, include_star, seed_points):
577+
"""Build index sets for Vanka patches."""
578+
V_points = []
579+
Q_points = []
580+
for seed in seed_points:
581+
# Only build patches over owned DoFs
582+
if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1:
583+
continue
584+
# Create point list from mesh DM
585+
star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False)
586+
star = order_points(mesh_dm, star, ordering, prefix)
587+
if include_star:
588+
Q_points.extend(star)
589+
else:
590+
Q_points.append(seed)
591+
592+
closure = []
593+
for s in reversed(star):
594+
cs, _ = mesh_dm.getTransitiveClosure(s, useCone=True)
595+
closure.extend(cs)
596+
# Grab unique points with stable ordering
597+
V_points.extend(reversed(dict.fromkeys(closure)))
598+
599+
indices = get_entity_dofs(Z[0], Z_local_ises_indices[0], V_points)
600+
indices.extend(get_entity_dofs(Z[1], Z_local_ises_indices[1], Q_points))
601+
iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF)
602+
return iset
603+
604+
605+
def get_point_coloring(plex, depth, distance):
606+
"""Return the point subsets for a coloring of the plex."""
607+
start, end = plex.getDepthStratum(depth)
608+
colors = plex.createColoring(depth=depth, distance=distance)
609+
color_indices = [c.indices for c in colors]
610+
if any(ci.size > 0 for ci in color_indices):
611+
offset = start - min(ci.min() for ci in color_indices if ci.size > 0)
612+
for c in color_indices:
613+
c += offset
614+
return color_indices
615+
616+
617+
def get_refined_cells(mesh):
618+
"""Return the cell indices corresponding to fine cells."""
619+
from firedrake.mg.utils import get_level
620+
from firedrake import FunctionSpace
621+
622+
sf = mesh.sfBC_orig
623+
plex = mesh.topology_dm
624+
cellNum = plex.getCellNumbering().indices
625+
cellNum[cellNum < 0] = -cellNum[cellNum < 0]-1
626+
fstart, fend = plex.getHeightStratum(0)
627+
cids = list(map(mesh._cell_numbering.getOffset, range(fstart, fend)))
628+
629+
# Create Netgen to Firedrake reordering
630+
M = FunctionSpace(mesh, "DG", 0)
631+
marked = M.dof_dset.layout_vec.copy()
632+
cstart, cend = marked.getOwnershipRange()
633+
marked.setValues(cellNum[cids[:cend-cstart]], numpy.arange(cstart, cend))
634+
marked.assemble()
635+
marked0 = marked
636+
if sf is not None:
637+
sfBCInv = sf.createInverse()
638+
_, marked0 = plex.distributeField(sfBCInv, mesh._cell_numbering, marked)
639+
perm = marked0.getArray().astype(PETSc.IntType)
640+
641+
# Get refined cells globally on rank 0
642+
if mesh.comm.rank == 0:
643+
tdim = mesh.topological_dimension
644+
if tdim == 2:
645+
parents = mesh.netgen_mesh.parentsurfaceelements.NumPy()
646+
elif tdim == 3:
647+
parents = mesh.netgen_mesh.parentelements.NumPy()
648+
else:
649+
raise ValueError("Need a 2D or 3D mesh")
650+
mh, level = get_level(mesh)
651+
if mh is not None and level > 0:
652+
coarse_ngmesh = mh[level-1].netgen_mesh
653+
num_coarse_cells = len(coarse_ngmesh.Elements2D()) if tdim == 2 else len(coarse_ngmesh.Elements3D())
654+
else:
655+
num_coarse_cells = parents.tolist().count((-1,))
656+
children = [[] for c in range(num_coarse_cells)]
657+
num_fine_cells = parents.shape[0]
658+
for f in range(num_fine_cells):
659+
c = f
660+
while c >= num_coarse_cells:
661+
c = parents[c][0]
662+
children[c].append(f)
663+
664+
cell_subset = list(set(chain.from_iterable(f for c, f in enumerate(children) if len(f) > 1)))
665+
if sf is not None:
666+
cell_subset = perm[cell_subset].tolist()
667+
cell_subset.sort()
668+
else:
669+
cell_subset = []
670+
671+
# Get refined cells locally on this rank
672+
cell_subset = mesh.comm.bcast(cell_subset, root=0)
673+
*_, cell_subset = numpy.intersect1d(cell_subset, cellNum, return_indices=True)
674+
cell_subset += fstart
675+
return cell_subset
676+
677+
678+
def get_adjacent_stratum(plex, depth, subset=None):
679+
"""Return point stratum subset adjacent to another point subset (of different depth)."""
680+
pstart, pend = plex.getDepthStratum(depth)
681+
if subset is None:
682+
points = range(pstart, pend)
683+
else:
684+
points = set()
685+
for s in subset:
686+
points.update(p for p in plex.getAdjacency(s) if pstart <= p < pend)
687+
points = list(points)
688+
return points

0 commit comments

Comments
 (0)