Skip to content

Commit 9805342

Browse files
committed
WIP
1 parent 574be15 commit 9805342

3 files changed

Lines changed: 139 additions & 24 deletions

File tree

firedrake/mesh.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2926,6 +2926,8 @@ def refine_marked_elements(self, mark, netgen_flags=None):
29262926
"""
29272927
utils.check_netgen_installed()
29282928

2929+
if not hasattr(self, "netgen_mesh"):
2930+
raise ValueError("Adaptive refinement requires a netgen mesh.")
29292931
if netgen_flags is None:
29302932
netgen_flags = self.netgen_flags
29312933
dim = self.geometric_dimension
@@ -2947,13 +2949,22 @@ def refine_marked_elements(self, mark, netgen_flags=None):
29472949
refine_faces = netgen_flags.get("refine_faces", False)
29482950
if self.comm.rank == 0:
29492951
max_refs = int(mark_np.max())
2950-
for _ in range(max_refs):
2951-
netgen_cells = netgen_mesh.Elements3D() if dim == 3 else netgen_mesh.Elements2D()
2952-
netgen_cells.NumPy()["refine"][:mark_np.size] = mark_np > 0
2952+
for r in range(max_refs):
2953+
cells = netgen_mesh.Elements3D() if dim == 3 else netgen_mesh.Elements2D()
2954+
cells.NumPy()["refine"] = (mark_np > 0)
29532955
if not refine_faces and dim == 3:
2954-
netgen_mesh.Elements2D().NumPy()["refine"] = 0
2956+
netgen_mesh.Elements2D().NumPy()["refine"] = False
29552957
netgen_mesh.Refine(adaptive=True)
29562958
mark_np -= 1
2959+
if r < max_refs - 1:
2960+
parents = netgen_mesh.parentelements if dim == 3 else netgen_mesh.parentsurfaceelements
2961+
parents = parents.NumPy().astype(int).flatten()
2962+
num_coarse_cells = mark_np.size
2963+
num_fine_cells = parents.shape[0]
2964+
fine_cells = indices > num_coarse_cells
2965+
indices = np.arange(num_fine_cells, dtype=int)
2966+
indices[fine_cells] = parents[indices[fine_cells]]
2967+
mark_np = mark_np[indices]
29572968

29582969
return Mesh(netgen_mesh,
29592970
reorder=self._did_reordering,
@@ -2979,7 +2990,7 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
29792990

29802991
utils.check_netgen_installed()
29812992

2982-
from firedrake.netgen import find_permutation
2993+
from firedrake.netgen import find_permutation, netgen_to_plex_numbering, netgen_distribute
29832994

29842995
netgen_mesh = self.netgen_mesh
29852996
# Check if the mesh is a surface mesh or two dimensional mesh
@@ -3012,6 +3023,12 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
30123023
ref.append(pt)
30133024
reference_space_points = np.array(ref)
30143025

3026+
# Get numbering
3027+
cell_node_map = new_coordinates.cell_node_map()
3028+
DG0 = fd.FunctionSpace(self, "DG", 0)
3029+
rstart, rend = DG0.dof_dset.layout_vec.getOwnershipRange()
3030+
perm, iperm = netgen_to_plex_numbering(self)
3031+
30153032
# Curve the mesh on rank 0 only
30163033
if self.comm.rank == 0:
30173034
# Construct numpy arrays for physical domain data
@@ -3027,43 +3044,72 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
30273044
netgen_mesh.CalcElementMapping(reference_space_points, curved_space_points)
30283045
curved = ng_element.NumPy()["curved"]
30293046

3030-
# Broadcast a boolean array identifying curved cells
3031-
curved = self.comm.bcast(curved, root=0)
30323047
physical_space_points = physical_space_points[curved]
30333048
curved_space_points = curved_space_points[curved]
30343049
else:
3035-
curved = self.comm.bcast(None, root=0)
3036-
# Construct numpy arrays as buffers to receive physical domain data
3037-
ncurved = np.sum(curved)
3038-
physical_space_points = np.zeros(
3039-
(ncurved, reference_space_points.shape[0], geom_dim)
3040-
)
3041-
curved_space_points = np.zeros(
3042-
(ncurved, reference_space_points.shape[0], geom_dim)
3043-
)
3050+
curved = []
3051+
physical_space_points = []
3052+
curved_space_points = []
30443053

30453054
# Broadcast curved cell point data
3055+
iperm -= rstart
3056+
num_cells = rend - rstart
3057+
own_curved = netgen_distribute(self, curved)
3058+
own_curved = np.flatnonzero(own_curved[:num_cells])
3059+
3060+
own_curved_points = netgen_distribute(self, curved_space_points)[own_curved]
3061+
own_physical_points = netgen_distribute(self, physical_space_points)[own_curved]
3062+
pyop2_index = cell_node_map.values[iperm[own_curved]].flatten().tolist()
3063+
3064+
if False:
3065+
if any(own_curved):
3066+
# Find the correct coordinate permutation for each cell
3067+
permutation = find_permutation(
3068+
own_physical_points,
3069+
new_coordinates.dat.data[pyop2_index].real.reshape(
3070+
own_physical_points.shape
3071+
),
3072+
tol=permutation_tol,
3073+
)
3074+
3075+
# Apply the permutation to each cell in turn
3076+
for ii, p in enumerate(own_curved_points):
3077+
own_curved_points[ii] = p[permutation[ii]]
3078+
3079+
# Assign the curved coordinates to the dat
3080+
new_coordinates.dat.data[pyop2_index] = own_curved_points
3081+
return new_coordinates
3082+
3083+
# Broadcast curved cell point data
3084+
curved = self.comm.bcast(curved, root=0)
30463085
physical_space_points = self.comm.bcast(physical_space_points, root=0)
30473086
curved_space_points = self.comm.bcast(curved_space_points, root=0)
3048-
cell_node_map = new_coordinates.cell_node_map()
3087+
3088+
print("", self.comm.rank, "distr", own_curved.tolist(), "\n",
3089+
self.comm.rank, "bcast", np.flatnonzero(curved[rstart:rend]).tolist(), flush=True)
30493090

30503091
# Select only the points in curved cells
30513092
barycentres = np.average(physical_space_points, axis=1)
30523093
ng_index = list(map(lambda x: self.locate_cell(x, tolerance=location_tol), barycentres))
3094+
ng_index = np.array(ng_index)
30533095

30543096
# Select only the indices of cells owned by this rank
30553097
owned = [(0 <= ii < len(cell_node_map.values)) if ii is not None else False for ii in ng_index]
30563098

30573099
# Get the PyOP2 indices corresponding to the netgen indices
3058-
ng_index = [idx for idx, o in zip(ng_index, owned) if o]
3100+
ng_index = ng_index[owned].tolist()
3101+
old_index = pyop2_index
30593102
pyop2_index = cell_node_map.values[ng_index].flatten()
30603103

3061-
# Select only the points owned by this rank
3062-
physical_space_points = physical_space_points[owned]
3063-
curved_space_points = curved_space_points[owned]
3064-
barycentres = barycentres[owned]
3104+
#print("", self.comm.rank, "ng_index", list(map(int, ng_index)), "\n",
3105+
# self.comm.rank, "iperm ", iperm.tolist(), flush=True)
30653106

30663107
if any(owned):
3108+
# Select only the points owned by this rank
3109+
physical_space_points = physical_space_points[owned]
3110+
curved_space_points = curved_space_points[owned]
3111+
# print(physical_space_points - own_physical_points)
3112+
30673113
# Find the correct coordinate permutation for each cell
30683114
permutation = find_permutation(
30693115
physical_space_points,
@@ -3076,8 +3122,6 @@ def curve_field(self, order, permutation_tol=1e-8, location_tol=1e-1, cg_field=F
30763122
# Apply the permutation to each cell in turn
30773123
for ii, p in enumerate(curved_space_points):
30783124
curved_space_points[ii] = p[permutation[ii]]
3079-
else:
3080-
print("barf")
30813125

30823126
# Assign the curved coordinates to the dat
30833127
new_coordinates.dat.data[pyop2_index] = curved_space_points.reshape(-1, geom_dim)

firedrake/netgen.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,73 @@ class comp:
2929
Mesh = type(None)
3030

3131

32+
def netgen_distribute(mesh, netgen_data):
33+
from firedrake import FunctionSpace
34+
# Create Netgen to Plex reordering
35+
plex = mesh.topology_dm
36+
sf = mesh.sfBC_orig
37+
perm, iperm = netgen_to_plex_numbering(mesh)
38+
if sf is not None:
39+
netgen_data = np.asarray(netgen_data)
40+
dtype = netgen_data.dtype
41+
dtype = mesh.comm.bcast(dtype, root=0)
42+
43+
netgen_data = netgen_data.transpose()
44+
shp = netgen_data.shape[:-1]
45+
shp = mesh.comm.bcast(shp, root=0)
46+
if mesh.comm.rank != 0:
47+
netgen_data = np.empty((*shp, 0), dtype=dtype)
48+
49+
M = FunctionSpace(mesh, "DG", 0)
50+
marked = M.dof_dset.layout_vec.copy()
51+
marked.set(0)
52+
53+
sfBCInv = sf.createInverse()
54+
section, marked0 = plex.distributeField(sfBCInv, mesh._cell_numbering, marked)
55+
plex_data = None
56+
for i in np.ndindex(shp):
57+
marked0[:netgen_data.shape[-1]] = netgen_data[i]
58+
_, marked = plex.distributeField(sf, section, marked0)
59+
arr = marked.getArray()
60+
if plex_data is None:
61+
plex_data = np.empty(shp + arr.shape, dtype=dtype)
62+
plex_data[i] = arr.astype(dtype)
63+
64+
plex_data = plex_data.transpose()
65+
else:
66+
plex_data = netgen_data
67+
return plex_data
68+
69+
70+
def netgen_to_plex_numbering(mesh):
71+
from firedrake import FunctionSpace
72+
73+
sf = mesh.sfBC_orig
74+
plex = mesh.topology_dm
75+
cellNum = plex.getCellNumbering().indices
76+
cellNum[cellNum < 0] = -cellNum[cellNum < 0]-1
77+
fstart, fend = plex.getHeightStratum(0)
78+
cids = list(map(mesh._cell_numbering.getOffset, range(fstart, fend)))
79+
num_cells = fend - fstart
80+
81+
# Create Netgen to Plex reordering
82+
M = FunctionSpace(mesh, "DG", 0)
83+
marked = M.dof_dset.layout_vec.copy()
84+
marked.set(0)
85+
86+
cstart, cend = marked.getOwnershipRange()
87+
iperm = cellNum[cids[:cend-cstart]]
88+
marked.setValues(iperm, np.arange(cstart, cend))
89+
marked.assemble()
90+
marked0 = marked
91+
if sf is not None:
92+
sfBCInv = sf.createInverse()
93+
_, marked0 = plex.distributeField(sfBCInv, mesh._cell_numbering, marked)
94+
95+
perm = marked0.getArray()[:M.dim()].astype(PETSc.IntType)
96+
return perm, iperm
97+
98+
3299
@PETSc.Log.EventDecorator()
33100
def find_permutation(points_a: npt.NDArray[np.inexact], points_b: npt.NDArray[np.inexact],
34101
tol: float = 1e-5):

tests/firedrake/regression/test_netgen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def adapt(mesh, eta):
326326
(eta, error_est) = estimate_error(mesh, uh)
327327
error_estimators.append(error_est)
328328
dofs.append(uh.function_space().dim())
329+
if error_est < 0.05:
330+
break
329331
mesh = adapt(mesh, eta)
330332
assert error_estimators[-1] < 0.05
331333

@@ -402,6 +404,8 @@ def adapt(mesh, eta):
402404
(eta, error_est) = estimate_error(mesh, uh)
403405
error_estimators.append(error_est)
404406
dofs.append(uh.function_space().dim())
407+
if error_est < 0.05:
408+
break
405409
mesh = adapt(mesh, eta)
406410
assert error_estimators[-1] < 0.06
407411

0 commit comments

Comments
 (0)