11import abc
22
3+ from itertools import chain
34from pyop2 .datatypes import IntType
45from firedrake .preconditioners .base import PCBase
56from firedrake .petsc import PETSc
@@ -140,68 +141,6 @@ def destroy(self, pc):
140141 self .asmpc .destroy ()
141142
142143
143- def get_entity_dofs (V , V_local_ises_indices , points ):
144- """Extract degrees of freedom associated to mesh entities (points of the DMPlex)."""
145- indices = []
146- for (i , W ) in enumerate (V ):
147- section = W .dm .getDefaultSection ()
148- for p in points :
149- dof = section .getDof (p )
150- if dof <= 0 :
151- continue
152- off = section .getOffset (p )
153- # Local indices within W
154- W_slice = slice (off * W .block_size , W .block_size * (off + dof ))
155- indices .extend (V_local_ises_indices [i ][W_slice ])
156- return indices
157-
158-
159- def build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , seed_points ):
160- """Build index sets for star patches."""
161- points = []
162- for seed in seed_points :
163- # Only build patches over owned DoFs
164- if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
165- continue
166- # Create point list from mesh DM
167- star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
168- star = order_points (mesh_dm , star , ordering , prefix )
169- points .extend (star )
170-
171- indices = get_entity_dofs (V , V_local_ises_indices , points )
172- iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
173- return iset
174-
175-
176- def build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix , include_star , seed_points ):
177- """Build index sets for Vanka patches."""
178- V_points = []
179- Q_points = []
180- for seed in seed_points :
181- # Only build patches over owned DoFs
182- if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
183- continue
184- # Create point list from mesh DM
185- star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
186- star = order_points (mesh_dm , star , ordering , prefix )
187- if include_star :
188- Q_points .extend (star )
189- else :
190- Q_points .append (seed )
191-
192- closure = []
193- for s in reversed (star ):
194- cs , _ = mesh_dm .getTransitiveClosure (s , useCone = True )
195- closure .extend (cs )
196- # Grab unique points with stable ordering
197- V_points .extend (reversed (dict .fromkeys (closure )))
198-
199- indices = get_entity_dofs (Z [0 ], Z_local_ises_indices [0 ], V_points )
200- indices .extend (get_entity_dofs (Z [1 ], Z_local_ises_indices [1 ], Q_points ))
201- iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
202- return iset
203-
204-
205144class ASMStarPC (ASMPatchPC ):
206145 '''Patch-based PC using Star of mesh entities implmented as an
207146 :class:`ASMPatchPC`.
@@ -214,13 +153,13 @@ class ASMStarPC(ASMPatchPC):
214153 _prefix = "pc_star_"
215154
216155 def get_patches (self , V ):
217- mesh = V ._mesh
156+ mesh = V .mesh ()
218157 if len (set (mesh )) == 1 :
219- mesh_unique = mesh .unique ()
158+ mesh = mesh .unique ()
220159 else :
221160 raise NotImplementedError ("Not implemented for general mixed meshes" )
222- mesh_dm = mesh_unique .topology_dm
223- if mesh_unique .cell_set ._extruded :
161+ mesh_dm = mesh .topology_dm
162+ if mesh .cell_set ._extruded :
224163 warning ("applying ASMStarPC on an extruded mesh" )
225164
226165 # Obtain the topological entities to use to construct the stars
@@ -229,25 +168,31 @@ def get_patches(self, V):
229168 depth = opts .getInt ("construct_dim" , default = 0 )
230169 coloring = opts .getBool ("coloring" , default = False )
231170 ordering = opts .getString ("mat_ordering_type" , default = "natural" )
232- validate_overlap (mesh_unique , depth , "star" )
171+ validate_overlap (mesh , depth , "star" )
233172
234173 # Accessing .indices causes the allocation of a global array,
235174 # so we need to cache these for efficiency
236175 V_local_ises_indices = tuple (iset .indices for iset in V .dof_dset .local_ises )
237176
238- (start , end ) = mesh_dm .getDepthStratum (depth )
239- if coloring and end > start :
240- colors = mesh_dm .createColoring (depth = depth , distance = 1 )
241- colors = [color for color in colors if color .getLocalSize () > 0 ]
242- if len (colors ) == 0 :
243- shift = 0
244- else :
245- shift = start - min (color .indices .min () for color in colors )
246- ises = [build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , color .indices + shift )
177+ point_subset = None
178+ if hasattr (mesh , "netgen_mesh" ):
179+ cell_subset = get_refined_cells (mesh )
180+ point_subset = get_adjacent_stratum (mesh_dm , depth , subset = cell_subset )
181+
182+ if coloring :
183+ colors = get_point_coloring (mesh_dm , depth , 1 )
184+ if point_subset is not None :
185+ colors = tuple (numpy .intersect1d (point_subset , color ) for color in colors )
186+
187+ ises = [build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , color )
247188 for color in colors ]
248189 else :
190+ (start , end ) = mesh_dm .getDepthStratum (depth )
191+ if point_subset is None :
192+ point_subset = range (start , end )
193+
249194 ises = [build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , (seed ,))
250- for seed in range ( start , end ) ]
195+ for seed in point_subset ]
251196 return ises
252197
253198
@@ -263,13 +208,13 @@ class ASMVankaPC(ASMPatchPC):
263208 _prefix = "pc_vanka_"
264209
265210 def get_patches (self , V ):
266- mesh = V ._mesh
211+ mesh = V .mesh ()
267212 if len (set (mesh )) == 1 :
268- mesh_unique = mesh .unique ()
213+ mesh = mesh .unique ()
269214 else :
270215 raise NotImplementedError ("Not implemented for general mixed meshes" )
271- mesh_dm = mesh_unique .topology_dm
272- if mesh_unique .layers :
216+ mesh_dm = mesh .topology_dm
217+ if mesh .layers :
273218 warning ("applying ASMVankaPC on an extruded mesh" )
274219
275220 # Obtain the topological entities to use to construct the stars
@@ -300,29 +245,29 @@ def splitting(V):
300245 Z_local_ises_indices = splitting (V_local_ises_indices )
301246
302247 # Build index sets for the patches
303- if depth != - 1 :
304- (start , end ) = mesh_dm .getDepthStratum (depth )
305- patch_dim = depth
306- else :
307- (start , end ) = mesh_dm .getHeightStratum (height )
308- patch_dim = mesh_dm .getDimension () - height
309- validate_overlap (mesh_unique , patch_dim , "vanka" )
310-
311- if start == end :
312- ises = []
313- elif coloring :
314- colors = mesh_dm .createColoring (depth = patch_dim , distance = 2 )
315- colors = [color for color in colors if color .getLocalSize () > 0 ]
316- if len (colors ) == 0 :
317- shift = start
318- else :
319- shift = start - min (color .indices .min () for color in colors )
248+ if depth == - 1 :
249+ depth = mesh_dm .getDimension () - height
250+ validate_overlap (mesh , depth , "vanka" )
251+
252+ point_subset = None
253+ if hasattr (mesh , "netgen_mesh" ):
254+ cell_subset = get_refined_cells (mesh )
255+ point_subset = get_adjacent_stratum (mesh_dm , depth , subset = cell_subset )
256+
257+ if coloring :
258+ colors = get_point_coloring (mesh_dm , depth , 2 )
259+ if point_subset is not None :
260+ colors = tuple (numpy .intersect1d (point_subset , color ) for color in colors )
261+
320262 ises = [build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix ,
321- include_star , color .indices + shift )
322- for color in colors ]
263+ include_star , color ) for color in colors ]
323264 else :
265+ (start , end ) = mesh_dm .getDepthStratum (depth )
266+ if point_subset is None :
267+ point_subset = range (start , end )
268+
324269 ises = [build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix , include_star , (seed ,))
325- for seed in range ( start , end ) ]
270+ for seed in point_subset ]
326271 return ises
327272
328273
@@ -348,7 +293,7 @@ class ASMLinesmoothPC(ASMPatchPC):
348293 _prefix = "pc_linesmooth_"
349294
350295 def get_patches (self , V ):
351- mesh = V ._mesh
296+ mesh = V .mesh ()
352297 if len (set (mesh )) == 1 :
353298 mesh_unique = mesh .unique ()
354299 else :
@@ -593,3 +538,151 @@ def validate_overlap(mesh, patch_dim, patch_type):
593538 if overlap_depth < patch_depth :
594539 warning (f"Mesh overlap depth of { overlap_depth } does not support { patch_type } -patches. "
595540 "Did you forget to set overlap_type in your mesh's distribution_parameters?" )
541+
542+
543+ def get_entity_dofs (V , V_local_ises_indices , points ):
544+ """Extract degrees of freedom associated to mesh entities (points of the DMPlex)."""
545+ indices = []
546+ for (i , W ) in enumerate (V ):
547+ section = W .dm .getDefaultSection ()
548+ for p in points :
549+ dof = section .getDof (p )
550+ if dof <= 0 :
551+ continue
552+ off = section .getOffset (p )
553+ # Local indices within W
554+ W_slice = slice (off * W .block_size , W .block_size * (off + dof ))
555+ indices .extend (V_local_ises_indices [i ][W_slice ])
556+ return indices
557+
558+
559+ def build_star_indices (V , V_local_ises_indices , mesh_dm , ordering , prefix , seed_points ):
560+ """Build index sets for star patches."""
561+ points = []
562+ for seed in seed_points :
563+ # Only build patches over owned DoFs
564+ if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
565+ continue
566+ # Create point list from mesh DM
567+ star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
568+ star = order_points (mesh_dm , star , ordering , prefix )
569+ points .extend (star )
570+
571+ indices = get_entity_dofs (V , V_local_ises_indices , points )
572+ iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
573+ return iset
574+
575+
576+ def build_vanka_indices (Z , Z_local_ises_indices , mesh_dm , ordering , prefix , include_star , seed_points ):
577+ """Build index sets for Vanka patches."""
578+ V_points = []
579+ Q_points = []
580+ for seed in seed_points :
581+ # Only build patches over owned DoFs
582+ if mesh_dm .getLabelValue ("pyop2_ghost" , seed ) != - 1 :
583+ continue
584+ # Create point list from mesh DM
585+ star , _ = mesh_dm .getTransitiveClosure (seed , useCone = False )
586+ star = order_points (mesh_dm , star , ordering , prefix )
587+ if include_star :
588+ Q_points .extend (star )
589+ else :
590+ Q_points .append (seed )
591+
592+ closure = []
593+ for s in reversed (star ):
594+ cs , _ = mesh_dm .getTransitiveClosure (s , useCone = True )
595+ closure .extend (cs )
596+ # Grab unique points with stable ordering
597+ V_points .extend (reversed (dict .fromkeys (closure )))
598+
599+ indices = get_entity_dofs (Z [0 ], Z_local_ises_indices [0 ], V_points )
600+ indices .extend (get_entity_dofs (Z [1 ], Z_local_ises_indices [1 ], Q_points ))
601+ iset = PETSc .IS ().createGeneral (indices , comm = PETSc .COMM_SELF )
602+ return iset
603+
604+
605+ def get_point_coloring (plex , depth , distance ):
606+ """Return the point subsets for a coloring of the plex."""
607+ start , end = plex .getDepthStratum (depth )
608+ colors = plex .createColoring (depth = depth , distance = distance )
609+ color_indices = [c .indices for c in colors ]
610+ if any (ci .size > 0 for ci in color_indices ):
611+ offset = start - min (ci .min () for ci in color_indices if ci .size > 0 )
612+ for c in color_indices :
613+ c += offset
614+ return color_indices
615+
616+
617+ def get_refined_cells (mesh ):
618+ """Return the cell indices corresponding to fine cells."""
619+ from firedrake .mg .utils import get_level
620+ from firedrake import FunctionSpace
621+
622+ sf = mesh .sfBC_orig
623+ plex = mesh .topology_dm
624+ cellNum = plex .getCellNumbering ().indices
625+ cellNum [cellNum < 0 ] = - cellNum [cellNum < 0 ]- 1
626+ fstart , fend = plex .getHeightStratum (0 )
627+ cids = list (map (mesh ._cell_numbering .getOffset , range (fstart , fend )))
628+
629+ # Create Netgen to Firedrake reordering
630+ M = FunctionSpace (mesh , "DG" , 0 )
631+ marked = M .dof_dset .layout_vec .copy ()
632+ cstart , cend = marked .getOwnershipRange ()
633+ marked .setValues (cellNum [cids [:cend - cstart ]], numpy .arange (cstart , cend ))
634+ marked .assemble ()
635+ marked0 = marked
636+ if sf is not None :
637+ sfBCInv = sf .createInverse ()
638+ _ , marked0 = plex .distributeField (sfBCInv , mesh ._cell_numbering , marked )
639+ perm = marked0 .getArray ().astype (PETSc .IntType )
640+
641+ # Get refined cells globally on rank 0
642+ if mesh .comm .rank == 0 :
643+ tdim = mesh .topological_dimension
644+ if tdim == 2 :
645+ parents = mesh .netgen_mesh .parentsurfaceelements .NumPy ()
646+ elif tdim == 3 :
647+ parents = mesh .netgen_mesh .parentelements .NumPy ()
648+ else :
649+ raise ValueError ("Need a 2D or 3D mesh" )
650+ mh , level = get_level (mesh )
651+ if mh is not None and level > 0 :
652+ coarse_ngmesh = mh [level - 1 ].netgen_mesh
653+ num_coarse_cells = len (coarse_ngmesh .Elements2D ()) if tdim == 2 else len (coarse_ngmesh .Elements3D ())
654+ else :
655+ num_coarse_cells = parents .tolist ().count ((- 1 ,))
656+ children = [[] for c in range (num_coarse_cells )]
657+ num_fine_cells = parents .shape [0 ]
658+ for f in range (num_fine_cells ):
659+ c = f
660+ while c >= num_coarse_cells :
661+ c = parents [c ][0 ]
662+ children [c ].append (f )
663+
664+ cell_subset = list (set (chain .from_iterable (f for c , f in enumerate (children ) if len (f ) > 1 )))
665+ if sf is not None :
666+ cell_subset = perm [cell_subset ].tolist ()
667+ cell_subset .sort ()
668+ else :
669+ cell_subset = []
670+
671+ # Get refined cells locally on this rank
672+ cell_subset = mesh .comm .bcast (cell_subset , root = 0 )
673+ * _ , cell_subset = numpy .intersect1d (cell_subset , cellNum , return_indices = True )
674+ cell_subset += fstart
675+ return cell_subset
676+
677+
678+ def get_adjacent_stratum (plex , depth , subset = None ):
679+ """Return point stratum subset adjacent to another point subset (of different depth)."""
680+ pstart , pend = plex .getDepthStratum (depth )
681+ if subset is None :
682+ points = range (pstart , pend )
683+ else :
684+ points = set ()
685+ for s in subset :
686+ points .update (p for p in plex .getAdjacency (s ) if pstart <= p < pend )
687+ points = list (points )
688+ return points
0 commit comments