-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_MeshQualityEnhanced.py
More file actions
190 lines (157 loc) · 8.38 KB
/
test_MeshQualityEnhanced.py
File metadata and controls
190 lines (157 loc) · 8.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# SPDX-FileContributor: Martin Lemay, Paloma Martinez
# SPDX-License-Identifier: Apache 2.0
# ruff: noqa: E402 # disable Module level import not at top of file
import os
from matplotlib.figure import Figure
from dataclasses import dataclass
import numpy as np
import numpy.typing as npt
import pandas as pd
import pytest
from typing import Iterator, Optional
from geos.mesh.utils.genericHelpers import createMultiCellMesh
from geos.mesh.stats.meshQualityMetricHelpers import getAllCellTypesExtended
from geos.processing.pre_processing.MeshQualityEnhanced import MeshQualityEnhanced
from geos.mesh.model.QualityMetricSummary import QualityMetricSummary
from vtkmodules.vtkFiltersVerdict import vtkMeshQuality
from vtkmodules.vtkCommonDataModel import ( vtkUnstructuredGrid, vtkCellData, vtkFieldData, vtkCellTypeUtilities, VTK_TRIANGLE,
VTK_QUAD, VTK_TETRA, VTK_PYRAMID, VTK_WEDGE, VTK_HEXAHEDRON )
from vtkmodules.vtkIOXML import vtkXMLUnstructuredGridReader
# input data
meshName_all: tuple[ str, ...] = (
"polydata",
"tetra_mesh",
)
cellTypes_all: tuple[ int, ...] = ( VTK_TRIANGLE, VTK_TETRA )
qualityMetrics_all: tuple[ tuple[ int, ...], ...] = (
( int( vtkMeshQuality.QualityMeasureTypes.ASPECT_RATIO ), int( vtkMeshQuality.QualityMeasureTypes.SCALED_JACOBIAN ),
int( vtkMeshQuality.QualityMeasureTypes.MAX_ANGLE ) ),
( int( vtkMeshQuality.QualityMeasureTypes.SCALED_JACOBIAN ),
int( vtkMeshQuality.QualityMeasureTypes.EQUIANGLE_SKEW ),
int( vtkMeshQuality.QualityMeasureTypes.SQUISH_INDEX ) ),
)
# yapf: disable
cellTypeCounts_all: tuple[ tuple[ int, ...], ...] = (
( 26324, 0, 0, 0, 0, 0, 26324, 0, ),
( 0, 0, 8, 0, 0, 0, 0, 8,)
)
metricsSummary_all: tuple[ tuple[ tuple[ float, ...], ...], ...] = (
( ( 1.07, 0.11, 1.0, 1.94, 26324.0 ), ( 0.91, 0.1, 0.53, 1.0, 26324.0 ), ( 64.59, 6.73, 60.00, 110.67, 26324.0 ) ),
( ( -0.28, 0.09, -0.49, -0.22, 8.0 ), ( 0.7, 0.1, 0.47, 0.79, 8.0 ), ( 0.8, 0.12, 0.58, 0.95, 8.0 ) ),
)
# yapf: enable
@dataclass( frozen=True )
class TestCase:
"""Test case."""
__test__ = False
#: mesh
mesh: vtkUnstructuredGrid
cellType: int
qualityMetrics: tuple[ int, ...]
cellTypeCounts: tuple[ int, ...]
metricsSummary: tuple[ tuple[ float, ...], ...]
def __get_tetra_dataset() -> vtkUnstructuredGrid:
"""Extract tetrahedra dataset from csv and add some deformations."""
# Get tetra mesh
data_root: str = os.path.join( os.path.dirname( os.path.abspath( __file__ ) ), "data" )
filename: str = "tetra_mesh.csv"
nbPtsCell: int = 4
ptsCoord: npt.NDArray[ np.float64 ] = np.loadtxt( os.path.join( data_root, filename ), dtype=float, delimiter=',' )
# Intentional deformation of the mesh
ptsCoord[ :, 0 ][ ptsCoord[ :, 0 ] == 0.5 ] = 0.2
ptsCoord[ :, 2 ][ ptsCoord[ :, 2 ] == 0.5 ] = 0.7
cellPtsCoords: list[ npt.NDArray[ np.float64 ] ] = [
ptsCoord[ i:i + nbPtsCell ] for i in range( 0, ptsCoord.shape[ 0 ], nbPtsCell )
]
nbCells: int = int( ptsCoord.shape[ 0 ] / nbPtsCell )
cellTypes = nbCells * [ VTK_TETRA ]
mesh: vtkUnstructuredGrid = createMultiCellMesh( cellTypes, cellPtsCoords )
return mesh
def __get_dataset( meshName: str ) -> vtkUnstructuredGrid:
"""Get the dataset from external vtk file.
Args:
meshName (str): The name of the mesh
Returns:
vtkUnstructuredGrid: The dataset.
"""
if meshName == "polydata":
reader: vtkXMLUnstructuredGridReader = vtkXMLUnstructuredGridReader()
vtkFilename: str = "data/triangulatedSurface.vtu"
datapath: str = os.path.join( os.path.dirname( os.path.realpath( __file__ ) ), vtkFilename )
reader.SetFileName( datapath )
reader.Update()
return reader.GetOutput()
def __generate_test_data() -> Iterator[ TestCase ]:
"""Generate test cases.
Yields:
Iterator[ TestCase ]: Iterator on test cases
"""
for meshName, cellType, qualityMetrics, cellTypeCounts, metricsSummary in zip( meshName_all,
cellTypes_all,
qualityMetrics_all,
cellTypeCounts_all,
metricsSummary_all,
strict=True ):
mesh: vtkUnstructuredGrid
mesh = __get_tetra_dataset() if meshName == "tetra_mesh" else __get_dataset( meshName )
yield TestCase( mesh, cellType, qualityMetrics, cellTypeCounts, metricsSummary )
ids: list[ str ] = [ os.path.splitext( name )[ 0 ] for name in meshName_all ]
@pytest.mark.parametrize( "test_case", __generate_test_data(), ids=ids )
def test_MeshQualityEnhanced( test_case: TestCase ) -> None:
"""Test of MeshQualityEnhanced filter.
Args:
test_case (TestCase): Test case
"""
mesh = test_case.mesh
meshQualityEnhancedFilter: MeshQualityEnhanced = MeshQualityEnhanced( mesh )
if test_case.cellType == VTK_TRIANGLE:
meshQualityEnhancedFilter.SetTriangleMetrics( test_case.qualityMetrics )
elif test_case.cellType == VTK_QUAD:
meshQualityEnhancedFilter.SetQuadMetrics( test_case.qualityMetrics )
elif test_case.cellType == VTK_TETRA:
meshQualityEnhancedFilter.SetTetraMetrics( test_case.qualityMetrics )
elif test_case.cellType == VTK_PYRAMID:
meshQualityEnhancedFilter.SetPyramidMetrics( test_case.qualityMetrics )
elif test_case.cellType == VTK_WEDGE:
meshQualityEnhancedFilter.SetWedgeMetrics( test_case.qualityMetrics )
elif test_case.cellType == VTK_HEXAHEDRON:
meshQualityEnhancedFilter.SetHexaMetrics( test_case.qualityMetrics )
meshQualityEnhancedFilter.applyFilter()
# test method getComputedMetricsFromCellType
for i, cellType in enumerate( getAllCellTypesExtended() ):
metrics: Optional[ set[ int ] ] = meshQualityEnhancedFilter.getComputedMetricsFromCellType( cellType )
if test_case.cellTypeCounts[ i ] > 0:
assert metrics is not None, f"Metrics from {vtkCellTypeUtilities.GetClassNameFromTypeId(cellType)} cells is undefined."
# test attributes
outputMesh: vtkUnstructuredGrid = meshQualityEnhancedFilter.getOutput()
cellData: vtkCellData = outputMesh.GetCellData()
assert cellData is not None, "Cell data is undefined."
nbMetrics: int = len( test_case.qualityMetrics )
nbCellArrayExp: int = mesh.GetCellData().GetNumberOfArrays() + nbMetrics
assert cellData.GetNumberOfArrays() == nbCellArrayExp, f"Number of cell arrays is expected to be {nbCellArrayExp}."
# test field data
fieldData: vtkFieldData = outputMesh.GetFieldData()
assert fieldData is not None, "Field data is undefined."
tmp = np.array( test_case.cellTypeCounts ) > 0
nbPolygon: int = np.sum( tmp[ :2 ].astype( int ) )
nbPolygon = 0 if nbPolygon == 0 else nbPolygon + 1
nbPolyhedra: int = np.sum( tmp[ 2:6 ].astype( int ) )
nbPolyhedra = 0 if nbPolyhedra == 0 else nbPolyhedra + 1
nbFieldArrayExp: int = mesh.GetFieldData().GetNumberOfArrays() + tmp.size + 4 * nbMetrics * ( nbPolygon +
nbPolyhedra )
assert fieldData.GetNumberOfArrays(
) == nbFieldArrayExp, f"Number of field data arrays is expected to be {nbFieldArrayExp}."
stats: QualityMetricSummary = meshQualityEnhancedFilter.GetQualityMetricSummary()
for i, cellType in enumerate( getAllCellTypesExtended() ):
# test Counts
assert stats.getCellTypeCountsOfCellType( cellType ) == test_case.cellTypeCounts[
i ], f"Number of {vtkCellTypeUtilities.GetClassNameFromTypeId(cellType)} cells is expected to be {test_case.cellTypeCounts[i]}"
if stats.getCellTypeCountsOfCellType( cellType ) == 0:
continue
# test metric summary
for j, metricIndex in enumerate( test_case.qualityMetrics ):
subStats: pd.Series = stats.getStatsFromMetricAndCellType( metricIndex, cellType )
assert np.round( subStats, 2 ).tolist() == list(
test_case.metricsSummary[ j ] ), f"Stats at metric index {j} are wrong."
fig: Figure = stats.plotSummaryFigure()
assert len( fig.get_axes() ) == 6, "Number of Axes is expected to be 6."