Skip to content

Commit 2395192

Browse files
committed
refactor: move some functions from driver to kernel_interface/common
1 parent fc38841 commit 2395192

2 files changed

Lines changed: 113 additions & 120 deletions

File tree

tsfc/driver.py

Lines changed: 0 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,18 @@
33
from itertools import chain
44
from finat.physically_mapped import DirectlyDefinedElement, PhysicallyMappedElement
55

6-
from numpy import asarray
7-
86
import ufl
97
from ufl.algorithms import extract_arguments, extract_coefficients
108
from ufl.algorithms.analysis import has_type
119
from ufl.classes import Form, GeometricQuantity
1210
from ufl.log import GREEN
13-
from ufl.utils.sequences import max_degree
1411

1512
import gem
1613
import gem.impero_utils as impero_utils
1714

18-
from FIAT.reference_element import TensorProductCell
19-
2015
import finat
21-
from finat.quadrature import AbstractQuadratureRule, make_quadrature
2216

2317
from tsfc import fem, ufl_utils
24-
from tsfc.finatinterface import as_fiat_cell
2518
from tsfc.logging import logger
2619
from tsfc.parameters import default_parameters, is_complex
2720
from tsfc.ufl_utils import apply_mapping
@@ -105,7 +98,6 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
10598
:returns: a kernel constructed by the kernel interface
10699
"""
107100
parameters = preprocess_parameters(parameters)
108-
109101
if interface is None:
110102
if coffee:
111103
import tsfc.kernel_interface.firedrake as firedrake_interface_coffee
@@ -118,14 +110,12 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
118110
scalar_type = parameters["scalar_type_c"]
119111
else:
120112
scalar_type = parameters["scalar_type"]
121-
122113
integral_type = integral_data.integral_type
123114
mesh = integral_data.domain
124115
arguments = form_data.preprocessed_form.arguments()
125116
kernel_name = "%s_%s_integral_%s" % (prefix, integral_type, integral_data.subdomain_id)
126117
# Handle negative subdomain_id
127118
kernel_name = kernel_name.replace("-", "_")
128-
129119
# Dict mapping domains to index in original_form.ufl_domains()
130120
domain_numbering = form_data.original_form.domain_numbering()
131121
domain_number = domain_numbering[integral_data.domain]
@@ -145,22 +135,17 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
145135
scalar_type,
146136
parameters["scalar_type"],
147137
diagonal=diagonal)
148-
149138
builder.set_coordinates(mesh)
150139
builder.set_cell_sizes(mesh)
151-
152140
builder.set_coefficients(integral_data, form_data)
153-
154141
ctx = builder.create_context()
155142
for integral in integral_data.integrals:
156143
params = parameters.copy()
157144
params.update(integral.metadata()) # integral metadata overrides
158-
159145
integrand = ufl.replace(integral.integrand(), form_data.function_replace_map)
160146
integrand_exprs = builder.compile_ufl(integrand, params, ctx)
161147
integral_exprs = builder.construct_integrals(integrand_exprs, params)
162148
builder.stash_integrals(integral_exprs, params, ctx)
163-
164149
return builder.construct_kernel(kernel_name, ctx)
165150

166151

@@ -342,107 +327,3 @@ def __call__(self, ps):
342327
assert set(gem_expr.free_indices) <= set(chain(ps.indices, *argument_multiindices))
343328

344329
return gem_expr
345-
346-
347-
def set_quad_rule(params, cell, integral_type, functions):
348-
# Check if the integral has a quad degree attached, otherwise use
349-
# the estimated polynomial degree attached by compute_form_data
350-
try:
351-
quadrature_degree = params["quadrature_degree"]
352-
except KeyError:
353-
quadrature_degree = params["estimated_polynomial_degree"]
354-
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
355-
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
356-
for degree in function_degrees):
357-
logger.warning("Estimated quadrature degree %s more "
358-
"than tenfold greater than any "
359-
"argument/coefficient degree (max %s)",
360-
quadrature_degree, max_degree(function_degrees))
361-
if params.get("quadrature_rule") == "default":
362-
del params["quadrature_rule"]
363-
try:
364-
quad_rule = params["quadrature_rule"]
365-
except KeyError:
366-
fiat_cell = as_fiat_cell(cell)
367-
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
368-
integration_cell = fiat_cell.construct_subelement(integration_dim)
369-
quad_rule = make_quadrature(integration_cell, quadrature_degree)
370-
params["quadrature_rule"] = quad_rule
371-
372-
if not isinstance(quad_rule, AbstractQuadratureRule):
373-
raise ValueError("Expected to find a QuadratureRule object, not a %s" %
374-
type(quad_rule))
375-
376-
377-
def get_index_ordering(quadrature_indices, return_variables):
378-
split_argument_indices = tuple(chain(*[var.index_ordering()
379-
for var in return_variables]))
380-
return tuple(quadrature_indices) + split_argument_indices
381-
382-
383-
def lower_integral_type(fiat_cell, integral_type):
384-
"""Lower integral type into the dimension of the integration
385-
subentity and a list of entity numbers for that dimension.
386-
387-
:arg fiat_cell: FIAT reference cell
388-
:arg integral_type: integral type (string)
389-
"""
390-
vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert']
391-
horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz']
392-
393-
dim = fiat_cell.get_dimension()
394-
if integral_type == 'cell':
395-
integration_dim = dim
396-
elif integral_type in ['exterior_facet', 'interior_facet']:
397-
if isinstance(fiat_cell, TensorProductCell):
398-
raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type))
399-
integration_dim = dim - 1
400-
elif integral_type == 'vertex':
401-
integration_dim = 0
402-
elif integral_type in vert_facet_types + horiz_facet_types:
403-
# Extrusion case
404-
if not isinstance(fiat_cell, TensorProductCell):
405-
raise ValueError("{} integral requires a TensorProductCell.".format(integral_type))
406-
basedim, extrdim = dim
407-
assert extrdim == 1
408-
409-
if integral_type in vert_facet_types:
410-
integration_dim = (basedim - 1, 1)
411-
elif integral_type in horiz_facet_types:
412-
integration_dim = (basedim, 0)
413-
else:
414-
raise NotImplementedError("integral type %s not supported" % integral_type)
415-
416-
if integral_type == 'exterior_facet_bottom':
417-
entity_ids = [0]
418-
elif integral_type == 'exterior_facet_top':
419-
entity_ids = [1]
420-
else:
421-
entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim])))
422-
423-
return integration_dim, entity_ids
424-
425-
426-
def pick_mode(mode):
427-
"Return one of the specialized optimisation modules from a mode string."
428-
try:
429-
from firedrake_citations import Citations
430-
cites = {"vanilla": ("Homolya2017", ),
431-
"coffee": ("Luporini2016", "Homolya2017", ),
432-
"spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"),
433-
"tensor": ("Kirby2006", "Homolya2017", )}
434-
for c in cites[mode]:
435-
Citations().register(c)
436-
except ImportError:
437-
pass
438-
if mode == "vanilla":
439-
import tsfc.vanilla as m
440-
elif mode == "coffee":
441-
import tsfc.coffee_mode as m
442-
elif mode == "spectral":
443-
import tsfc.spectral as m
444-
elif mode == "tensor":
445-
import tsfc.tensor as m
446-
else:
447-
raise ValueError("Unknown mode: {}".format(mode))
448-
return m

tsfc/kernel_interface/common.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,28 @@
22
import string
33
import operator
44
from functools import reduce
5+
from itertools import chain
56

67
import numpy
8+
from numpy import asarray
9+
10+
from ufl.utils.sequences import max_degree
711

812
import coffee.base as coffee
913

14+
from FIAT.reference_element import TensorProductCell
15+
16+
from finat.quadrature import AbstractQuadratureRule, make_quadrature
17+
1018
import gem
1119

1220
from gem.utils import cached_property
1321
import gem.impero_utils as impero_utils
1422

15-
from tsfc.driver import lower_integral_type, set_quad_rule, pick_mode, get_index_ordering
1623
from tsfc import fem, ufl_utils
1724
from tsfc.kernel_interface import KernelInterface
1825
from tsfc.finatinterface import as_fiat_cell
26+
from tsfc.logging import logger
1927

2028

2129
class KernelBuilderBase(KernelInterface):
@@ -239,6 +247,42 @@ def create_context(self):
239247
'mode_irs': collections.OrderedDict()}
240248

241249

250+
def set_quad_rule(params, cell, integral_type, functions):
251+
# Check if the integral has a quad degree attached, otherwise use
252+
# the estimated polynomial degree attached by compute_form_data
253+
try:
254+
quadrature_degree = params["quadrature_degree"]
255+
except KeyError:
256+
quadrature_degree = params["estimated_polynomial_degree"]
257+
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
258+
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
259+
for degree in function_degrees):
260+
logger.warning("Estimated quadrature degree %s more "
261+
"than tenfold greater than any "
262+
"argument/coefficient degree (max %s)",
263+
quadrature_degree, max_degree(function_degrees))
264+
if params.get("quadrature_rule") == "default":
265+
del params["quadrature_rule"]
266+
try:
267+
quad_rule = params["quadrature_rule"]
268+
except KeyError:
269+
fiat_cell = as_fiat_cell(cell)
270+
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
271+
integration_cell = fiat_cell.construct_subelement(integration_dim)
272+
quad_rule = make_quadrature(integration_cell, quadrature_degree)
273+
params["quadrature_rule"] = quad_rule
274+
275+
if not isinstance(quad_rule, AbstractQuadratureRule):
276+
raise ValueError("Expected to find a QuadratureRule object, not a %s" %
277+
type(quad_rule))
278+
279+
280+
def get_index_ordering(quadrature_indices, return_variables):
281+
split_argument_indices = tuple(chain(*[var.index_ordering()
282+
for var in return_variables]))
283+
return tuple(quadrature_indices) + split_argument_indices
284+
285+
242286
def get_index_names(quadrature_indices, argument_multiindices, index_cache):
243287
index_names = []
244288

@@ -260,3 +304,71 @@ def name_multiindex(multiindex, name):
260304
for multiindex, name in zip(argument_multiindices, ['j', 'k']):
261305
name_multiindex(multiindex, name)
262306
return index_names
307+
308+
309+
def lower_integral_type(fiat_cell, integral_type):
310+
"""Lower integral type into the dimension of the integration
311+
subentity and a list of entity numbers for that dimension.
312+
313+
:arg fiat_cell: FIAT reference cell
314+
:arg integral_type: integral type (string)
315+
"""
316+
vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert']
317+
horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz']
318+
319+
dim = fiat_cell.get_dimension()
320+
if integral_type == 'cell':
321+
integration_dim = dim
322+
elif integral_type in ['exterior_facet', 'interior_facet']:
323+
if isinstance(fiat_cell, TensorProductCell):
324+
raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type))
325+
integration_dim = dim - 1
326+
elif integral_type == 'vertex':
327+
integration_dim = 0
328+
elif integral_type in vert_facet_types + horiz_facet_types:
329+
# Extrusion case
330+
if not isinstance(fiat_cell, TensorProductCell):
331+
raise ValueError("{} integral requires a TensorProductCell.".format(integral_type))
332+
basedim, extrdim = dim
333+
assert extrdim == 1
334+
335+
if integral_type in vert_facet_types:
336+
integration_dim = (basedim - 1, 1)
337+
elif integral_type in horiz_facet_types:
338+
integration_dim = (basedim, 0)
339+
else:
340+
raise NotImplementedError("integral type %s not supported" % integral_type)
341+
342+
if integral_type == 'exterior_facet_bottom':
343+
entity_ids = [0]
344+
elif integral_type == 'exterior_facet_top':
345+
entity_ids = [1]
346+
else:
347+
entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim])))
348+
349+
return integration_dim, entity_ids
350+
351+
352+
def pick_mode(mode):
353+
"Return one of the specialized optimisation modules from a mode string."
354+
try:
355+
from firedrake_citations import Citations
356+
cites = {"vanilla": ("Homolya2017", ),
357+
"coffee": ("Luporini2016", "Homolya2017", ),
358+
"spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"),
359+
"tensor": ("Kirby2006", "Homolya2017", )}
360+
for c in cites[mode]:
361+
Citations().register(c)
362+
except ImportError:
363+
pass
364+
if mode == "vanilla":
365+
import tsfc.vanilla as m
366+
elif mode == "coffee":
367+
import tsfc.coffee_mode as m
368+
elif mode == "spectral":
369+
import tsfc.spectral as m
370+
elif mode == "tensor":
371+
import tsfc.tensor as m
372+
else:
373+
raise ValueError("Unknown mode: {}".format(mode))
374+
return m

0 commit comments

Comments
 (0)