Skip to content

Commit e4b31a4

Browse files
committed
style fix
1 parent c9e9a76 commit e4b31a4

1 file changed

Lines changed: 11 additions & 15 deletions

File tree

firedrake/interpolation.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,6 @@ class CrossMeshInterpolator(Interpolator):
433433
@no_annotations
434434
def __init__(self, expr: Interpolate):
435435
super().__init__(expr)
436-
self.target_mesh = self.target_mesh.unique()
437-
self.source_mesh = self.source_mesh.unique()
438436
if self.access and self.access != op2.WRITE:
439437
raise NotImplementedError(
440438
"Access other than op2.WRITE not implemented for cross-mesh interpolation."
@@ -458,7 +456,7 @@ def __init__(self, expr: Interpolate):
458456
else:
459457
self.missing_points_behaviour = MissingPointsBehaviour.ERROR
460458

461-
if self.source_mesh.geometric_dimension != self.target_mesh.geometric_dimension:
459+
if self.source_mesh.unique().geometric_dimension != self.target_mesh.unique().geometric_dimension:
462460
raise ValueError("Geometric dimensions of source and destination meshes must match.")
463461

464462
dest_element = self.target_space.ufl_element()
@@ -495,18 +493,18 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
495493
"""
496494
from firedrake.assemble import assemble
497495
# Immerse coordinates of target space point evaluation dofs in src_mesh
498-
target_space_vec = VectorFunctionSpace(self.target_mesh, self.dest_element)
499-
f_dest_node_coords = assemble(interpolate(self.target_mesh.coordinates, target_space_vec))
500-
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.geometric_dimension)
496+
target_space_vec = VectorFunctionSpace(self.target_mesh.unique(), self.dest_element)
497+
f_dest_node_coords = assemble(interpolate(self.target_mesh.unique().coordinates, target_space_vec))
498+
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, self.target_mesh.unique().geometric_dimension)
501499
try:
502500
vom = VertexOnlyMesh(
503-
self.source_mesh,
501+
self.source_mesh.unique(),
504502
dest_node_coords,
505503
redundant=False,
506504
missing_points_behaviour=self.missing_points_behaviour,
507505
)
508506
except VertexOnlyMeshMissingPointsError:
509-
raise DofNotDefinedError(self.source_mesh, self.target_mesh)
507+
raise DofNotDefinedError(self.source_mesh.unique(), self.target_mesh.unique())
510508

511509
# Get the correct type of function space
512510
shape = self.target_space.ufl_function_space().value_shape
@@ -621,12 +619,10 @@ class SameMeshInterpolator(Interpolator):
621619
@no_annotations
622620
def __init__(self, expr):
623621
super().__init__(expr)
624-
self.target_mesh = self.target_mesh.unique()
625-
self.source_mesh = self.source_mesh.unique()
626622
subset = self.subset
627623
if subset is None:
628-
target = self.target_mesh.topology
629-
source = self.source_mesh.topology
624+
target = self.target_mesh.unique().topology
625+
source = self.source_mesh.unique().topology
630626
if all(isinstance(m, MeshTopology) for m in [target, source]) and target is not source:
631627
composed_map, result_integral_type = source.trans_mesh_entity_map(target, "cell", "everywhere", None)
632628
if result_integral_type != "cell":
@@ -667,7 +663,7 @@ def _get_tensor(self, mat_type: Literal["aij", "baij"]) -> op2.Mat | Function |
667663
The tensor to interpolate into.
668664
"""
669665
if self.rank == 0:
670-
R = FunctionSpace(self.target_mesh, "Real", 0)
666+
R = FunctionSpace(self.target_mesh.unique(), "Real", 0)
671667
f = Function(R, dtype=ScalarType)
672668
elif self.rank == 1:
673669
f = Function(self.ufl_interpolate.function_space())
@@ -704,8 +700,8 @@ def _get_monolithic_sparsity(self, mat_type: Literal["aij", "baij"]) -> op2.Spar
704700
Vcol = self.interpolate_args[1].function_space()
705701
if len(Vrow) > 1 or len(Vcol) > 1:
706702
raise NotImplementedError("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator")
707-
Vrow_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vrow)
708-
Vcol_map = get_interp_node_map(self.source_mesh, self.target_mesh, Vcol)
703+
Vrow_map = get_interp_node_map(self.source_mesh.unique(), self.target_mesh.unique(), Vrow)
704+
Vcol_map = get_interp_node_map(self.source_mesh.unique(), self.target_mesh.unique(), Vcol)
709705
sparsity = op2.Sparsity((Vrow.dof_dset, Vcol.dof_dset),
710706
[(Vrow_map, Vcol_map, None)], # non-mixed
711707
name=f"{Vrow.name}_{Vcol.name}_sparsity",

0 commit comments

Comments
 (0)