Skip to content

Commit 16dc8a7

Browse files
authored
Merge pull request #177 from scientificcomputing/dokken/quadrature-interpolation-matrix
Support interpolation matrix into quadrature space
2 parents af21076 + df11038 commit 16dc8a7

2 files changed

Lines changed: 100 additions & 39 deletions

File tree

src/scifem/interpolation.py

Lines changed: 70 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -55,37 +55,57 @@ def prepare_interpolation_data(
5555
array_evaluated = compiled_expr.eval(mesh, np.arange(num_cells, dtype=np.int32))
5656
assert np.prod(Q.value_shape) == np.prod(expr.ufl_shape)
5757

58-
im = Q.element.basix_element.interpolation_matrix
59-
6058
# Get data as (num_cells*num_points,1, expr_shape, num_test_basis_functions*test_block_size)
6159
expr_size = int(np.prod(expr.ufl_shape))
6260
array_evaluated = array_evaluated.reshape(
6361
num_cells * q_points.shape[0], 1, expr_size, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
6462
)
65-
jacobian = dolfinx.fem.Expression(ufl.Jacobian(mesh), q_points)
66-
detJ = dolfinx.fem.Expression(ufl.JacobianDeterminant(mesh), q_points)
67-
K = dolfinx.fem.Expression(ufl.JacobianInverse(mesh), q_points)
68-
jacs = jacobian.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
69-
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
70-
)
71-
detJs = detJ.eval(mesh, np.arange(num_cells, dtype=np.int32)).flatten()
72-
Ks = K.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
73-
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
74-
)
7563

76-
Q_vs = Q.element.basix_element.value_size
64+
# Check if we are dealing with a quadrature element or not.
65+
# They do not have a complete DOLFINx API, which makes them tricky to use.
66+
try:
67+
basix_el = Q.element.basix_element
68+
Q_vs = basix_el.value_size
69+
pull_back = basix_el.pull_back
70+
im = basix_el.interpolation_matrix
71+
except RuntimeError:
72+
Q_vs = 1 # If we do not have a basix element, assume value size is 1
73+
assert isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback)
74+
pull_back = lambda x: None
75+
assert Q.element.interpolation_ident
76+
im = None
77+
7778
new_array = np.zeros(
7879
(num_cells * num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs),
7980
dtype=np.float64,
8081
)
81-
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
82-
for q in range(Q.dofmap.bs):
83-
new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = Q.element.basix_element.pull_back(
84-
array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks
85-
).reshape(num_cells * num_points, Q_vs)
86-
new_array = new_array.reshape(
87-
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
88-
)
82+
83+
# Check if pullback is identity, then we can skip this step
84+
if not isinstance(Q.ufl_element().pullback, ufl.pullback.IdentityPullback):
85+
jacobian = dolfinx.fem.Expression(ufl.Jacobian(mesh), q_points)
86+
detJ = dolfinx.fem.Expression(ufl.JacobianDeterminant(mesh), q_points)
87+
K = dolfinx.fem.Expression(ufl.JacobianInverse(mesh), q_points)
88+
jacs = jacobian.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
89+
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
90+
)
91+
detJs = detJ.eval(mesh, np.arange(num_cells, dtype=np.int32)).flatten()
92+
Ks = K.eval(mesh, np.arange(num_cells, dtype=np.int32)).reshape(
93+
num_cells * num_points, mesh.geometry.dim, mesh.topology.dim
94+
)
95+
96+
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
97+
for q in range(Q.dofmap.bs):
98+
new_array[:, q * Q_vs : (q + 1) * Q_vs, i] = pull_back(
99+
array_evaluated[:, :, q * Q_vs : (q + 1) * Q_vs, i], jacs, detJs, Ks
100+
).reshape(num_cells * num_points, Q_vs)
101+
new_array = new_array.reshape(
102+
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
103+
)
104+
else:
105+
new_array = array_evaluated.reshape(
106+
num_cells, num_points, Q.dofmap.bs * Q_vs, V.dofmap.bs * V.dofmap.dof_layout.num_dofs
107+
)
108+
89109
interpolated_matrix = np.zeros(
90110
(
91111
num_cells,
@@ -94,19 +114,36 @@ def prepare_interpolation_data(
94114
),
95115
dtype=np.float64,
96116
)
117+
# Check if interpolation matrix of dual operator is identity, then we can use a vectorized
118+
# version of this step
119+
if Q.element.interpolation_ident:
120+
# Smart vectorized version with identity mapping
121+
if Q.dofmap.bs == 1:
122+
interpolated_matrix = new_array.transpose(0, 2, 1, 3).reshape(
123+
new_array.shape[0], new_array.shape[1] * new_array.shape[2], new_array.shape[3]
124+
)
125+
else:
126+
i_scalar = new_array.transpose(0, 2, 1, 3)
127+
interpolated_matrix = np.zeros(
128+
(new_array.shape[0], new_array.shape[1] * new_array.shape[2], new_array.shape[3])
129+
)
130+
for q in range(Q.dofmap.bs):
131+
interpolated_matrix[:, q :: Q.dofmap.bs, :] = i_scalar[:, q, :, :]
97132

98-
for c in range(num_cells):
99-
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
100-
tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64)
101-
for p in range(num_points):
102-
tmp_array[p] = new_array[c, p, :, i]
103-
if Q.dofmap.bs == 1:
104-
interpolated_matrix[c, :, i] = (im @ tmp_array.T.flatten()).flatten()
105-
else:
106-
for q in range(Q.dofmap.bs):
107-
interpolated_matrix[c, q :: Q.dofmap.bs, i] = (
108-
im @ tmp_array.T[q].flatten()
109-
).flatten()
133+
else:
134+
# Tedious non-identity version
135+
for c in range(num_cells):
136+
for i in range(V.dofmap.bs * V.dofmap.dof_layout.num_dofs):
137+
tmp_array = np.zeros((int(num_points), Q.dofmap.bs * Q_vs), dtype=np.float64)
138+
for p in range(num_points):
139+
tmp_array[p] = new_array[c, p, :, i]
140+
if Q.dofmap.bs == 1:
141+
interpolated_matrix[c, :, i] = (im @ tmp_array.T.flatten()).flatten()
142+
else:
143+
for q in range(Q.dofmap.bs):
144+
interpolated_matrix[c, q :: Q.dofmap.bs, i] = (
145+
im @ tmp_array.T[q].flatten()
146+
).flatten()
110147
# Apply dof transformation to each column (using Piopla maps)
111148
mesh.topology.create_entity_permutations()
112149
if Q.element.needs_dof_transformations:

tests/test_interpolation.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
import ufl
77
import numpy as np
8+
import basix.ufl
89

910

1011
@pytest.mark.skipif(
@@ -20,8 +21,10 @@
2021
],
2122
)
2223
@pytest.mark.parametrize("use_petsc", [True, False])
23-
@pytest.mark.parametrize("degree", [1, 3, 5])
24-
def test_interpolation_matrix(use_petsc, cell_type, degree):
24+
@pytest.mark.parametrize("degree", [2, 4])
25+
@pytest.mark.parametrize("out_family", ["Lagrange", "DG", "Quadrature"])
26+
@pytest.mark.parametrize("value_shape", [(), (2,), (2, 3)])
27+
def test_interpolation_matrix(use_petsc, cell_type, degree, out_family, value_shape):
2528
if use_petsc:
2629
pytest.importorskip("petsc4py")
2730

@@ -33,14 +36,27 @@ def test_interpolation_matrix(use_petsc, cell_type, degree):
3336
else:
3437
raise ValueError("Unsupported cell type")
3538

36-
V = dolfinx.fem.functionspace(mesh, ("DG", degree))
37-
Q = dolfinx.fem.functionspace(mesh, ("Lagrange", degree))
39+
V = dolfinx.fem.functionspace(mesh, ("DG", degree, value_shape))
40+
if out_family == "Quadrature":
41+
el = basix.ufl.quadrature_element(mesh.basix_cell(), degree=degree, value_shape=value_shape)
42+
else:
43+
el = (out_family, degree, value_shape)
44+
Q = dolfinx.fem.functionspace(mesh, el)
45+
46+
def f(x):
47+
scalar_val = x[0] ** degree + x[1] if tdim == 2 else x[0] + x[1] + x[2] ** degree
48+
vs = int(np.prod(value_shape))
49+
f_rep = np.tile(scalar_val, vs).reshape(vs, -1)
50+
for i in range(vs):
51+
f_rep[i] += np.pi * (i + 1)
52+
return f_rep
3853

3954
u = dolfinx.fem.Function(V)
40-
u.interpolate(lambda x: x[0] ** degree + x[1] if tdim == 2 else x[0] + x[1] + x[2] ** degree)
55+
u.interpolate(f)
4156

4257
q = dolfinx.fem.Function(Q)
4358
expr = ufl.TrialFunction(V)
59+
4460
if use_petsc:
4561
A = scifem.interpolation.petsc_interpolation_matrix(expr, Q)
4662
A.mult(u.x.petsc_vec, q.x.petsc_vec)
@@ -59,7 +75,15 @@ def test_interpolation_matrix(use_petsc, cell_type, degree):
5975
q.x.scatter_forward()
6076

6177
q_ref = dolfinx.fem.Function(Q)
62-
q_ref.interpolate(u)
78+
if out_family == "Quadrature":
79+
try:
80+
ip = Q.element.interpolation_points()
81+
except TypeError:
82+
ip = Q.element.interpolation_points
83+
u_expr = dolfinx.fem.Expression(u, ip)
84+
q_ref.interpolate(u_expr)
85+
else:
86+
q_ref.interpolate(u)
6387

6488
np.testing.assert_allclose(q.x.array, q_ref.x.array, rtol=1e-12, atol=1e-13)
6589

0 commit comments

Comments
 (0)