Skip to content

Commit e877b90

Browse files
committed
Fixes
1 parent 574be15 commit e877b90

4 files changed

Lines changed: 134 additions & 184 deletions

File tree

firedrake/mesh.py

Lines changed: 62 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,6 +2926,8 @@ def refine_marked_elements(self, mark, netgen_flags=None):
29262926
"""
29272927
utils.check_netgen_installed()
29282928

2929+
if not hasattr(self, "netgen_mesh"):
2930+
raise ValueError("Adaptive refinement requires a netgen mesh.")
29292931
if netgen_flags is None:
29302932
netgen_flags = self.netgen_flags
29312933
dim = self.geometric_dimension
@@ -2947,57 +2949,64 @@ def refine_marked_elements(self, mark, netgen_flags=None):
29472949
refine_faces = netgen_flags.get("refine_faces", False)
29482950
if self.comm.rank == 0:
29492951
max_refs = int(mark_np.max())
2950-
for _ in range(max_refs):
2951-
netgen_cells = netgen_mesh.Elements3D() if dim == 3 else netgen_mesh.Elements2D()
2952-
netgen_cells.NumPy()["refine"][:mark_np.size] = mark_np > 0
2952+
for r in range(max_refs):
2953+
cells = netgen_mesh.Elements3D() if dim == 3 else netgen_mesh.Elements2D()
2954+
cells.NumPy()["refine"] = (mark_np[:len(cells)] > 0)
29532955
if not refine_faces and dim == 3:
2954-
netgen_mesh.Elements2D().NumPy()["refine"] = 0
2956+
netgen_mesh.Elements2D().NumPy()["refine"] = False
29552957
netgen_mesh.Refine(adaptive=True)
29562958
mark_np -= 1
2957-
2959+
if r < max_refs - 1:
2960+
parents = netgen_mesh.parentelements if dim == 3 else netgen_mesh.parentsurfaceelements
2961+
parents = parents.NumPy().astype(int).flatten()
2962+
cells = netgen_mesh.Elements3D() if dim == 3 else netgen_mesh.Elements2D()
2963+
num_coarse_cells = len(cells)
2964+
num_fine_cells = parents.shape[0]
2965+
indices = np.arange(num_fine_cells, dtype=int)
2966+
fine_cells = indices > num_coarse_cells
2967+
indices[fine_cells] = parents[indices[fine_cells]]
2968+
mark_np = mark_np[indices]
2969+
2970+
self.comm.Barrier()
29582971
return Mesh(netgen_mesh,
29592972
reorder=self._did_reordering,
29602973
distribution_parameters=self._distribution_parameters,
29612974
comm=self.comm,
29622975
netgen_flags=netgen_flags)
29632976

29642977
@PETSc.Log.EventDecorator()
2965-
def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=False):
2978+
def curve_field(self, order, permutation_tol=1e-8, cg_field=False):
29662979
'''Return a function containing the curved coordinates of the mesh.
29672980
29682981
This method requires that the mesh has been constructed from a
29692982
netgen mesh.
29702983
29712984
:arg order: the order of the curved mesh.
29722985
:arg permutation_tol: tolerance used to construct the permutation of the reference element.
2973-
:arg location_tol: tolerance used to locate the cell a point belongs to.
29742986
:arg cg_field: return a CG function field representing the mesh, rather than the
29752987
default DG field.
29762988
29772989
'''
2978-
import firedrake as fd
2990+
from firedrake.netgen import find_permutation, netgen_to_plex_numbering, netgen_distribute
2991+
from firedrake.functionspace import VectorFunctionSpace, FunctionSpace
2992+
from firedrake.function import Function
29792993

29802994
utils.check_netgen_installed()
2981-
2982-
from firedrake.netgen import find_permutation
2983-
2984-
netgen_mesh = self.netgen_mesh
29852995
# Check if the mesh is a surface mesh or two dimensional mesh
2986-
if len(netgen_mesh.Elements3D()) == 0:
2987-
ng_element = netgen_mesh.Elements2D()
2996+
if self.topological_dimension == 2:
2997+
ng_element = self.netgen_mesh.Elements2D()
29882998
else:
2989-
ng_element = netgen_mesh.Elements3D()
2999+
ng_element = self.netgen_mesh.Elements3D()
29903000
ng_dimension = len(ng_element)
2991-
geom_dim = self.geometric_dimension
29923001

29933002
# Construct the mesh as a Firedrake function
29943003
if cg_field:
2995-
firedrake_space = fd.VectorFunctionSpace(self, "CG", order)
3004+
coords_space = VectorFunctionSpace(self, "CG", order)
29963005
else:
29973006
low_order_element = self.coordinates.function_space().ufl_element().sub_elements[0]
29983007
ufl_element = low_order_element.reconstruct(degree=order)
2999-
firedrake_space = fd.VectorFunctionSpace(self, fd.BrokenElement(ufl_element))
3000-
new_coordinates = fd.assemble(fd.interpolate(self.coordinates, firedrake_space))
3008+
coords_space = VectorFunctionSpace(self, finat.ufl.BrokenElement(ufl_element))
3009+
new_coordinates = Function(coords_space).interpolate(self.coordinates)
30013010

30023011
# Compute reference points using fiat
30033012
fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent
@@ -3012,75 +3021,50 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
30123021
ref.append(pt)
30133022
reference_space_points = np.array(ref)
30143023

3015-
# Curve the mesh on rank 0 only
3016-
if self.comm.rank == 0:
3017-
# Construct numpy arrays for physical domain data
3018-
physical_space_points = np.zeros(
3019-
(ng_dimension, reference_space_points.shape[0], geom_dim)
3020-
)
3021-
curved_space_points = np.zeros(
3022-
(ng_dimension, reference_space_points.shape[0], geom_dim)
3023-
)
3024-
netgen_mesh.CalcElementMapping(reference_space_points, physical_space_points)
3025-
# NOTE: This will segfault for CSG!
3026-
netgen_mesh.Curve(order)
3027-
netgen_mesh.CalcElementMapping(reference_space_points, curved_space_points)
3028-
curved = ng_element.NumPy()["curved"]
3029-
3030-
# Broadcast a boolean array identifying curved cells
3031-
curved = self.comm.bcast(curved, root=0)
3032-
physical_space_points = physical_space_points[curved]
3033-
curved_space_points = curved_space_points[curved]
3034-
else:
3035-
curved = self.comm.bcast(None, root=0)
3036-
# Construct numpy arrays as buffers to receive physical domain data
3037-
ncurved = np.sum(curved)
3038-
physical_space_points = np.zeros(
3039-
(ncurved, reference_space_points.shape[0], geom_dim)
3040-
)
3041-
curved_space_points = np.zeros(
3042-
(ncurved, reference_space_points.shape[0], geom_dim)
3043-
)
3024+
# Construct numpy arrays for physical domain data
3025+
physical_space_points = np.zeros(
3026+
(ng_dimension, reference_space_points.shape[0], self.geometric_dimension)
3027+
)
3028+
curved_space_points = np.zeros(
3029+
(ng_dimension, reference_space_points.shape[0], self.geometric_dimension)
3030+
)
3031+
self.netgen_mesh.CalcElementMapping(reference_space_points, physical_space_points)
3032+
# NOTE: This will segfault for CSG!
3033+
self.netgen_mesh.Curve(order)
3034+
self.netgen_mesh.CalcElementMapping(reference_space_points, curved_space_points)
3035+
curved = ng_element.NumPy()["curved"]
3036+
3037+
# Get numbering
3038+
DG0 = FunctionSpace(self, "DG", 0)
3039+
rstart, rend = DG0.dof_dset.layout_vec.getOwnershipRange()
3040+
num_cells = rend - rstart
3041+
_, iperm = netgen_to_plex_numbering(self)
3042+
iperm -= rstart
3043+
3044+
# Distribute curved cell data
3045+
own_curved = netgen_distribute(self, curved)
3046+
own_curved = np.flatnonzero(own_curved[:num_cells])
30443047

3045-
# Broadcast curved cell point data
3046-
physical_space_points = self.comm.bcast(physical_space_points, root=0)
3047-
curved_space_points = self.comm.bcast(curved_space_points, root=0)
30483048
cell_node_map = new_coordinates.cell_node_map()
3049+
pyop2_index = cell_node_map.values[iperm[own_curved]]
30493050

3050-
# Select only the points in curved cells
3051-
barycentres = np.average(physical_space_points, axis=1)
3052-
ng_index = list(map(lambda x: self.locate_cell(x, tolerance=location_tol), barycentres))
3053-
3054-
# Select only the indices of cells owned by this rank
3055-
owned = [(0 <= ii < len(cell_node_map.values)) if ii is not None else False for ii in ng_index]
3051+
# Distribute coordinate data
3052+
own_curved_points = netgen_distribute(self, curved_space_points)[own_curved]
3053+
own_physical_points = netgen_distribute(self, physical_space_points)[own_curved]
30563054

3057-
# Get the PyOP2 indices corresponding to the netgen indices
3058-
ng_index = [idx for idx, o in zip(ng_index, owned) if o]
3059-
pyop2_index = cell_node_map.values[ng_index].flatten()
3060-
3061-
# Select only the points owned by this rank
3062-
physical_space_points = physical_space_points[owned]
3063-
curved_space_points = curved_space_points[owned]
3064-
barycentres = barycentres[owned]
3065-
3066-
if any(owned):
3055+
if any(own_curved):
30673056
# Find the correct coordinate permutation for each cell
30683057
permutation = find_permutation(
3069-
physical_space_points,
3070-
new_coordinates.dat.data[pyop2_index].real.reshape(
3071-
physical_space_points.shape
3072-
),
3073-
tol=permutation_tol
3058+
own_physical_points,
3059+
new_coordinates.dat.data[pyop2_index].real,
3060+
tol=permutation_tol,
30743061
)
3075-
30763062
# Apply the permutation to each cell in turn
3077-
for ii, p in enumerate(curved_space_points):
3078-
curved_space_points[ii] = p[permutation[ii]]
3079-
else:
3080-
print("barf")
3063+
for ii, p in enumerate(own_curved_points):
3064+
own_curved_points[ii] = p[permutation[ii]]
30813065

30823066
# Assign the curved coordinates to the dat
3083-
new_coordinates.dat.data[pyop2_index] = curved_space_points.reshape(-1, geom_dim)
3067+
new_coordinates.dat.data[pyop2_index] = own_curved_points
30843068
return new_coordinates
30853069

30863070

firedrake/mg/netgen.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,15 +256,13 @@ def NetgenHierarchy(mesh, levs, flags, distribution_parameters=None):
256256
cg = flags.get("cg", False)
257257
nested = flags.get("nested", snap in ["coarse"])
258258
permutation_tol = flags.get("permutation_tol", 1e-8)
259-
location_tol = flags.get("location_tol", 1e-8)
260259
# Firedrake quantities
261260
meshes = []
262261
lgmaps = []
263262
# Curve the mesh
264263
if mesh.coordinates.function_space().ufl_element().degree() != order[0]:
265264
coordinates = mesh.curve_field(
266265
order=order[0],
267-
location_tol=location_tol,
268266
permutation_tol=permutation_tol,
269267
cg_field=cg,
270268
)
@@ -329,7 +327,6 @@ def NetgenHierarchy(mesh, levs, flags, distribution_parameters=None):
329327
if snap == "geometry":
330328
coordinates = mesh.curve_field(
331329
order=order[l+1],
332-
location_tol=location_tol,
333330
permutation_tol=permutation_tol,
334331
cg_field=cg,
335332
)

firedrake/netgen.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,72 @@ class comp:
2929
Mesh = type(None)
3030

3131

32+
def netgen_distribute(mesh, netgen_data):
33+
from firedrake import FunctionSpace
34+
# Create Netgen to Plex reordering
35+
plex = mesh.topology_dm
36+
sf = mesh.sfBC_orig
37+
perm, iperm = netgen_to_plex_numbering(mesh)
38+
if sf is not None:
39+
netgen_data = np.asarray(netgen_data)
40+
dtype = netgen_data.dtype
41+
dtype = mesh.comm.bcast(dtype, root=0)
42+
43+
netgen_data = netgen_data.transpose()
44+
shp = netgen_data.shape[:-1]
45+
shp = mesh.comm.bcast(shp, root=0)
46+
if mesh.comm.rank != 0:
47+
netgen_data = np.empty((*shp, 0), dtype=dtype)
48+
49+
M = FunctionSpace(mesh, "DG", 0)
50+
marked = M.dof_dset.layout_vec.copy()
51+
marked.set(0)
52+
53+
sfBCInv = sf.createInverse()
54+
section, marked0 = plex.distributeField(sfBCInv, mesh._cell_numbering, marked)
55+
plex_data = None
56+
for i in np.ndindex(shp):
57+
marked0[:netgen_data.shape[-1]] = netgen_data[i]
58+
_, marked = plex.distributeField(sf, section, marked0)
59+
arr = marked.getArray()
60+
if plex_data is None:
61+
plex_data = np.empty(shp + arr.shape, dtype=dtype)
62+
plex_data[i] = arr.astype(dtype)
63+
64+
plex_data = plex_data.transpose()
65+
else:
66+
plex_data = netgen_data
67+
return plex_data
68+
69+
70+
def netgen_to_plex_numbering(mesh):
71+
from firedrake import FunctionSpace
72+
73+
sf = mesh.sfBC_orig
74+
plex = mesh.topology_dm
75+
cellNum = plex.getCellNumbering().indices
76+
cellNum[cellNum < 0] = -cellNum[cellNum < 0]-1
77+
fstart, fend = plex.getHeightStratum(0)
78+
cids = list(map(mesh._cell_numbering.getOffset, range(fstart, fend)))
79+
80+
# Create Netgen to Plex reordering
81+
M = FunctionSpace(mesh, "DG", 0)
82+
marked = M.dof_dset.layout_vec.copy()
83+
marked.set(0)
84+
85+
cstart, cend = marked.getOwnershipRange()
86+
iperm = cellNum[cids[:cend-cstart]]
87+
marked.setValues(iperm, np.arange(cstart, cend))
88+
marked.assemble()
89+
marked0 = marked
90+
if sf is not None:
91+
sfBCInv = sf.createInverse()
92+
_, marked0 = plex.distributeField(sfBCInv, mesh._cell_numbering, marked)
93+
94+
perm = marked0.getArray()[:M.dim()].astype(PETSc.IntType)
95+
return perm, iperm
96+
97+
3298
@PETSc.Log.EventDecorator()
3399
def find_permutation(points_a: npt.NDArray[np.inexact], points_b: npt.NDArray[np.inexact],
34100
tol: float = 1e-5):

0 commit comments

Comments
 (0)