diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h new file mode 100644 index 00000000..091cf89b --- /dev/null +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -0,0 +1,291 @@ +#pragma once + +#include + +#include "../vector.h" +#include "../vector_operations.h" + +template +class BatchedTridiagonalSolver +{ +public: + BatchedTridiagonalSolver(int matrix_dimension, int batch_count, bool is_cyclic = true) + : matrix_dimension_(matrix_dimension) + , batch_count_(batch_count) + , main_diagonal_("BatchedTridiagonalSolver::main_diagonal", matrix_dimension * batch_count) + , sub_diagonal_("BatchedTridiagonalSolver::sub_diagonal", matrix_dimension * batch_count) + , buffer_("BatchedTridiagonalSolver::buffer", is_cyclic ? matrix_dimension * batch_count : 0) + , gamma_("BatchedTridiagonalSolver::gamma", is_cyclic ? batch_count : 0) + , is_cyclic_(is_cyclic) + , is_factorized_(false) + { + assign(main_diagonal_, T(0)); + assign(sub_diagonal_, T(0)); + } + + /* ---------------------------- */ + /* Accessors for matrix entries */ + /* ---------------------------- */ + + KOKKOS_INLINE_FUNCTION + const T& main_diagonal(const int batch_idx, const int index) const + { + return main_diagonal_(batch_idx * matrix_dimension_ + index); + } + KOKKOS_INLINE_FUNCTION + T& main_diagonal(const int batch_idx, const int index) + { + return main_diagonal_(batch_idx * matrix_dimension_ + index); + } + + KOKKOS_INLINE_FUNCTION + const T& sub_diagonal(const int batch_idx, const int index) const + { + return sub_diagonal_(batch_idx * matrix_dimension_ + index); + } + KOKKOS_INLINE_FUNCTION + T& sub_diagonal(const int batch_idx, const int index) + { + return sub_diagonal_(batch_idx * matrix_dimension_ + index); + } + + KOKKOS_INLINE_FUNCTION + const T& cyclic_corner(const int batch_idx) const + { + return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1)); + } + + KOKKOS_INLINE_FUNCTION + T& cyclic_corner(const int batch_idx) + { + return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1)); + } + + /* ---------------------------------------------- */ + /* Setup: Cholesky Decomposition: A = L * D * L^T */ + /* ---------------------------------------------- */ + // This step factorizes the tridiagonal matrix into lower triangular (L) and diagonal (D) matrices. + // For cyclic systems, it also applies the Shermann-Morrison adjustment to account for the cyclic connection. + + void setup() + { + // Create local copies for lambda capture + int matrix_dimension = matrix_dimension_; + Vector main_diagonal = main_diagonal_; + Vector sub_diagonal = sub_diagonal_; + Vector gamma = gamma_; + + if (!is_cyclic_) { + Kokkos::parallel_for( + "SetupNonCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) { + // ----------------------------------- // + // Obtain offset for the current batch // + int offset = batch_idx * matrix_dimension; + + // ---------------------- // + // Cholesky Decomposition // + for (int i = 1; i < matrix_dimension; i++) { + sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1); + const T factor = sub_diagonal(offset + i - 1); + main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1); + } + }); + } + else { + Kokkos::parallel_for( + "SetupCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) { + // ----------------------------------- // + // Obtain offset for the current batch // + int offset = batch_idx * matrix_dimension; + + // ------------------------------------------------- // + // Shermann-Morrison Adjustment // + // - Modify the first and last main diagonal element // + // - Compute and store gamma for later use // + // ------------------------------------------------- // + T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1); + gamma(batch_idx) = -main_diagonal(offset + 0); + main_diagonal(offset + 0) -= gamma(batch_idx); + main_diagonal(offset + matrix_dimension - 1) -= + cyclic_corner_element * cyclic_corner_element / gamma(batch_idx); + + // ---------------------- // + // Cholesky Decomposition // + for (int i = 1; i < matrix_dimension; i++) { + sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1); + const T factor = sub_diagonal(offset + i - 1); + main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1); + } + }); + } + Kokkos::fence(); + is_factorized_ = true; + } + + /* ---------------------------------------- */ + /* Solve: Forward and Backward Substitution */ + /* ---------------------------------------- */ + // This step solves the system Ax = b using the factorized form of A. + // For cyclic systems, it also performs the Shermann-Morrison reconstruction to obtain the final solution. + + void solve(Vector rhs, int batch_offset = 0, int batch_stride = 1) + { + if (!is_factorized_) { + throw std::runtime_error("Error: Matrix must be factorized before solving."); + } + + // Compute the effective number of batches to solve + int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; + + // Create local copies for lambda capture + int matrix_dimension = matrix_dimension_; + Vector main_diagonal = main_diagonal_; + Vector sub_diagonal = sub_diagonal_; + Vector buffer = buffer_; + Vector gamma = gamma_; + + if (!is_cyclic_) { + Kokkos::parallel_for( + "SolveNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // -------------------- // + // Forward Substitution // + for (int i = 1; i < matrix_dimension; i++) { + rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); + } + + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + } + + // --------------------- // + // Backward Substitution // + for (int i = matrix_dimension - 2; i >= 0; i--) { + rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1); + } + }); + } + else { + Kokkos::parallel_for( + "SolveCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // -------------------- // + // Forward Substitution // + T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1); + buffer(offset + 0) = gamma(batch_idx); + for (int i = 1; i < matrix_dimension; i++) { + rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); + if (i < matrix_dimension - 1) + buffer(offset + i) = 0.0 - sub_diagonal(offset + i - 1) * buffer(offset + i - 1); + else + buffer(offset + i) = + cyclic_corner_element - sub_diagonal(offset + i - 1) * buffer(offset + i - 1); + } + + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + buffer(offset + i) /= main_diagonal(offset + i); + } + + // --------------------- // + // Backward Substitution // + for (int i = matrix_dimension - 2; i >= 0; i--) { + rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1); + buffer(offset + i) -= sub_diagonal(offset + i) * buffer(offset + i + 1); + } + + // ------------------------------- // + // Shermann-Morrison Reonstruction // + const T dot_product_x_v = + rhs(offset + 0) + cyclic_corner_element / gamma(batch_idx) * rhs(offset + matrix_dimension - 1); + const T dot_product_u_v = buffer(offset + 0) + cyclic_corner_element / gamma(batch_idx) * + buffer(offset + matrix_dimension - 1); + const T factor = dot_product_x_v / (1.0 + dot_product_u_v); + + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) -= factor * buffer(offset + i); + } + }); + } + Kokkos::fence(); + } + + /* ---------------------------- */ + /* Solve: Diagonal Scaling Only */ + /* ---------------------------- */ + // This step performs only the diagonal scaling part of the solve process. + // It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries. + // Note that .setup() modifies main_diagonal(0) in the cyclic case. + + void solve_diagonal(Vector rhs, int batch_offset = 0, int batch_stride = 1) + { + if (!is_factorized_) { + throw std::runtime_error("Error: Matrix must be factorized before solving."); + } + + // Compute the effective number of batches to solve + int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; + + // Create local copies for lambda capture + int matrix_dimension = matrix_dimension_; + Vector main_diagonal = main_diagonal_; + Vector gamma = gamma_; + + if (!is_cyclic_) { + Kokkos::parallel_for( + "SolveDiagonalNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + } + }); + } + else { + Kokkos::parallel_for( + "SolveDiagonalCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // ---------------- // + // Diagonal Scaling // + rhs(offset + 0) /= main_diagonal(offset + 0) + gamma(batch_idx); + for (int i = 1; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + } + }); + } + Kokkos::fence(); + } + +private: + int matrix_dimension_; + int batch_count_; + + Vector main_diagonal_; + Vector sub_diagonal_; + Vector buffer_; + Vector gamma_; + + bool is_cyclic_; + bool is_factorized_; +}; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fad28d0a..ce9c360e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,6 +14,7 @@ add_executable(gmgpolar_tests LinearAlgebra/csr_solver.cpp LinearAlgebra/tridiagonal_solver.cpp LinearAlgebra/cyclic_tridiagonal_solver.cpp + LinearAlgebra/Solvers/tridiagonal_solver.cpp PolarGrid/polargrid.cpp Interpolation/prolongation.cpp Interpolation/restriction.cpp diff --git a/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp b/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp new file mode 100644 index 00000000..a708983d --- /dev/null +++ b/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp @@ -0,0 +1,306 @@ +#include +#include +#include +#include + +#include "../../../include/LinearAlgebra/Solvers/tridiagonal_solver.h" +#include "../../../include/LinearAlgebra/vector.h" + +// clang-format off +TEST(BatchedTridiagonalSolvers, non_cyclic_tridiagonal_n_4) +{ + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = false; + + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); + + // System 1: {{2, 1, 0,0},{1,4,2,0},{0,2,6,3},{0,0,3,8}} * {{a},{b},{c},{d}} = {{1},{2},{3},{4}} + // a = 70/209, b = 69/209, c = 36/209, d = 91/209 + // System 2: {{3,1,0,0},{1,5,2,0},{0,2,7,3},{0,0,3,9}} * {{a},{b},{c},{d}} = {{2},{3},{4},{5}} + // a = 29/54, b = 7/18, c = 7/27, d = 38/81 + // System 3: {{4,1,0,0},{1,6,2,0},{0,2,8,3},{0,0,3,10}} * {{a},{b},{c},{d}} = {{3},{4},{5},{6}} + // a = 938/1473, b = 667/1473, c = 476/1473, d = 247/491 + // System 4: {{5,1,0,0},{1,7,2,0},{0,2,9,3},{0,0,3,11}} * {{a},{b},{c},{d}} = {{4},{5},{6},{7}} + // a = 248/355, b = 36/71, c = 267/710, d = 379/710 + + solver.main_diagonal(0,0) = 2.0; solver.sub_diagonal(0,0) = 1.0; + solver.main_diagonal(0,1) = 4.0; solver.sub_diagonal(0,1) = 2.0; + solver.main_diagonal(0,2) = 6.0; solver.sub_diagonal(0,2) = 3.0; + solver.main_diagonal(0,3) = 8.0; + + solver.main_diagonal(1,0) = 3.0; solver.sub_diagonal(1,0) = 1.0; + solver.main_diagonal(1,1) = 5.0; solver.sub_diagonal(1,1) = 2.0; + solver.main_diagonal(1,2) = 7.0; solver.sub_diagonal(1,2) = 3.0; + solver.main_diagonal(1,3) = 9.0; + + solver.main_diagonal(2,0) = 4.0; solver.sub_diagonal(2,0) = 1.0; + solver.main_diagonal(2,1) = 6.0; solver.sub_diagonal(2,1) = 2.0; + solver.main_diagonal(2,2) = 8.0; solver.sub_diagonal(2,2) = 3.0; + solver.main_diagonal(2,3) = 10.0; + + solver.main_diagonal(3,0) = 5.0; solver.sub_diagonal(3,0) = 1.0; + solver.main_diagonal(3,1) = 7.0; solver.sub_diagonal(3,1) = 2.0; + solver.main_diagonal(3,2) = 9.0; solver.sub_diagonal(3,2) = 3.0; + solver.main_diagonal(3,3) = 11.0; + + Vector rhs("rhs", matrix_dimension * batch_count); + + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; + + solver.setup(); + + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 70.0/209.0, tol); + EXPECT_NEAR(rhs(1), 69.0/209.0, tol); + EXPECT_NEAR(rhs(2), 36.0/209.0, tol); + EXPECT_NEAR(rhs(3), 91.0/209.0, tol); + + EXPECT_NEAR(rhs(4), 29.0/54.0, tol); + EXPECT_NEAR(rhs(5), 7.0/18.0, tol); + EXPECT_NEAR(rhs(6), 7.0/27.0, tol); + EXPECT_NEAR(rhs(7), 38.0/81.0, tol); + + EXPECT_NEAR(rhs(8), 938.0/1473.0, tol); + EXPECT_NEAR(rhs(9), 667.0/1473.0, tol); + EXPECT_NEAR(rhs(10), 476.0/1473.0, tol); + EXPECT_NEAR(rhs(11), 247.0/491.0, tol); + + EXPECT_NEAR(rhs(12), 248.0/355.0, tol); + EXPECT_NEAR(rhs(13), 36.0/71.0, tol); + EXPECT_NEAR(rhs(14), 267.0/710.0, tol); + EXPECT_NEAR(rhs(15), 379.0/710.0, tol); +} + +TEST(BatchedTridiagonalSolvers, cyclic_tridiagonal_n_4) +{ + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = true; + + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); + + // System 1: {{2, 1, 0,-1},{1,4,2,0},{0,2,6,3},{-1,0,3,8}} * {{a},{b},{c},{d}} = {{1},{2},{3},{4}} + // a = 42/67, b = 18/67, c = 10/67, d = 35/67 + // System 2: {{3,1,0,-2},{1,5,2,0},{0,2,7,3},{-2,0,3,9}} * {{a},{b},{c},{d}} = {{2},{3},{4},{5}} + // a = 287/274, b = 89/274, c = 45/274, d = 201/274 + // System 3: {{4,1,0,-3},{1,6,2,0},{0,2,8,3},{-3,0,3,10}} * {{a},{b},{c},{d}} = {{3},{4},{5},{6}} + // a = 1532/1113, b = 8/21, c = 188/1113, d = 51/53 + // System 4: {{5,1,0,-4},{1,7,2,0},{0,2,9,3},{-4,0,3,11}} * {{a},{b},{c},{d}} = {{4},{5},{6},{7}} + // a = 271/162, b = 23/54, c = 14/81, d = 97/81 + + solver.main_diagonal(0,0) = 2.0; solver.sub_diagonal(0,0) = 1.0; + solver.main_diagonal(0,1) = 4.0; solver.sub_diagonal(0,1) = 2.0; + solver.main_diagonal(0,2) = 6.0; solver.sub_diagonal(0,2) = 3.0; + solver.main_diagonal(0,3) = 8.0; solver.cyclic_corner(0) = -1.0; + + solver.main_diagonal(1,0) = 3.0; solver.sub_diagonal(1,0) = 1.0; + solver.main_diagonal(1,1) = 5.0; solver.sub_diagonal(1,1) = 2.0; + solver.main_diagonal(1,2) = 7.0; solver.sub_diagonal(1,2) = 3.0; + solver.main_diagonal(1,3) = 9.0; solver.cyclic_corner(1) = -2.0; + + solver.main_diagonal(2,0) = 4.0; solver.sub_diagonal(2,0) = 1.0; + solver.main_diagonal(2,1) = 6.0; solver.sub_diagonal(2,1) = 2.0; + solver.main_diagonal(2,2) = 8.0; solver.sub_diagonal(2,2) = 3.0; + solver.main_diagonal(2,3) = 10.0; solver.cyclic_corner(2) = -3.0; + + solver.main_diagonal(3,0) = 5.0; solver.sub_diagonal(3,0) = 1.0; + solver.main_diagonal(3,1) = 7.0; solver.sub_diagonal(3,1) = 2.0; + solver.main_diagonal(3,2) = 9.0; solver.sub_diagonal(3,2) = 3.0; + solver.main_diagonal(3,3) = 11.0; solver.cyclic_corner(3) = -4.0; + + Vector rhs("rhs", matrix_dimension * batch_count); + + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; + + solver.setup(); + + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 42.0/67.0, tol); + EXPECT_NEAR(rhs(1), 18.0/67.0, tol); + EXPECT_NEAR(rhs(2), 10.0/67.0, tol); + EXPECT_NEAR(rhs(3), 35.0/67.0, tol); + + EXPECT_NEAR(rhs(4), 287.0/274.0, tol); + EXPECT_NEAR(rhs(5), 89.0/274.0, tol); + EXPECT_NEAR(rhs(6), 45.0/274.0, tol); + EXPECT_NEAR(rhs(7), 201.0/274.0, tol); + + EXPECT_NEAR(rhs(8), 1532.0/1113.0, tol); + EXPECT_NEAR(rhs(9), 8.0/21.0, tol); + EXPECT_NEAR(rhs(10), 188.0/1113.0, tol); + EXPECT_NEAR(rhs(11), 51.0/53.0, tol); + + EXPECT_NEAR(rhs(12), 271.0/162.0, tol); + EXPECT_NEAR(rhs(13), 23.0/54.0, tol); + EXPECT_NEAR(rhs(14), 14.0/81.0, tol); + EXPECT_NEAR(rhs(15), 97.0/81.0, tol); +} + +TEST(BatchedTridiagonalSolvers, non_cyclic_diagonal_n_4) +{ + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = false; + + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); + + solver.main_diagonal(0,0) = 2.0; + solver.main_diagonal(0,1) = 4.0; + solver.main_diagonal(0,2) = 6.0; + solver.main_diagonal(0,3) = 8.0; + + solver.main_diagonal(1,0) = 3.0; + solver.main_diagonal(1,1) = 5.0; + solver.main_diagonal(1,2) = 7.0; + solver.main_diagonal(1,3) = 9.0; + + solver.main_diagonal(2,0) = 4.0; + solver.main_diagonal(2,1) = 6.0; + solver.main_diagonal(2,2) = 8.0; + solver.main_diagonal(2,3) = 10.0; + + solver.main_diagonal(3,0) = 5.0; + solver.main_diagonal(3,1) = 7.0; + solver.main_diagonal(3,2) = 9.0; + solver.main_diagonal(3,3) = 11.0; + + Vector rhs("rhs", matrix_dimension * batch_count); + + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; + + solver.setup(); + + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 1.0/2.0, tol); + EXPECT_NEAR(rhs(1), 2.0/4.0, tol); + EXPECT_NEAR(rhs(2), 3.0/6.0, tol); + EXPECT_NEAR(rhs(3), 4.0/8.0, tol); + + EXPECT_NEAR(rhs(4), 2.0/3.0, tol); + EXPECT_NEAR(rhs(5), 3.0/5.0, tol); + EXPECT_NEAR(rhs(6), 4.0/7.0, tol); + EXPECT_NEAR(rhs(7), 5.0/9.0, tol); + + EXPECT_NEAR(rhs(8), 3.0/4.0, tol); + EXPECT_NEAR(rhs(9), 4.0/6.0, tol); + EXPECT_NEAR(rhs(10), 5.0/8.0, tol); + EXPECT_NEAR(rhs(11), 6.0/10.0, tol); + + EXPECT_NEAR(rhs(12), 4.0/5.0, tol); + EXPECT_NEAR(rhs(13), 5.0/7.0, tol); + EXPECT_NEAR(rhs(14), 6.0/9.0, tol); + EXPECT_NEAR(rhs(15), 7.0/11.0, tol); +} + +TEST(BatchedTridiagonalSolvers, cyclic_diagonal_n_4) +{ + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = true; + + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); + + solver.main_diagonal(0,0) = 2.0; + solver.main_diagonal(0,1) = 4.0; + solver.main_diagonal(0,2) = 6.0; + solver.main_diagonal(0,3) = 8.0; + + solver.main_diagonal(1,0) = 3.0; + solver.main_diagonal(1,1) = 5.0; + solver.main_diagonal(1,2) = 7.0; + solver.main_diagonal(1,3) = 9.0; + + solver.main_diagonal(2,0) = 4.0; + solver.main_diagonal(2,1) = 6.0; + solver.main_diagonal(2,2) = 8.0; + solver.main_diagonal(2,3) = 10.0; + + solver.main_diagonal(3,0) = 5.0; + solver.main_diagonal(3,1) = 7.0; + solver.main_diagonal(3,2) = 9.0; + solver.main_diagonal(3,3) = 11.0; + + Vector rhs("rhs", matrix_dimension * batch_count); + + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; + + solver.setup(); + + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 1.0/2.0, tol); + EXPECT_NEAR(rhs(1), 2.0/4.0, tol); + EXPECT_NEAR(rhs(2), 3.0/6.0, tol); + EXPECT_NEAR(rhs(3), 4.0/8.0, tol); + + EXPECT_NEAR(rhs(4), 2.0/3.0, tol); + EXPECT_NEAR(rhs(5), 3.0/5.0, tol); + EXPECT_NEAR(rhs(6), 4.0/7.0, tol); + EXPECT_NEAR(rhs(7), 5.0/9.0, tol); + + EXPECT_NEAR(rhs(8), 3.0/4.0, tol); + EXPECT_NEAR(rhs(9), 4.0/6.0, tol); + EXPECT_NEAR(rhs(10), 5.0/8.0, tol); + EXPECT_NEAR(rhs(11), 6.0/10.0, tol); + + EXPECT_NEAR(rhs(12), 4.0/5.0, tol); + EXPECT_NEAR(rhs(13), 5.0/7.0, tol); + EXPECT_NEAR(rhs(14), 6.0/9.0, tol); + EXPECT_NEAR(rhs(15), 7.0/11.0, tol); +}