@@ -433,8 +433,6 @@ class CrossMeshInterpolator(Interpolator):
433433 @no_annotations
434434 def __init__ (self , expr : Interpolate ):
435435 super ().__init__ (expr )
436- self .target_mesh = self .target_mesh .unique ()
437- self .source_mesh = self .source_mesh .unique ()
438436 if self .access and self .access != op2 .WRITE :
439437 raise NotImplementedError (
440438 "Access other than op2.WRITE not implemented for cross-mesh interpolation."
@@ -458,7 +456,7 @@ def __init__(self, expr: Interpolate):
458456 else :
459457 self .missing_points_behaviour = MissingPointsBehaviour .ERROR
460458
461- if self .source_mesh .geometric_dimension != self .target_mesh .geometric_dimension :
459+ if self .source_mesh .unique (). geometric_dimension != self .target_mesh . unique () .geometric_dimension :
462460 raise ValueError ("Geometric dimensions of source and destination meshes must match." )
463461
464462 dest_element = self .target_space .ufl_element ()
@@ -495,18 +493,18 @@ def _get_symbolic_expressions(self) -> tuple[Interpolate, Interpolate]:
495493 """
496494 from firedrake .assemble import assemble
497495 # Immerse coordinates of target space point evaluation dofs in src_mesh
498- target_space_vec = VectorFunctionSpace (self .target_mesh , self .dest_element )
499- f_dest_node_coords = assemble (interpolate (self .target_mesh .coordinates , target_space_vec ))
500- dest_node_coords = f_dest_node_coords .dat .data_ro .reshape (- 1 , self .target_mesh .geometric_dimension )
496+ target_space_vec = VectorFunctionSpace (self .target_mesh . unique () , self .dest_element )
497+ f_dest_node_coords = assemble (interpolate (self .target_mesh .unique (). coordinates , target_space_vec ))
498+ dest_node_coords = f_dest_node_coords .dat .data_ro .reshape (- 1 , self .target_mesh .unique (). geometric_dimension )
501499 try :
502500 vom = VertexOnlyMesh (
503- self .source_mesh ,
501+ self .source_mesh . unique () ,
504502 dest_node_coords ,
505503 redundant = False ,
506504 missing_points_behaviour = self .missing_points_behaviour ,
507505 )
508506 except VertexOnlyMeshMissingPointsError :
509- raise DofNotDefinedError (self .source_mesh , self .target_mesh )
507+ raise DofNotDefinedError (self .source_mesh . unique () , self .target_mesh . unique () )
510508
511509 # Get the correct type of function space
512510 shape = self .target_space .ufl_function_space ().value_shape
@@ -621,12 +619,10 @@ class SameMeshInterpolator(Interpolator):
621619 @no_annotations
622620 def __init__ (self , expr ):
623621 super ().__init__ (expr )
624- self .target_mesh = self .target_mesh .unique ()
625- self .source_mesh = self .source_mesh .unique ()
626622 subset = self .subset
627623 if subset is None :
628- target = self .target_mesh .topology
629- source = self .source_mesh .topology
624+ target = self .target_mesh .unique (). topology
625+ source = self .source_mesh .unique (). topology
630626 if all (isinstance (m , MeshTopology ) for m in [target , source ]) and target is not source :
631627 composed_map , result_integral_type = source .trans_mesh_entity_map (target , "cell" , "everywhere" , None )
632628 if result_integral_type != "cell" :
@@ -667,7 +663,7 @@ def _get_tensor(self, mat_type: Literal["aij", "baij"]) -> op2.Mat | Function |
667663 The tensor to interpolate into.
668664 """
669665 if self .rank == 0 :
670- R = FunctionSpace (self .target_mesh , "Real" , 0 )
666+ R = FunctionSpace (self .target_mesh . unique () , "Real" , 0 )
671667 f = Function (R , dtype = ScalarType )
672668 elif self .rank == 1 :
673669 f = Function (self .ufl_interpolate .function_space ())
@@ -704,8 +700,8 @@ def _get_monolithic_sparsity(self, mat_type: Literal["aij", "baij"]) -> op2.Spar
704700 Vcol = self .interpolate_args [1 ].function_space ()
705701 if len (Vrow ) > 1 or len (Vcol ) > 1 :
706702 raise NotImplementedError ("Interpolation matrix with MixedFunctionSpace requires MixedInterpolator" )
707- Vrow_map = get_interp_node_map (self .source_mesh , self .target_mesh , Vrow )
708- Vcol_map = get_interp_node_map (self .source_mesh , self .target_mesh , Vcol )
703+ Vrow_map = get_interp_node_map (self .source_mesh . unique () , self .target_mesh . unique () , Vrow )
704+ Vcol_map = get_interp_node_map (self .source_mesh . unique () , self .target_mesh . unique () , Vcol )
709705 sparsity = op2 .Sparsity ((Vrow .dof_dset , Vcol .dof_dset ),
710706 [(Vrow_map , Vcol_map , None )], # non-mixed
711707 name = f"{ Vrow .name } _{ Vcol .name } _sparsity" ,
0 commit comments