@@ -2926,6 +2926,8 @@ def refine_marked_elements(self, mark, netgen_flags=None):
29262926 """
29272927 utils .check_netgen_installed ()
29282928
2929+ if not hasattr (self , "netgen_mesh" ):
2930+ raise ValueError ("Adaptive refinement requires a netgen mesh." )
29292931 if netgen_flags is None :
29302932 netgen_flags = self .netgen_flags
29312933 dim = self .geometric_dimension
@@ -2947,57 +2949,64 @@ def refine_marked_elements(self, mark, netgen_flags=None):
29472949 refine_faces = netgen_flags .get ("refine_faces" , False )
29482950 if self .comm .rank == 0 :
29492951 max_refs = int (mark_np .max ())
2950- for _ in range (max_refs ):
2951- netgen_cells = netgen_mesh .Elements3D () if dim == 3 else netgen_mesh .Elements2D ()
2952- netgen_cells .NumPy ()["refine" ][: mark_np . size ] = mark_np > 0
2952+ for r in range (max_refs ):
2953+ cells = netgen_mesh .Elements3D () if dim == 3 else netgen_mesh .Elements2D ()
2954+ cells .NumPy ()["refine" ] = ( mark_np [: len ( cells )] > 0 )
29532955 if not refine_faces and dim == 3 :
2954- netgen_mesh .Elements2D ().NumPy ()["refine" ] = 0
2956+ netgen_mesh .Elements2D ().NumPy ()["refine" ] = False
29552957 netgen_mesh .Refine (adaptive = True )
29562958 mark_np -= 1
2957-
2959+ if r < max_refs - 1 :
2960+ parents = netgen_mesh .parentelements if dim == 3 else netgen_mesh .parentsurfaceelements
2961+ parents = parents .NumPy ().astype (int ).flatten ()
2962+ cells = netgen_mesh .Elements3D () if dim == 3 else netgen_mesh .Elements2D ()
2963+ num_coarse_cells = len (cells )
2964+ num_fine_cells = parents .shape [0 ]
2965+ indices = np .arange (num_fine_cells , dtype = int )
2966+ fine_cells = indices > num_coarse_cells
2967+ indices [fine_cells ] = parents [indices [fine_cells ]]
2968+ mark_np = mark_np [indices ]
2969+
2970+ self .comm .Barrier ()
29582971 return Mesh (netgen_mesh ,
29592972 reorder = self ._did_reordering ,
29602973 distribution_parameters = self ._distribution_parameters ,
29612974 comm = self .comm ,
29622975 netgen_flags = netgen_flags )
29632976
29642977 @PETSc .Log .EventDecorator ()
2965- def curve_field (self , order , permutation_tol = 1e-8 , location_tol = 1e-1 , cg_field = False ):
2978+ def curve_field (self , order , permutation_tol = 1e-8 , cg_field = False ):
29662979 '''Return a function containing the curved coordinates of the mesh.
29672980
29682981 This method requires that the mesh has been constructed from a
29692982 netgen mesh.
29702983
29712984 :arg order: the order of the curved mesh.
29722985 :arg permutation_tol: tolerance used to construct the permutation of the reference element.
2973- :arg location_tol: tolerance used to locate the cell a point belongs to.
29742986 :arg cg_field: return a CG function field representing the mesh, rather than the
29752987 default DG field.
29762988
29772989 '''
2978- import firedrake as fd
2990+ from firedrake .netgen import find_permutation , netgen_to_plex_numbering , netgen_distribute
2991+ from firedrake .functionspace import VectorFunctionSpace , FunctionSpace
2992+ from firedrake .function import Function
29792993
29802994 utils .check_netgen_installed ()
2981-
2982- from firedrake .netgen import find_permutation
2983-
2984- netgen_mesh = self .netgen_mesh
29852995 # Check if the mesh is a surface mesh or two dimensional mesh
2986- if len ( netgen_mesh . Elements3D ()) == 0 :
2987- ng_element = netgen_mesh .Elements2D ()
2996+ if self . topological_dimension == 2 :
2997+ ng_element = self . netgen_mesh .Elements2D ()
29882998 else :
2989- ng_element = netgen_mesh .Elements3D ()
2999+ ng_element = self . netgen_mesh .Elements3D ()
29903000 ng_dimension = len (ng_element )
2991- geom_dim = self .geometric_dimension
29923001
29933002 # Construct the mesh as a Firedrake function
29943003 if cg_field :
2995- firedrake_space = fd . VectorFunctionSpace (self , "CG" , order )
3004+ coords_space = VectorFunctionSpace (self , "CG" , order )
29963005 else :
29973006 low_order_element = self .coordinates .function_space ().ufl_element ().sub_elements [0 ]
29983007 ufl_element = low_order_element .reconstruct (degree = order )
2999- firedrake_space = fd . VectorFunctionSpace (self , fd .BrokenElement (ufl_element ))
3000- new_coordinates = fd . assemble ( fd .interpolate (self .coordinates , firedrake_space ) )
3008+ coords_space = VectorFunctionSpace (self , finat . ufl .BrokenElement (ufl_element ))
3009+ new_coordinates = Function ( coords_space ) .interpolate (self .coordinates )
30013010
30023011 # Compute reference points using fiat
30033012 fiat_element = new_coordinates .function_space ().finat_element .fiat_equivalent
@@ -3012,75 +3021,50 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
30123021 ref .append (pt )
30133022 reference_space_points = np .array (ref )
30143023
3015- # Curve the mesh on rank 0 only
3016- if self .comm .rank == 0 :
3017- # Construct numpy arrays for physical domain data
3018- physical_space_points = np .zeros (
3019- (ng_dimension , reference_space_points .shape [0 ], geom_dim )
3020- )
3021- curved_space_points = np .zeros (
3022- (ng_dimension , reference_space_points .shape [0 ], geom_dim )
3023- )
3024- netgen_mesh .CalcElementMapping (reference_space_points , physical_space_points )
3025- # NOTE: This will segfault for CSG!
3026- netgen_mesh .Curve (order )
3027- netgen_mesh .CalcElementMapping (reference_space_points , curved_space_points )
3028- curved = ng_element .NumPy ()["curved" ]
3029-
3030- # Broadcast a boolean array identifying curved cells
3031- curved = self .comm .bcast (curved , root = 0 )
3032- physical_space_points = physical_space_points [curved ]
3033- curved_space_points = curved_space_points [curved ]
3034- else :
3035- curved = self .comm .bcast (None , root = 0 )
3036- # Construct numpy arrays as buffers to receive physical domain data
3037- ncurved = np .sum (curved )
3038- physical_space_points = np .zeros (
3039- (ncurved , reference_space_points .shape [0 ], geom_dim )
3040- )
3041- curved_space_points = np .zeros (
3042- (ncurved , reference_space_points .shape [0 ], geom_dim )
3043- )
3024+ # Construct numpy arrays for physical domain data
3025+ physical_space_points = np .zeros (
3026+ (ng_dimension , reference_space_points .shape [0 ], self .geometric_dimension )
3027+ )
3028+ curved_space_points = np .zeros (
3029+ (ng_dimension , reference_space_points .shape [0 ], self .geometric_dimension )
3030+ )
3031+ self .netgen_mesh .CalcElementMapping (reference_space_points , physical_space_points )
3032+ # NOTE: This will segfault for CSG!
3033+ self .netgen_mesh .Curve (order )
3034+ self .netgen_mesh .CalcElementMapping (reference_space_points , curved_space_points )
3035+ curved = ng_element .NumPy ()["curved" ]
3036+
3037+ # Get numbering
3038+ DG0 = FunctionSpace (self , "DG" , 0 )
3039+ rstart , rend = DG0 .dof_dset .layout_vec .getOwnershipRange ()
3040+ num_cells = rend - rstart
3041+ _ , iperm = netgen_to_plex_numbering (self )
3042+ iperm -= rstart
3043+
3044+ # Distribute curved cell data
3045+ own_curved = netgen_distribute (self , curved )
3046+ own_curved = np .flatnonzero (own_curved [:num_cells ])
30443047
3045- # Broadcast curved cell point data
3046- physical_space_points = self .comm .bcast (physical_space_points , root = 0 )
3047- curved_space_points = self .comm .bcast (curved_space_points , root = 0 )
30483048 cell_node_map = new_coordinates .cell_node_map ()
3049+ pyop2_index = cell_node_map .values [iperm [own_curved ]]
30493050
3050- # Select only the points in curved cells
3051- barycentres = np .average (physical_space_points , axis = 1 )
3052- ng_index = list (map (lambda x : self .locate_cell (x , tolerance = location_tol ), barycentres ))
3053-
3054- # Select only the indices of cells owned by this rank
3055- owned = [(0 <= ii < len (cell_node_map .values )) if ii is not None else False for ii in ng_index ]
3051+ # Distribute coordinate data
3052+ own_curved_points = netgen_distribute (self , curved_space_points )[own_curved ]
3053+ own_physical_points = netgen_distribute (self , physical_space_points )[own_curved ]
30563054
3057- # Get the PyOP2 indices corresponding to the netgen indices
3058- ng_index = [idx for idx , o in zip (ng_index , owned ) if o ]
3059- pyop2_index = cell_node_map .values [ng_index ].flatten ()
3060-
3061- # Select only the points owned by this rank
3062- physical_space_points = physical_space_points [owned ]
3063- curved_space_points = curved_space_points [owned ]
3064- barycentres = barycentres [owned ]
3065-
3066- if any (owned ):
3055+ if any (own_curved ):
30673056 # Find the correct coordinate permutation for each cell
30683057 permutation = find_permutation (
3069- physical_space_points ,
3070- new_coordinates .dat .data [pyop2_index ].real .reshape (
3071- physical_space_points .shape
3072- ),
3073- tol = permutation_tol
3058+ own_physical_points ,
3059+ new_coordinates .dat .data [pyop2_index ].real ,
3060+ tol = permutation_tol ,
30743061 )
3075-
30763062 # Apply the permutation to each cell in turn
3077- for ii , p in enumerate (curved_space_points ):
3078- curved_space_points [ii ] = p [permutation [ii ]]
3079- else :
3080- print ("barf" )
3063+ for ii , p in enumerate (own_curved_points ):
3064+ own_curved_points [ii ] = p [permutation [ii ]]
30813065
30823066 # Assign the curved coordinates to the dat
3083- new_coordinates .dat .data [pyop2_index ] = curved_space_points . reshape ( - 1 , geom_dim )
3067+ new_coordinates .dat .data [pyop2_index ] = own_curved_points
30843068 return new_coordinates
30853069
30863070
0 commit comments