@@ -121,13 +121,11 @@ void checkRelativeError( real64 const v1, real64 const v2, real64 const relTol,
121121 EXPECT_PRED_FORMAT4 ( checkRelativeErrorFormat, v1, v2, relTol, absTol );
122122}
123123
124- template < typename ROW_INDEX, typename COL_INDEX, typename VALUE >
125- void compareMatrixRow ( ROW_INDEX const rowNumber, VALUE const relTol, VALUE const absTol,
124+ template < typename COL_INDEX, typename VALUE >
125+ void compareMatrixRow ( VALUE const relTol, VALUE const absTol,
126126 localIndex const length1, COL_INDEX const * const indices1, VALUE const * const values1,
127127 localIndex const length2, COL_INDEX const * const indices2, VALUE const * const values2 )
128128{
129- SCOPED_TRACE ( " Row " + std::to_string ( rowNumber ));
130-
131129 EXPECT_EQ ( length1, length2 );
132130
133131 for ( localIndex j1 = 0 , j2 = 0 ; j1 < length1 && j2 < length2; ++j1, ++j2 )
@@ -152,17 +150,31 @@ void compareMatrixRow( ROW_INDEX const rowNumber, VALUE const relTol, VALUE cons
152150 }
153151}
154152
155- template < typename ROW_INDEX, typename COL_INDEX, typename VALUE >
156- void compareMatrixRow ( ROW_INDEX const rowNumber, VALUE const relTol, VALUE const absTol,
157- arraySlice1d< COL_INDEX const > indices1, arraySlice1d< VALUE const > values1,
158- arraySlice1d< COL_INDEX const > indices2, arraySlice1d< VALUE const > values2 )
153+ template < typename T, typename COL_INDEX >
154+ void compareLocalMatrices ( CRSMatrixView< T const , COL_INDEX const > const & matrix1,
155+ CRSMatrixView< T const , COL_INDEX const > const & matrix2,
156+ real64 const relTol = DEFAULT_REL_TOL,
157+ real64 const absTol = DEFAULT_ABS_TOL,
158+ globalIndex const rowOffset = 0 )
159159{
160- ASSERT_EQ ( indices1.size (), values1.size () );
161- ASSERT_EQ ( indices2.size (), values2.size () );
160+ ASSERT_EQ ( matrix1.numRows (), matrix2.numRows () );
161+ ASSERT_EQ ( matrix1.numColumns (), matrix2.numColumns () );
162+
163+ matrix1.move ( hostMemorySpace, false );
164+ matrix2.move ( hostMemorySpace, false );
162165
163- compareMatrixRow ( rowNumber, relTol, absTol,
164- indices1.size (), indices1.dataIfContiguous (), values1.dataIfContiguous (),
165- indices2.size (), indices2.dataIfContiguous (), values2.dataIfContiguous () );
166+ // check the accuracy across local rows
167+ for ( localIndex i = 0 ; i < matrix1.numRows (); ++i )
168+ {
169+ SCOPED_TRACE ( GEOS_FMT ( " Row {}" , i + rowOffset ) );
170+ compareMatrixRow ( relTol, absTol,
171+ matrix1.numNonZeros ( i ),
172+ matrix1.getColumns ( i ).dataIfContiguous (),
173+ matrix1.getEntries ( i ).dataIfContiguous (),
174+ matrix2.numNonZeros ( i ),
175+ matrix2.getColumns ( i ).dataIfContiguous (),
176+ matrix2.getEntries ( i ).dataIfContiguous () );
177+ }
166178}
167179
168180template < typename MATRIX >
@@ -177,45 +189,10 @@ void compareMatrices( MATRIX const & matrix1,
177189 ASSERT_EQ ( matrix1.numLocalRows (), matrix2.numLocalRows () );
178190 ASSERT_EQ ( matrix1.numLocalCols (), matrix2.numLocalCols () );
179191
180- array1d< globalIndex > indices1, indices2;
181- array1d< real64 > values1, values2;
182-
183- // check the accuracy across local rows
184- for ( globalIndex i = matrix1.ilower (); i < matrix1.iupper (); ++i )
185- {
186- indices1.resize ( matrix1.rowLength ( i ) );
187- values1.resize ( matrix1.rowLength ( i ) );
188- matrix1.getRowCopy ( i, indices1, values1 );
189-
190- indices2.resize ( matrix2.rowLength ( i ) );
191- values2.resize ( matrix2.rowLength ( i ) );
192- matrix2.getRowCopy ( i, indices2, values2 );
192+ CRSMatrix< real64, globalIndex > const mat1 = matrix1.extract ();
193+ CRSMatrix< real64, globalIndex > const mat2 = matrix2.extract ();
193194
194- compareMatrixRow ( i, relTol, absTol,
195- indices1.size (), indices1.data (), values1.data (),
196- indices2.size (), indices2.data (), values2.data () );
197- }
198- }
199-
200- template < typename T, typename COL_INDEX >
201- void compareLocalMatrices ( CRSMatrixView< T const , COL_INDEX const > const & matrix1,
202- CRSMatrixView< T const , COL_INDEX const > const & matrix2,
203- real64 const relTol = DEFAULT_REL_TOL,
204- real64 const absTol = DEFAULT_ABS_TOL )
205- {
206- ASSERT_EQ ( matrix1.numRows (), matrix2.numRows () );
207- ASSERT_EQ ( matrix1.numColumns (), matrix2.numColumns () );
208-
209- matrix1.move ( hostMemorySpace, false );
210- matrix2.move ( hostMemorySpace, false );
211-
212- // check the accuracy across local rows
213- for ( localIndex i = 0 ; i < matrix1.numRows (); ++i )
214- {
215- compareMatrixRow ( i, relTol, absTol,
216- matrix1.getColumns ( i ), matrix1.getEntries ( i ),
217- matrix2.getColumns ( i ), matrix2.getEntries ( i ) );
218- }
195+ compareLocalMatrices ( mat1.toViewConst (), mat2.toViewConst (), relTol, absTol, matrix1.ilower () );
219196}
220197
221198} // namespace testing
0 commit comments