Skip to content

Commit c260ac8

Browse files
committed
refactor: move some functions from driver to kernel_interface/common
1 parent 401bc2a commit c260ac8

2 files changed

Lines changed: 113 additions & 119 deletions

File tree

tsfc/driver.py

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,18 @@
44
from itertools import chain
55
from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement
66

7-
from numpy import asarray
8-
97
import ufl
108
from ufl.algorithms import extract_arguments, extract_coefficients
119
from ufl.algorithms.analysis import has_type
1210
from ufl.classes import Form, GeometricQuantity
1311
from ufl.log import GREEN
14-
from ufl.utils.sequences import max_degree
1512

1613
import gem
1714
import gem.impero_utils as impero_utils
1815

19-
from FIAT.reference_element import TensorProductCell
20-
2116
import finat
22-
from finat.quadrature import AbstractQuadratureRule, make_quadrature
2317

2418
from tsfc import fem, ufl_utils
25-
from tsfc.finatinterface import as_fiat_cell
2619
from tsfc.logging import logger
2720
from tsfc.parameters import default_parameters, is_complex
2821
from tsfc.ufl_utils import apply_mapping
@@ -96,7 +89,6 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
9689
:returns: a kernel constructed by the kernel interface
9790
"""
9891
parameters = preprocess_parameters(parameters)
99-
10092
if interface is None:
10193
if coffee:
10294
import tsfc.kernel_interface.firedrake as firedrake_interface_coffee
@@ -112,7 +104,6 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
112104
kernel_name = "%s_%s_integral_%s" % (prefix, integral_type, integral_data.subdomain_id)
113105
# Handle negative subdomain_id
114106
kernel_name = kernel_name.replace("-", "_")
115-
116107
# Dict mapping domains to index in original_form.ufl_domains()
117108
domain_numbering = form_data.original_form.domain_numbering()
118109
domain_number = domain_numbering[integral_data.domain]
@@ -134,22 +125,17 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
134125
builder = interface(integral_data_info,
135126
scalar_type,
136127
diagonal=diagonal)
137-
138128
builder.set_coordinates(mesh)
139129
builder.set_cell_sizes(mesh)
140-
141130
builder.set_coefficients(integral_data, form_data)
142-
143131
ctx = builder.create_context()
144132
for integral in integral_data.integrals:
145133
params = parameters.copy()
146134
params.update(integral.metadata()) # integral metadata overrides
147-
148135
integrand = ufl.replace(integral.integrand(), form_data.function_replace_map)
149136
integrand_exprs = builder.compile_integrand(integrand, params, ctx)
150137
integral_exprs = builder.construct_integrals(integrand_exprs, params)
151138
builder.stash_integrals(integral_exprs, params, ctx)
152-
153139
return builder.construct_kernel(kernel_name, ctx)
154140

155141

@@ -331,107 +317,3 @@ def __call__(self, ps):
331317
assert set(gem_expr.free_indices) <= set(chain(ps.indices, *argument_multiindices))
332318

333319
return gem_expr
334-
335-
336-
def set_quad_rule(params, cell, integral_type, functions):
337-
# Check if the integral has a quad degree attached, otherwise use
338-
# the estimated polynomial degree attached by compute_form_data
339-
try:
340-
quadrature_degree = params["quadrature_degree"]
341-
except KeyError:
342-
quadrature_degree = params["estimated_polynomial_degree"]
343-
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
344-
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
345-
for degree in function_degrees):
346-
logger.warning("Estimated quadrature degree %s more "
347-
"than tenfold greater than any "
348-
"argument/coefficient degree (max %s)",
349-
quadrature_degree, max_degree(function_degrees))
350-
if params.get("quadrature_rule") == "default":
351-
del params["quadrature_rule"]
352-
try:
353-
quad_rule = params["quadrature_rule"]
354-
except KeyError:
355-
fiat_cell = as_fiat_cell(cell)
356-
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
357-
integration_cell = fiat_cell.construct_subelement(integration_dim)
358-
quad_rule = make_quadrature(integration_cell, quadrature_degree)
359-
params["quadrature_rule"] = quad_rule
360-
361-
if not isinstance(quad_rule, AbstractQuadratureRule):
362-
raise ValueError("Expected to find a QuadratureRule object, not a %s" %
363-
type(quad_rule))
364-
365-
366-
def get_index_ordering(quadrature_indices, return_variables):
367-
split_argument_indices = tuple(chain(*[var.index_ordering()
368-
for var in return_variables]))
369-
return tuple(quadrature_indices) + split_argument_indices
370-
371-
372-
def lower_integral_type(fiat_cell, integral_type):
373-
"""Lower integral type into the dimension of the integration
374-
subentity and a list of entity numbers for that dimension.
375-
376-
:arg fiat_cell: FIAT reference cell
377-
:arg integral_type: integral type (string)
378-
"""
379-
vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert']
380-
horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz']
381-
382-
dim = fiat_cell.get_dimension()
383-
if integral_type == 'cell':
384-
integration_dim = dim
385-
elif integral_type in ['exterior_facet', 'interior_facet']:
386-
if isinstance(fiat_cell, TensorProductCell):
387-
raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type))
388-
integration_dim = dim - 1
389-
elif integral_type == 'vertex':
390-
integration_dim = 0
391-
elif integral_type in vert_facet_types + horiz_facet_types:
392-
# Extrusion case
393-
if not isinstance(fiat_cell, TensorProductCell):
394-
raise ValueError("{} integral requires a TensorProductCell.".format(integral_type))
395-
basedim, extrdim = dim
396-
assert extrdim == 1
397-
398-
if integral_type in vert_facet_types:
399-
integration_dim = (basedim - 1, 1)
400-
elif integral_type in horiz_facet_types:
401-
integration_dim = (basedim, 0)
402-
else:
403-
raise NotImplementedError("integral type %s not supported" % integral_type)
404-
405-
if integral_type == 'exterior_facet_bottom':
406-
entity_ids = [0]
407-
elif integral_type == 'exterior_facet_top':
408-
entity_ids = [1]
409-
else:
410-
entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim])))
411-
412-
return integration_dim, entity_ids
413-
414-
415-
def pick_mode(mode):
416-
"Return one of the specialized optimisation modules from a mode string."
417-
try:
418-
from firedrake_citations import Citations
419-
cites = {"vanilla": ("Homolya2017", ),
420-
"coffee": ("Luporini2016", "Homolya2017", ),
421-
"spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"),
422-
"tensor": ("Kirby2006", "Homolya2017", )}
423-
for c in cites[mode]:
424-
Citations().register(c)
425-
except ImportError:
426-
pass
427-
if mode == "vanilla":
428-
import tsfc.vanilla as m
429-
elif mode == "coffee":
430-
import tsfc.coffee_mode as m
431-
elif mode == "spectral":
432-
import tsfc.spectral as m
433-
elif mode == "tensor":
434-
import tsfc.tensor as m
435-
else:
436-
raise ValueError("Unknown mode: {}".format(mode))
437-
return m

tsfc/kernel_interface/common.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,28 @@
22
import string
33
import operator
44
from functools import reduce
5+
from itertools import chain
56

67
import numpy
8+
from numpy import asarray
9+
10+
from ufl.utils.sequences import max_degree
711

812
import coffee.base as coffee
913

14+
from FIAT.reference_element import TensorProductCell
15+
16+
from finat.quadrature import AbstractQuadratureRule, make_quadrature
17+
1018
import gem
1119

1220
from gem.utils import cached_property
1321
import gem.impero_utils as impero_utils
1422

15-
from tsfc.driver import lower_integral_type, set_quad_rule, pick_mode, get_index_ordering
1623
from tsfc import fem, ufl_utils
1724
from tsfc.kernel_interface import KernelInterface
1825
from tsfc.finatinterface import as_fiat_cell
26+
from tsfc.logging import logger
1927

2028

2129
class KernelBuilderBase(KernelInterface):
@@ -290,6 +298,42 @@ def create_context(self):
290298
'mode_irs': collections.OrderedDict()}
291299

292300

301+
def set_quad_rule(params, cell, integral_type, functions):
302+
# Check if the integral has a quad degree attached, otherwise use
303+
# the estimated polynomial degree attached by compute_form_data
304+
try:
305+
quadrature_degree = params["quadrature_degree"]
306+
except KeyError:
307+
quadrature_degree = params["estimated_polynomial_degree"]
308+
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
309+
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
310+
for degree in function_degrees):
311+
logger.warning("Estimated quadrature degree %s more "
312+
"than tenfold greater than any "
313+
"argument/coefficient degree (max %s)",
314+
quadrature_degree, max_degree(function_degrees))
315+
if params.get("quadrature_rule") == "default":
316+
del params["quadrature_rule"]
317+
try:
318+
quad_rule = params["quadrature_rule"]
319+
except KeyError:
320+
fiat_cell = as_fiat_cell(cell)
321+
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
322+
integration_cell = fiat_cell.construct_subelement(integration_dim)
323+
quad_rule = make_quadrature(integration_cell, quadrature_degree)
324+
params["quadrature_rule"] = quad_rule
325+
326+
if not isinstance(quad_rule, AbstractQuadratureRule):
327+
raise ValueError("Expected to find a QuadratureRule object, not a %s" %
328+
type(quad_rule))
329+
330+
331+
def get_index_ordering(quadrature_indices, return_variables):
332+
split_argument_indices = tuple(chain(*[var.index_ordering()
333+
for var in return_variables]))
334+
return tuple(quadrature_indices) + split_argument_indices
335+
336+
293337
def get_index_names(quadrature_indices, argument_multiindices, index_cache):
294338
index_names = []
295339

@@ -311,3 +355,71 @@ def name_multiindex(multiindex, name):
311355
for multiindex, name in zip(argument_multiindices, ['j', 'k']):
312356
name_multiindex(multiindex, name)
313357
return index_names
358+
359+
360+
def lower_integral_type(fiat_cell, integral_type):
361+
"""Lower integral type into the dimension of the integration
362+
subentity and a list of entity numbers for that dimension.
363+
364+
:arg fiat_cell: FIAT reference cell
365+
:arg integral_type: integral type (string)
366+
"""
367+
vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert']
368+
horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz']
369+
370+
dim = fiat_cell.get_dimension()
371+
if integral_type == 'cell':
372+
integration_dim = dim
373+
elif integral_type in ['exterior_facet', 'interior_facet']:
374+
if isinstance(fiat_cell, TensorProductCell):
375+
raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type))
376+
integration_dim = dim - 1
377+
elif integral_type == 'vertex':
378+
integration_dim = 0
379+
elif integral_type in vert_facet_types + horiz_facet_types:
380+
# Extrusion case
381+
if not isinstance(fiat_cell, TensorProductCell):
382+
raise ValueError("{} integral requires a TensorProductCell.".format(integral_type))
383+
basedim, extrdim = dim
384+
assert extrdim == 1
385+
386+
if integral_type in vert_facet_types:
387+
integration_dim = (basedim - 1, 1)
388+
elif integral_type in horiz_facet_types:
389+
integration_dim = (basedim, 0)
390+
else:
391+
raise NotImplementedError("integral type %s not supported" % integral_type)
392+
393+
if integral_type == 'exterior_facet_bottom':
394+
entity_ids = [0]
395+
elif integral_type == 'exterior_facet_top':
396+
entity_ids = [1]
397+
else:
398+
entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim])))
399+
400+
return integration_dim, entity_ids
401+
402+
403+
def pick_mode(mode):
404+
"Return one of the specialized optimisation modules from a mode string."
405+
try:
406+
from firedrake_citations import Citations
407+
cites = {"vanilla": ("Homolya2017", ),
408+
"coffee": ("Luporini2016", "Homolya2017", ),
409+
"spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"),
410+
"tensor": ("Kirby2006", "Homolya2017", )}
411+
for c in cites[mode]:
412+
Citations().register(c)
413+
except ImportError:
414+
pass
415+
if mode == "vanilla":
416+
import tsfc.vanilla as m
417+
elif mode == "coffee":
418+
import tsfc.coffee_mode as m
419+
elif mode == "spectral":
420+
import tsfc.spectral as m
421+
elif mode == "tensor":
422+
import tsfc.tensor as m
423+
else:
424+
raise ValueError("Unknown mode: {}".format(mode))
425+
return m

0 commit comments

Comments
 (0)