@@ -122,13 +122,11 @@ inline void checkRelativeError( real64 const v1, real64 const v2, real64 const r
122122 EXPECT_PRED_FORMAT4 ( checkRelativeErrorFormat, v1, v2, relTol, absTol );
123123}
124124
125- template < typename ROW_INDEX, typename COL_INDEX, typename VALUE >
126- void compareMatrixRow ( ROW_INDEX const rowNumber, VALUE const relTol, VALUE const absTol,
125+ template < typename COL_INDEX, typename VALUE >
126+ void compareMatrixRow ( VALUE const relTol, VALUE const absTol,
127127 localIndex const length1, COL_INDEX const * const indices1, VALUE const * const values1,
128128 localIndex const length2, COL_INDEX const * const indices2, VALUE const * const values2 )
129129{
130- SCOPED_TRACE ( " Row " + std::to_string ( rowNumber ));
131-
132130 EXPECT_EQ ( length1, length2 );
133131
134132 for ( localIndex j1 = 0 , j2 = 0 ; j1 < length1 && j2 < length2; ++j1, ++j2 )
@@ -153,17 +151,31 @@ void compareMatrixRow( ROW_INDEX const rowNumber, VALUE const relTol, VALUE cons
153151 }
154152}
155153
156- template < typename ROW_INDEX, typename COL_INDEX, typename VALUE >
157- void compareMatrixRow ( ROW_INDEX const rowNumber, VALUE const relTol, VALUE const absTol,
158- arraySlice1d< COL_INDEX const > indices1, arraySlice1d< VALUE const > values1,
159- arraySlice1d< COL_INDEX const > indices2, arraySlice1d< VALUE const > values2 )
154+ template < typename T, typename COL_INDEX >
155+ void compareLocalMatrices ( CRSMatrixView< T const , COL_INDEX const > const & matrix1,
156+ CRSMatrixView< T const , COL_INDEX const > const & matrix2,
157+ real64 const relTol = DEFAULT_REL_TOL,
158+ real64 const absTol = DEFAULT_ABS_TOL,
159+ globalIndex const rowOffset = 0 )
160160{
161- ASSERT_EQ ( indices1.size (), values1.size () );
162- ASSERT_EQ ( indices2.size (), values2.size () );
161+ ASSERT_EQ ( matrix1.numRows (), matrix2.numRows () );
162+ ASSERT_EQ ( matrix1.numColumns (), matrix2.numColumns () );
163+
164+ matrix1.move ( hostMemorySpace, false );
165+ matrix2.move ( hostMemorySpace, false );
163166
164- compareMatrixRow ( rowNumber, relTol, absTol,
165- indices1.size (), indices1.dataIfContiguous (), values1.dataIfContiguous (),
166- indices2.size (), indices2.dataIfContiguous (), values2.dataIfContiguous () );
167+ // check the accuracy across local rows
168+ for ( localIndex i = 0 ; i < matrix1.numRows (); ++i )
169+ {
170+ SCOPED_TRACE ( GEOS_FMT ( " Row {}" , i + rowOffset ) );
171+ compareMatrixRow ( relTol, absTol,
172+ matrix1.numNonZeros ( i ),
173+ matrix1.getColumns ( i ).dataIfContiguous (),
174+ matrix1.getEntries ( i ).dataIfContiguous (),
175+ matrix2.numNonZeros ( i ),
176+ matrix2.getColumns ( i ).dataIfContiguous (),
177+ matrix2.getEntries ( i ).dataIfContiguous () );
178+ }
167179}
168180
169181template < typename MATRIX >
@@ -178,45 +190,10 @@ void compareMatrices( MATRIX const & matrix1,
178190 ASSERT_EQ ( matrix1.numLocalRows (), matrix2.numLocalRows () );
179191 ASSERT_EQ ( matrix1.numLocalCols (), matrix2.numLocalCols () );
180192
181- array1d< globalIndex > indices1, indices2;
182- array1d< real64 > values1, values2;
183-
184- // check the accuracy across local rows
185- for ( globalIndex i = matrix1.ilower (); i < matrix1.iupper (); ++i )
186- {
187- indices1.resize ( matrix1.rowLength ( i ) );
188- values1.resize ( matrix1.rowLength ( i ) );
189- matrix1.getRowCopy ( i, indices1, values1 );
190-
191- indices2.resize ( matrix2.rowLength ( i ) );
192- values2.resize ( matrix2.rowLength ( i ) );
193- matrix2.getRowCopy ( i, indices2, values2 );
193+ CRSMatrix< real64, globalIndex > const mat1 = matrix1.extract ();
194+ CRSMatrix< real64, globalIndex > const mat2 = matrix2.extract ();
194195
195- compareMatrixRow ( i, relTol, absTol,
196- indices1.size (), indices1.data (), values1.data (),
197- indices2.size (), indices2.data (), values2.data () );
198- }
199- }
200-
201- template < typename T, typename COL_INDEX >
202- void compareLocalMatrices ( CRSMatrixView< T const , COL_INDEX const > const & matrix1,
203- CRSMatrixView< T const , COL_INDEX const > const & matrix2,
204- real64 const relTol = DEFAULT_REL_TOL,
205- real64 const absTol = DEFAULT_ABS_TOL )
206- {
207- ASSERT_EQ ( matrix1.numRows (), matrix2.numRows () );
208- ASSERT_EQ ( matrix1.numColumns (), matrix2.numColumns () );
209-
210- matrix1.move ( hostMemorySpace, false );
211- matrix2.move ( hostMemorySpace, false );
212-
213- // check the accuracy across local rows
214- for ( localIndex i = 0 ; i < matrix1.numRows (); ++i )
215- {
216- compareMatrixRow ( i, relTol, absTol,
217- matrix1.getColumns ( i ), matrix1.getEntries ( i ),
218- matrix2.getColumns ( i ), matrix2.getEntries ( i ) );
219- }
196+ compareLocalMatrices ( mat1.toViewConst (), mat2.toViewConst (), relTol, absTol, matrix1.ilower () );
220197}
221198
222199} // namespace testing
0 commit comments