@@ -263,7 +263,18 @@ class Interpolator(abc.ABC):
263263
264264 def __new__ (cls , expr , V , ** kwargs ):
265265 if isinstance (expr , ufl .Interpolate ):
266+ # Mixed spaces are handled well only by the primal 1-form.
267+ # Are we a 2-form or a dual 1-form?
268+ arguments = expr .arguments ()
269+ if any (not isinstance (a , Coargument ) for a in arguments ):
270+ # Do we have mixed source or target spaces?
271+ spaces = [a .function_space () for a in arguments ]
272+ if len (spaces ) < 2 :
273+ spaces .append (V )
274+ if any (len (space ) > 1 for space in spaces ):
275+ return object .__new__ (MixedInterpolator )
266276 expr , = expr .ufl_operands
277+
267278 target_mesh = as_domain (V )
268279 source_mesh = extract_unique_domain (expr ) or target_mesh
269280 submesh_interp_implemented = \
@@ -369,7 +380,7 @@ def _interpolate(self, *args, **kwargs):
369380 """
370381 pass
371382
372- def assemble (self , tensor = None , default_missing_val = None ):
383+ def assemble (self , tensor = None , ** kwargs ):
373384 """Assemble the operator (or its action)."""
374385 from firedrake .assemble import assemble
375386 needs_adjoint = self .ufl_interpolate_renumbered != self .ufl_interpolate
@@ -383,13 +394,11 @@ def assemble(self, tensor=None, default_missing_val=None):
383394 if needs_adjoint :
384395 # Out-of-place Hermitian transpose
385396 petsc_mat .hermitianTranspose (out = res )
386- elif res :
387- petsc_mat .copy (res )
397+ elif tensor :
398+ petsc_mat .copy (tensor . petscmat )
388399 else :
389400 res = petsc_mat
390- if tensor is None :
391- tensor = firedrake .AssembledMatrix (arguments , self .bcs , res )
392- return tensor
401+ return tensor or firedrake .AssembledMatrix (arguments , self .bcs , res )
393402 else :
394403 # Assembling the action
395404 cofunctions = ()
@@ -401,11 +410,11 @@ def assemble(self, tensor=None, default_missing_val=None):
401410 cofunctions = (dual_arg ,)
402411
403412 if needs_adjoint and len (arguments ) == 0 :
404- Iu = self ._interpolate (default_missing_val = default_missing_val )
413+ Iu = self ._interpolate (** kwargs )
405414 return assemble (ufl .Action (* cofunctions , Iu ), tensor = tensor )
406415 else :
407416 return self ._interpolate (* cofunctions , output = tensor , adjoint = needs_adjoint ,
408- default_missing_val = default_missing_val )
417+ ** kwargs )
409418
410419
411420class DofNotDefinedError (Exception ):
@@ -975,33 +984,10 @@ def callable():
975984 return callable
976985 else :
977986 loops = []
978- if len (V ) == 1 :
979- expressions = (expr ,)
980- else :
981- if (hasattr (operand , "subfunctions" ) and len (operand .subfunctions ) == len (V )
982- and all (sub_op .ufl_shape == Vsub .value_shape for Vsub , sub_op in zip (V , operand .subfunctions ))):
983- # Use subfunctions if they match the target shapes
984- operands = operand .subfunctions
985- else :
986- # Unflatten the expression into the shapes of the mixed components
987- offset = 0
988- operands = []
989- for Vsub in V :
990- if len (Vsub .value_shape ) == 0 :
991- operands .append (operand [offset ])
992- else :
993- components = [operand [offset + j ] for j in range (Vsub .value_size )]
994- operands .append (ufl .as_tensor (numpy .reshape (components , Vsub .value_shape )))
995- offset += Vsub .value_size
996-
997- # Split the dual argument
998- if isinstance (dual_arg , Cofunction ):
999- duals = dual_arg .subfunctions
1000- elif isinstance (dual_arg , Coargument ):
1001- duals = [Coargument (Vsub , number = dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1002- else :
1003- duals = [v for _ , v in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1004- expressions = map (expr ._ufl_expr_reconstruct_ , operands , duals )
987+ expressions = split_interpolate_target (expr )
988+
989+ if access == op2 .INC :
990+ loops .append (tensor .zero )
1005991
1006992 # Interpolate each sub expression into each function space
1007993 for Vsub , sub_tensor , sub_expr in zip (V , tensor , expressions ):
@@ -1074,8 +1060,6 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
10741060 parameters ['scalar_type' ] = utils .ScalarType
10751061
10761062 callables = ()
1077- if access == op2 .INC :
1078- callables += (tensor .zero ,)
10791063
10801064 # For the matfree adjoint 1-form and the 0-form, the cellwise kernel will add multiple
10811065 # contributions from the facet DOFs of the dual argument.
@@ -1720,3 +1704,115 @@ def _wrap_dummy_mat(self):
17201704
17211705 def duplicate (self , mat = None , op = None ):
17221706 return self ._wrap_dummy_mat ()
1707+
1708+
1709+ def split_interpolate_target (expr : ufl .Interpolate ):
1710+ """Split an Interpolate into the components (subfunctions) of the target space."""
1711+ dual_arg , operand = expr .argument_slots ()
1712+ V = dual_arg .function_space ().dual ()
1713+ if len (V ) == 1 :
1714+ return (expr ,)
1715+ # Split the target (dual) argument
1716+ if isinstance (dual_arg , Cofunction ):
1717+ duals = dual_arg .subfunctions
1718+ elif isinstance (dual_arg , ufl .Coargument ):
1719+ duals = [Coargument (Vsub , dual_arg .number ()) for Vsub in dual_arg .function_space ()]
1720+ else :
1721+ duals = [vi for _ , vi in sorted (firedrake .formmanipulation .split_form (dual_arg ))]
1722+ # Split the operand into the target shapes
1723+ if (isinstance (operand , firedrake .Function ) and len (operand .subfunctions ) == len (V )
1724+ and all (fsub .ufl_shape == Vsub .value_shape for Vsub , fsub in zip (V , operand .subfunctions ))):
1725+ # Use subfunctions if they match the target shapes
1726+ operands = operand .subfunctions
1727+ else :
1728+ # Unflatten the expression into the target shapes
1729+ cur = 0
1730+ operands = []
1731+ components = numpy .reshape (operand , (- 1 ,))
1732+ for Vi in V :
1733+ operands .append (ufl .as_tensor (components [cur :cur + Vi .value_size ].reshape (Vi .value_shape )))
1734+ cur += Vi .value_size
1735+ expressions = tuple (map (expr ._ufl_expr_reconstruct_ , operands , duals ))
1736+ return expressions
1737+
1738+
1739+ class MixedInterpolator (Interpolator ):
1740+ """A reusable interpolation object between MixedFunctionSpaces.
1741+
1742+ Parameters
1743+ ----------
1744+ expr
1745+ The underlying ufl.Interpolate or the operand to the ufl.Interpolate.
1746+ V
1747+ The :class:`.FunctionSpace` or :class:`.Function` to
1748+ interpolate into.
1749+ bcs
1750+ A list of boundary conditions.
1751+ **kwargs
1752+ Any extra kwargs are passed on to the sub Interpolators.
1753+ For details see :class:`firedrake.interpolation.Interpolator`.
1754+ """
1755+ def __init__ (self , expr , V , bcs = None , ** kwargs ):
1756+ if not isinstance (expr , ufl .Interpolate ):
1757+ fs = V if isinstance (V , ufl .FunctionSpace ) else V .function_space ()
1758+ expr = interpolate (expr , fs )
1759+ if bcs is None :
1760+ bcs = ()
1761+ self .expr = expr
1762+ self .V = V
1763+ self .bcs = bcs
1764+ self .arguments = expr .arguments ()
1765+
1766+ # Split the target (dual) argument
1767+ dual_split = split_interpolate_target (expr )
1768+ self .sub_interpolators = {}
1769+ for i , form in enumerate (dual_split ):
1770+ # Split the source (primal) argument
1771+ for j , sub_interp in firedrake .formmanipulation .split_form (form ):
1772+ j = max (j ) if j else 0
1773+ # Ensure block sparsity
1774+ vi , operand = sub_interp .argument_slots ()
1775+ if not isinstance (operand , ufl .classes .Zero ):
1776+ Vtarget = vi .function_space ().dual ()
1777+ adjoint = vi .number () == 1 if isinstance (vi , Coargument ) else True
1778+
1779+ args = sub_interp .arguments ()
1780+ Vsource = args [0 if adjoint else 1 ].function_space ()
1781+ sub_bcs = [bc for bc in bcs if bc .function_space () in {Vsource , Vtarget }]
1782+
1783+ indices = (j , i ) if adjoint else (i , j )
1784+ Isub = Interpolator (sub_interp , Vtarget , bcs = sub_bcs , ** kwargs )
1785+ self .sub_interpolators [indices ] = Isub
1786+
1787+ def assemble (self , tensor = None , ** kwargs ):
1788+ """Assemble the operator (or its action)."""
1789+ rank = len (self .arguments )
1790+ if rank == 2 :
1791+ # Assemble the operator
1792+ sub_tensors = {}
1793+ for ij , Isub in self .sub_interpolators .items ():
1794+ block = tensor .petscmat .getNestSubMatrix (* ij ) if tensor else PETSc .Mat ()
1795+ sub_tensors [ij ] = firedrake .AssembledMatrix (Isub .arguments , Isub .bcs , block )
1796+ Isub .assemble (tensor = sub_tensors [ij ], ** kwargs )
1797+ if tensor is None :
1798+ shape = tuple (len (a .function_space ()) for a in self .arguments )
1799+ blocks = numpy .reshape ([sub_tensors [ij ].petscmat if ij in sub_tensors else PETSc .Mat ()
1800+ for ij in numpy .ndindex (shape )], shape )
1801+ petscmat = PETSc .Mat ().createNest (blocks )
1802+ tensor = firedrake .AssembledMatrix (self .arguments , self .bcs , petscmat )
1803+ elif rank == 1 :
1804+ # Assemble the action
1805+ if tensor is None :
1806+ V_dest = self .arguments [0 ].function_space ().dual ()
1807+ tensor = firedrake .Function (V_dest )
1808+ for k , fsub in enumerate (tensor .subfunctions ):
1809+ fsub .assign (sum (Isub .assemble (** kwargs ) for (i , j ), Isub in self .sub_interpolators .items () if i == k ))
1810+ elif rank == 0 :
1811+ # Assemble the double action
1812+ result = sum (Isub .assemble (** kwargs ) for (i , j ), Isub in self .sub_interpolators .items ())
1813+ return tensor .assign (result ) if tensor else result
1814+ return tensor
1815+
1816+ def _interpolate (self , output = None , ** kwargs ):
1817+ """Assemble the action."""
1818+ return self .assemble (tensor = output , ** kwargs )
0 commit comments