Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 85 additions & 74 deletions firedrake/mg/interface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import ufl
from itertools import repeat
from pyop2 import op2

from firedrake import ufl_expr, dmhooks
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.petsc import PETSc
from ufl.duals import is_dual
from ufl.algorithms.analysis import extract_coefficients
from ufl.domain import extract_unique_domain
from . import utils
from . import kernels

Expand All @@ -13,6 +17,10 @@


def check_arguments(coarse, fine, needs_dual=False):
if coarse.ufl_shape != fine.ufl_shape:
raise ValueError("Mismatching function space shapes")
coarse, = extract_coefficients(coarse)
fine, = extract_coefficients(fine)
if is_dual(coarse) != needs_dual:
expected_type = Cofunction if needs_dual else Function
raise TypeError("Coarse argument is a %s, not a %s" % (type(coarse).__name__, expected_type.__name__))
Expand All @@ -29,13 +37,58 @@ def check_arguments(coarse, fine, needs_dual=False):
raise ValueError("Coarse argument must be from coarser space")
if hierarchy is not fhierarchy:
raise ValueError("Can't transfer between functions from different hierarchies")
if coarse.ufl_shape != fine.ufl_shape:
raise ValueError("Mismatching function space shapes")


def multigrid_transfer(ufl_interpolate, tensor=None):
if tensor is None:
tensor = Function(ufl_interpolate.ufl_function_space())

coefficients = extract_coefficients(ufl_interpolate)
if is_dual(ufl_interpolate.ufl_function_space()):
kernel = kernels.restrict_kernel(ufl_interpolate)
access = op2.INC
source, = ufl_interpolate.arguments()
else:
kernel = kernels.prolong_kernel(ufl_interpolate)
access = op2.WRITE
source, = coefficients

dual_arg, operand = ufl_interpolate.argument_slots()
Vtarget = dual_arg.ufl_function_space().dual()
source_mesh = extract_unique_domain(operand)
target_mesh = Vtarget.mesh()
if utils.get_level(target_mesh)[1] > utils.get_level(source_mesh)[1]:
node_map = utils.fine_node_to_coarse_node_map
else:
node_map = utils.coarse_node_to_fine_node_map

# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
# x = \sum_i c_i \phi_i(x_hat)
target_coords = utils.physical_node_locations(Vtarget)
source_coords = get_coordinates(source.ufl_function_space())
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [source_coords, *coefficients]:
if d.function_space().mesh() is not target_mesh:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)

def parloop_arg(c, access):
m_ = None if c.function_space().mesh() is target_mesh else node_map(Vtarget, c.function_space())
return c.dat(access, m_)

op2.par_loop(kernel, Vtarget.node_set,
parloop_arg(tensor, access),
*map(parloop_arg, (*coefficients, target_coords, source_coords), repeat(op2.READ)))
return tensor


@PETSc.Log.EventDecorator()
def prolong(coarse, fine):
check_arguments(coarse, fine)
coarse_expr = coarse
coarse, = extract_coefficients(coarse_expr)
Vc = coarse.function_space()
Vf = fine.function_space()
if len(Vc) > 1:
Expand All @@ -56,7 +109,7 @@ def prolong(coarse, fine):
hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse))
_, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine))
refinements_per_level = hierarchy.refinements_per_level
repeat = (fine_level - coarse_level)*refinements_per_level
refine = (fine_level - coarse_level)*refinements_per_level
next_level = coarse_level * refinements_per_level

if needs_quadrature := not Vf.finat_element.has_pointwise_dual_basis:
Expand All @@ -66,40 +119,24 @@ def prolong(coarse, fine):
finest = fine
Vfinest = finest.function_space()
meshes = hierarchy._meshes
for j in range(repeat):
for j in range(refine):
next_level += 1
if j == repeat - 1 and not needs_quadrature:
fine = finest
if j == refine - 1 and not needs_quadrature:
tensor = finest
else:
fine = Function(Vf.reconstruct(mesh=meshes[next_level]))
Vf = fine.function_space()
Vc = coarse.function_space()
tensor = None

coarse_coords = get_coordinates(Vc)
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
kernel = kernels.prolong_kernel(coarse, Vf)

# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
# x = \sum_i c_i \phi_i(x_hat)
node_locations = utils.physical_node_locations(Vf)
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [coarse, coarse_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, fine.node_set,
fine.dat(op2.WRITE),
coarse.dat(op2.READ, fine_to_coarse),
node_locations.dat(op2.READ),
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
fine_dual = ufl.TestFunction(Vf.reconstruct(mesh=meshes[next_level]).dual())
ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual)
fine = multigrid_transfer(ufl_interpolate, tensor=tensor)

if needs_quadrature:
# Transfer to the actual target space
new_fine = finest if j == repeat-1 else Function(Vfinest.reconstruct(mesh=meshes[next_level]))
new_fine = finest if j == refine-1 else Function(Vfinest.reconstruct(mesh=meshes[next_level]))
fine = new_fine.interpolate(fine)

coarse = fine
coarse_expr = coarse
return fine


Expand All @@ -126,7 +163,7 @@ def restrict(fine_dual, coarse_dual):
hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse_dual))
_, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine_dual))
refinements_per_level = hierarchy.refinements_per_level
repeat = (fine_level - coarse_level)*refinements_per_level
refine = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

if needs_quadrature := not Vf.finat_element.has_pointwise_dual_basis:
Expand All @@ -135,45 +172,31 @@ def restrict(fine_dual, coarse_dual):

coarsest = coarse_dual.zero()
meshes = hierarchy._meshes
for j in range(repeat):
for j in range(refine):
if needs_quadrature:
# Transfer to the quadrature source space
fine_dual = Function(Vq.reconstruct(mesh=meshes[next_level])).interpolate(fine_dual)

next_level -= 1
if j == repeat - 1:
if j == refine - 1:
coarse_dual = coarsest
else:
coarse_dual = Function(Vc.reconstruct(mesh=meshes[next_level]))
Vf = fine_dual.function_space()
Vc = coarse_dual.function_space()

# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
# x = \sum_i c_i \phi_i(x_hat)
node_locations = utils.physical_node_locations(Vf.dual())

coarse_coords = get_coordinates(Vc.dual())
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [coarse_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
kernel = kernels.restrict_kernel(Vf, Vc)
op2.par_loop(kernel, fine_dual.node_set,
coarse_dual.dat(op2.INC, fine_to_coarse),
fine_dual.dat(op2.READ),
node_locations.dat(op2.READ),
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
coarse_expr = ufl.TestFunction(Vc.dual())
ufl_interpolate = ufl.Interpolate(coarse_expr, fine_dual)
multigrid_transfer(ufl_interpolate, tensor=coarse_dual)
fine_dual = coarse_dual
return coarse_dual


@PETSc.Log.EventDecorator()
def inject(fine, coarse):
check_arguments(coarse, fine)
fine_expr = fine
fine, = extract_coefficients(fine)
Vf = fine.function_space()
Vc = coarse.function_space()
if len(Vc) > 1:
Expand Down Expand Up @@ -205,45 +228,32 @@ def inject(fine, coarse):
hierarchy, coarse_level = utils.get_level(ufl_expr.extract_unique_domain(coarse))
_, fine_level = utils.get_level(ufl_expr.extract_unique_domain(fine))
refinements_per_level = hierarchy.refinements_per_level
repeat = (fine_level - coarse_level)*refinements_per_level
refine = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

if needs_quadrature := not Vc.finat_element.has_pointwise_dual_basis:
# Introduce an intermediate quadrature target space
Vc = Vc.quadrature_space()

kernel, dg = kernels.inject_kernel(Vf, Vc)
if dg and not hierarchy.nested:
raise NotImplementedError("Sorry, we can't do supermesh projections yet!")

coarsest = coarse.zero()
Vcoarsest = coarsest.function_space()
meshes = hierarchy._meshes
for j in range(repeat):
for j in range(refine):
next_level -= 1
if j == repeat - 1 and not needs_quadrature:
if j == refine - 1 and not needs_quadrature:
coarse = coarsest
else:
coarse = Function(Vc.reconstruct(mesh=meshes[next_level]))
Vc = coarse.function_space()
Vf = fine.function_space()
if not dg:
fine_coords = get_coordinates(Vf)
coarse_to_fine = utils.coarse_node_to_fine_node_map(Vc, Vf)
coarse_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())

node_locations = utils.physical_node_locations(Vc)
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [fine, fine_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, coarse.node_set,
coarse.dat(op2.WRITE),
fine.dat(op2.READ, coarse_to_fine),
node_locations.dat(op2.READ),
fine_coords.dat(op2.READ, coarse_to_fine_coords))
ufl_interpolate = ufl.Interpolate(fine_expr, ufl.TestFunction(Vc.dual()))
if not Vf.finat_element.is_dg():
multigrid_transfer(ufl_interpolate, tensor=coarse)
else:
kernel, dg = kernels.inject_kernel(ufl_interpolate)
if dg and not hierarchy.nested:
raise NotImplementedError("Sorry, we can't do supermesh projections yet!")
coarse_coords = get_coordinates(Vc)
fine_coords = get_coordinates(Vf)
coarse_cell_to_fine_nodes = utils.coarse_cell_to_fine_node_map(Vc, Vf)
Expand All @@ -261,9 +271,10 @@ def inject(fine, coarse):

if needs_quadrature:
# Transfer to the actual target space
new_coarse = coarsest if j == repeat - 1 else Function(Vcoarsest.reconstruct(mesh=meshes[next_level]))
new_coarse = coarsest if j == refine - 1 else Function(Vcoarsest.reconstruct(mesh=meshes[next_level]))
coarse = new_coarse.interpolate(coarse)
fine = coarse
fine_expr = fine
return coarse


Expand Down
Loading
Loading