From 9a030febd8f92a82e2ec1825440292aa33bbc03c Mon Sep 17 00:00:00 2001 From: julianlitz Date: Sat, 31 Jan 2026 00:01:08 +0100 Subject: [PATCH 1/8] Wo war Gondor als die Westfold fiel? --- .../Solvers/tridiagonal_solver.h | 310 +++++++++ tests/CMakeLists.txt | 1 + .../Solvers/tridiagonal_solver.cpp | 637 ++++++++++++++++++ 3 files changed, 948 insertions(+) create mode 100644 include/LinearAlgebra/Solvers/tridiagonal_solver.h create mode 100644 tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h new file mode 100644 index 00000000..9c6ea63c --- /dev/null +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -0,0 +1,310 @@ +#pragma once + +#include + +#include "../vector.h" + +template +class BatchedTridiagonalSolver +{ +public: + BatchedTridiagonalSolver(int matrix_dimension, int batch_count, bool is_cyclic = true) + : matrix_dimension_(matrix_dimension) + , batch_count_(batch_count) + , main_diagonal_("BatchedTridiagonalSolver::main_diagonal", matrix_dimension * batch_count) + , sub_diagonal_("BatchedTridiagonalSolver::sub_diagonal", matrix_dimension * batch_count) + , buffer_("BatchedTridiagonalSolver::buffer", is_cyclic ? matrix_dimension * batch_count : 0) + , gamma_("BatchedTridiagonalSolver::gamma", is_cyclic ? batch_count : 0) + , is_cyclic_(is_cyclic) + , is_factorized_(false) + { + Kokkos::deep_copy(main_diagonal_, T(0)); + Kokkos::deep_copy(sub_diagonal_, T(0)); + } + + /* ---------------------------- */ + /* Accessors for matrix entries */ + /* ---------------------------- */ + + KOKKOS_INLINE_FUNCTION + const T& main_diagonal(const int batch_idx, const int index) const + { + return main_diagonal_(batch_idx * matrix_dimension_ + index); + } + KOKKOS_INLINE_FUNCTION + T& main_diagonal(const int batch_idx, const int index) + { + return main_diagonal_(batch_idx * matrix_dimension_ + index); + } + + KOKKOS_INLINE_FUNCTION + const T& sub_diagonal(const int batch_idx, const int index) const + { + return sub_diagonal_(batch_idx * matrix_dimension_ + index); + } + KOKKOS_INLINE_FUNCTION + T& sub_diagonal(const int batch_idx, const int index) + { + return sub_diagonal_(batch_idx * matrix_dimension_ + index); + } + + KOKKOS_INLINE_FUNCTION + const T& cyclic_corner(const int batch_idx) const + { + return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1)); + } + + KOKKOS_INLINE_FUNCTION + T& cyclic_corner(const int batch_idx) + { + return sub_diagonal_(batch_idx * matrix_dimension_ + (matrix_dimension_ - 1)); + } + + /* ---------------------------------------------- */ + /* Setup: Cholesky Decomposition: A = L * D * L^T */ + /* ---------------------------------------------- */ + // This step factorizes the tridiagonal matrix into lower triangular (L) and diagonal (D) matrices. + // For cyclic systems, it also applies the Shermann-Morrison adjustment to account for the cyclic connection. + + struct SetupNonCyclic { + int m_matrix_dimension; + Vector m_main_diagonal; + Vector m_sub_diagonal; + + void operator()(const int batch_idx) const + { + // ----------------------------------- // + // Obtain offset for the current batch // + int offset = batch_idx * m_matrix_dimension; + + // ---------------------- // + // Cholesky Decomposition // + for (int i = 1; i < m_matrix_dimension; i++) { + m_sub_diagonal(offset + i - 1) /= m_main_diagonal(offset + i - 1); + const T factor = m_sub_diagonal(offset + i - 1); + m_main_diagonal(offset + i) -= factor * factor * m_main_diagonal(offset + i - 1); + } + } + }; + + struct SetupCyclic { + int m_matrix_dimension; + Vector m_main_diagonal; + Vector m_sub_diagonal; + Vector m_gamma; + + void operator()(const int batch_idx) const + { + // ----------------------------------- // + // Obtain offset for the current batch // + int offset = batch_idx * m_matrix_dimension; + + // ------------------------------------------------- // + // Shermann-Morrison Adjustment // + // - Modify the first and last main diagonal element // + // - Compute and store gamma for later use // + // ------------------------------------------------- // + T cyclic_corner_element = m_sub_diagonal(offset + m_matrix_dimension - 1); + /* gamma_ = -main_diagonal(0);*/ + m_gamma(batch_idx) = -m_main_diagonal(offset + 0); + /* main_diagonal(0) -= gamma_;*/ + m_main_diagonal(offset + 0) -= m_gamma(batch_idx); + /* main_diagonal(matrix_dimension_ - 1) -= cyclic_corner_element()^2 / gamma_;*/ + m_main_diagonal(offset + m_matrix_dimension - 1) -= + cyclic_corner_element * cyclic_corner_element / m_gamma(batch_idx); + + // ---------------------- // + // Cholesky Decomposition // + for (int i = 1; i < m_matrix_dimension; i++) { + m_sub_diagonal(offset + i - 1) /= m_main_diagonal(offset + i - 1); + const T factor = m_sub_diagonal(offset + i - 1); + m_main_diagonal(offset + i) -= factor * factor * m_main_diagonal(offset + i - 1); + } + } + }; + + void setup() + { + if (!is_cyclic_) { + SetupNonCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_}; + Kokkos::parallel_for("SetupNonCyclic", batch_count_, functor); + } + else { + SetupCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, gamma_}; + Kokkos::parallel_for("SetupNonCyclic", batch_count_, functor); + } + Kokkos::fence(); + is_factorized_ = true; + } + + /* ---------------------------------------- */ + /* Solve: Forward and Backward Substitution */ + /* ---------------------------------------- */ + // This step solves the system Ax = b using the factorized form of A. + // For cyclic systems, it also performs the Shermann-Morrison reconstruction to obtain the final solution. + + struct SolveNonCyclic { + int m_matrix_dimension; + Vector m_main_diagonal; + Vector m_sub_diagonal; + Vector m_rhs; + int m_batch_offset; + int m_batch_stride; + + void operator()(const int batch_idx) const + { + // ----------------------------------- // + // Obtain offset for the current batch // + int solve_batch = m_batch_stride * batch_idx + m_batch_offset; + int offset = solve_batch * m_matrix_dimension; + + // -------------------- // + // Forward Substitution // + for (int i = 1; i < m_matrix_dimension; i++) { + m_rhs(offset + i) -= m_sub_diagonal(offset + i - 1) * m_rhs(offset + i - 1); + } + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < m_matrix_dimension; i++) { + m_rhs(offset + i) /= m_main_diagonal(offset + i); + } + // --------------------- // + // Backward Substitution // + for (int i = m_matrix_dimension - 2; i >= 0; i--) { + m_rhs(offset + i) -= m_sub_diagonal(offset + i) * m_rhs(offset + i + 1); + } + } + }; + + struct SolveCyclic { + int m_matrix_dimension; + Vector m_main_diagonal; + Vector m_sub_diagonal; + Vector m_buffer; + Vector m_gamma; + Vector m_rhs; + int m_batch_offset; + int m_batch_stride; + + void operator()(const int batch_idx) const + { + // ----------------------------------- // + // Obtain offset for the current batch // + int solve_batch = m_batch_stride * batch_idx + m_batch_offset; + int offset = solve_batch * m_matrix_dimension; + + // -------------------- // + // Forward Substitution // + T cyclic_corner_element = m_sub_diagonal(offset + m_matrix_dimension - 1); + m_buffer(offset + 0) = m_gamma(batch_idx); + for (int i = 1; i < m_matrix_dimension; i++) { + m_rhs(offset + i) -= m_sub_diagonal(offset + i - 1) * m_rhs(offset + i - 1); + if (i < m_matrix_dimension - 1) + m_buffer(offset + i) = 0.0 - m_sub_diagonal(offset + i - 1) * m_buffer(offset + i - 1); + else + m_buffer(offset + i) = + cyclic_corner_element - m_sub_diagonal(offset + i - 1) * m_buffer(offset + i - 1); + } + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < m_matrix_dimension; i++) { + m_rhs(offset + i) /= m_main_diagonal(offset + i); + m_buffer(offset + i) /= m_main_diagonal(offset + i); + } + // --------------------- // + // Backward Substitution // + for (int i = m_matrix_dimension - 2; i >= 0; i--) { + m_rhs(offset + i) -= m_sub_diagonal(offset + i) * m_rhs(offset + i + 1); + m_buffer(offset + i) -= m_sub_diagonal(offset + i) * m_buffer(offset + i + 1); + } + // ------------------------------- // + // Shermann-Morrison Reonstruction // + const T dot_product_x_v = + m_rhs(offset + 0) + cyclic_corner_element / m_gamma(batch_idx) * m_rhs(offset + m_matrix_dimension - 1); + const T dot_product_u_v = m_buffer(offset + 0) + cyclic_corner_element / m_gamma(batch_idx) * + m_buffer(offset + m_matrix_dimension - 1); + const T factor = dot_product_x_v / (1.0 + dot_product_u_v); + + for (int i = 0; i < m_matrix_dimension; i++) { + m_rhs(offset + i) -= factor * m_buffer(offset + i); + } + } + }; + + void solve(Vector rhs, int batch_offset = 0, int batch_stride = 1) + { + if (!is_factorized_) { + throw std::runtime_error("Error: Matrix must be factorized before solving."); + } + + // Compute the effective number of batches to solve + int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; + + if (!is_cyclic_) { + SolveNonCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, rhs, batch_offset, batch_stride}; + Kokkos::parallel_for("SolveNonCyclic", effective_batch_count, functor); + } + else { + SolveCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, buffer_, gamma_, rhs, + batch_offset, batch_stride}; + Kokkos::parallel_for("SolveCyclic", effective_batch_count, functor); + } + Kokkos::fence(); + } + + /* ---------------------------- */ + /* Solve: Diagonal Scaling Only */ + /* ---------------------------- */ + // This step performs only the diagonal scaling part of the solve process. + // It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries. + // Note that .setup() doesn't modify the main diagonal in this case. + + struct SolveDiagonal { + int m_matrix_dimension; + Vector m_main_diagonal; + Vector m_rhs; + int m_batch_offset; + int m_batch_stride; + + void operator()(const int batch_idx) const + { + // ----------------------------------- // + // Obtain offset for the current batch // + int solve_batch = m_batch_stride * batch_idx + m_batch_offset; + int offset = solve_batch * m_matrix_dimension; + + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < m_matrix_dimension; i++) { + m_rhs(offset + i) /= m_main_diagonal(offset + i); + } + } + }; + + void solve_diagonal(Vector rhs, int batch_offset = 0, int batch_stride = 1) + { + if (!is_factorized_) { + throw std::runtime_error("Error: Matrix must be factorized before solving."); + } + + // Compute the effective number of batches to solve + int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; + + SolveDiagonal functor{matrix_dimension_, main_diagonal_, rhs, batch_offset, batch_stride}; + Kokkos::parallel_for("SolveDiagonal", effective_batch_count, functor); + + Kokkos::fence(); + } + +private: + int matrix_dimension_; + int batch_count_; + + Vector main_diagonal_; + Vector sub_diagonal_; + Vector buffer_; + Vector gamma_; + + bool is_cyclic_; + bool is_factorized_; +}; \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fad28d0a..ce9c360e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,6 +14,7 @@ add_executable(gmgpolar_tests LinearAlgebra/csr_solver.cpp LinearAlgebra/tridiagonal_solver.cpp LinearAlgebra/cyclic_tridiagonal_solver.cpp + LinearAlgebra/Solvers/tridiagonal_solver.cpp PolarGrid/polargrid.cpp Interpolation/prolongation.cpp Interpolation/restriction.cpp diff --git a/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp b/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp new file mode 100644 index 00000000..2643dff4 --- /dev/null +++ b/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp @@ -0,0 +1,637 @@ +#include +#include +#include +#include + +#include "../../../include/LinearAlgebra/Solvers/tridiagonal_solver.h" +#include "../../../include/LinearAlgebra/vector.h" + +// Helper function to initialize tridiagonal system +template +void initialize_tridiagonal_system(BatchedTridiagonalSolver& solver, Vector& rhs, int batch_idx, int matrix_dim, + bool is_cyclic) +{ + // Create a simple test system: A*x = b + // Main diagonal = 4, Sub diagonal = -1 + for (int i = 0; i < matrix_dim; i++) { + solver.main_diagonal(batch_idx, i) = 4.0; + if (i < matrix_dim - 1) { + solver.sub_diagonal(batch_idx, i) = -1.0; + } + } + + if (is_cyclic) { + // For cyclic, set the corner element + solver.sub_diagonal(batch_idx, matrix_dim - 1) = -1.0; + } + + // Set RHS to 1.0 for simple test + for (int i = 0; i < matrix_dim; i++) { + rhs(batch_idx * matrix_dim + i) = 1.0; + } +} + +// Helper function to verify solution +// NOTE: This function uses the ORIGINAL matrix values (before setup() modifies them) +// The test matrix has main_diagonal = 4.0 and sub_diagonal = -1.0 +template +bool verify_solution(const Vector& x, int batch_idx, int matrix_dim, bool is_cyclic, T tolerance = 1e-6) +{ + // Reconstruct the original test matrix (main=4, sub=-1) and compute A*x + // NOTE: We use hardcoded values because setup() has already modified the solver's internal state + std::vector Ax(matrix_dim, T(0)); + + for (int i = 0; i < matrix_dim; ++i) { + Ax[i] += T(4.0) * x(batch_idx * matrix_dim + i); + if (i > 0) { + Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i - 1); + } + if (i < matrix_dim - 1) { + Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i + 1); + } + } + + if (is_cyclic) { + Ax[0] += T(-1.0) * x(batch_idx * matrix_dim + matrix_dim - 1); + Ax[matrix_dim - 1] += T(-1.0) * x(batch_idx * matrix_dim + 0); + } + + // Check if Ax ≈ b (original RHS = 1.0) + T max_error = T(0); + for (int i = 0; i < matrix_dim; ++i) { + T error = std::abs(Ax[i] - T(1.0)); + if (error > max_error) + max_error = error; + } + + return max_error < tolerance; +} + +// Helper function to compute actual error for more detailed testing +// NOTE: Uses hardcoded original matrix values (main=4.0, sub=-1.0) +template +T compute_solution_error(const Vector& x, int batch_idx, int matrix_dim, bool is_cyclic) +{ + std::vector Ax(matrix_dim, T(0)); + + for (int i = 0; i < matrix_dim; ++i) { + Ax[i] += T(4.0) * x(batch_idx * matrix_dim + i); + if (i > 0) { + Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i - 1); + } + if (i < matrix_dim - 1) { + Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i + 1); + } + } + + if (is_cyclic) { + Ax[0] += T(-1.0) * x(batch_idx * matrix_dim + matrix_dim - 1); + Ax[matrix_dim - 1] += T(-1.0) * x(batch_idx * matrix_dim + 0); + } + + T max_error = T(0); + for (int i = 0; i < matrix_dim; ++i) { + T error = std::abs(Ax[i] - T(1.0)); + if (error > max_error) + max_error = error; + } + + return max_error; +} + +// Test fixture for BatchedTridiagonalSolver tests +class BatchedTridiagonalSolverTest : public ::testing::Test +{ +protected: + void SetUp() override + { + // Kokkos initialization is handled in main() + } + + void TearDown() override + { + // Cleanup if needed + } + + // Common test parameters + static constexpr int default_matrix_dim = 10; + static constexpr int default_batch_count = 8; + static constexpr double default_tolerance = 1e-6; +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, NonCyclicAllBatches) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize all batches + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + // Factorize + solver.setup(); + + // Solve all batches + solver.solve(rhs, 0, 1); + + // Verify all batches + for (int b = 0; b < batch_count; b++) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) + << "Batch " << b << " failed verification"; + } +} + +TEST_F(BatchedTridiagonalSolverTest, CyclicAllBatches) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize all batches + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, true); + } + + // Factorize + solver.setup(); + + // Solve all batches + solver.solve(rhs, 0, 1); + + // Verify all batches + for (int b = 0; b < batch_count; b++) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, true, default_tolerance)) + << "Batch " << b << " failed verification"; + } +} + +// ============================================================================ +// Stride and Offset Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, NonCyclicEvenBatchesStride2Offset0) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize all batches + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + solver.setup(); + + // Solve only even batches (0, 2, 4, 6) + solver.solve(rhs, 0, 2); + + // Verify only even batches + std::vector even_batches = {0, 2, 4, 6}; + for (int b : even_batches) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) + << "Even batch " << b << " failed verification"; + } +} + +TEST_F(BatchedTridiagonalSolverTest, NonCyclicOddBatchesStride2Offset1) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize all batches + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + solver.setup(); + + // Solve only odd batches (1, 3, 5, 7) + solver.solve(rhs, 1, 2); + + // Verify only odd batches + std::vector odd_batches = {1, 3, 5, 7}; + for (int b : odd_batches) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) + << "Odd batch " << b << " failed verification"; + } +} + +TEST_F(BatchedTridiagonalSolverTest, CyclicStride3Offset1) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize all batches + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, true); + } + + solver.setup(); + + // Solve batches with stride 3, offset 1 (1, 4, 7) + solver.solve(rhs, 1, 3); + + // Verify + std::vector batch_indices = {1, 4, 7}; + for (int b : batch_indices) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, true, default_tolerance)) + << "Batch " << b << " failed verification"; + } +} + +// ============================================================================ +// Edge Cases +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, SingleBatchWithOffset) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + const int target_batch = 5; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize all batches + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + solver.setup(); + + // Solve only batch 5 + solver.solve(rhs, target_batch, batch_count); + + // Verify only batch 5 + EXPECT_TRUE(verify_solution(rhs, target_batch, matrix_dim, false, default_tolerance)) + << "Single batch " << target_batch << " failed verification"; +} + +TEST_F(BatchedTridiagonalSolverTest, SmallMatrixSize) +{ + const int matrix_dim = 3; + const int batch_count = 4; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + solver.setup(); + solver.solve(rhs, 0, 1); + + for (int b = 0; b < batch_count; b++) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) + << "Small matrix batch " << b << " failed verification"; + } +} + +TEST_F(BatchedTridiagonalSolverTest, SingleBatchSingleElement) +{ + const int matrix_dim = 1; + const int batch_count = 1; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // For 1x1 matrix, just set main diagonal + solver.main_diagonal(0, 0) = 4.0; + rhs(0) = 1.0; + + solver.setup(); + solver.solve(rhs, 0, 1); + + // Solution should be 1.0 / 4.0 = 0.25 + EXPECT_NEAR(rhs(0), 0.25, default_tolerance); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, SolveBeforeSetupThrows) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Try to solve without calling setup() first + EXPECT_THROW({ solver.solve(rhs, 0, 1); }, std::runtime_error); +} + +TEST_F(BatchedTridiagonalSolverTest, DiagonalSolveBeforeSetupThrows) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Try to solve_diagonal without calling setup() first + EXPECT_THROW({ solver.solve_diagonal(rhs, 0, 1); }, std::runtime_error); +} + +// ============================================================================ +// Diagonal Solve Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, DiagonalSolveAllBatches) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize with diagonal-only system + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + solver.main_diagonal(b, i) = 2.0; + rhs(b * matrix_dim + i) = 4.0; + } + } + + solver.setup(); + solver.solve_diagonal(rhs, 0, 1); + + // Each solution element should be 4.0 / 2.0 = 2.0 + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + EXPECT_NEAR(rhs(b * matrix_dim + i), 2.0, default_tolerance) + << "Diagonal solve failed at batch " << b << ", index " << i; + } + } +} + +TEST_F(BatchedTridiagonalSolverTest, DiagonalSolveWithStride) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + // Initialize with diagonal-only system + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + solver.main_diagonal(b, i) = 3.0; + rhs(b * matrix_dim + i) = 6.0; + } + } + + solver.setup(); + solver.solve_diagonal(rhs, 1, 2); // Solve odd batches only + + // Check odd batches: should be 6.0 / 3.0 = 2.0 + std::vector odd_batches = {1, 3, 5, 7}; + for (int b : odd_batches) { + for (int i = 0; i < matrix_dim; i++) { + EXPECT_NEAR(rhs(b * matrix_dim + i), 2.0, default_tolerance) + << "Diagonal solve with stride failed at batch " << b << ", index " << i; + } + } +} + +// ============================================================================ +// Numerical Accuracy Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, AccuracyNonCyclic) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + solver.setup(); + solver.solve(rhs, 0, 1); + + // Check that error is well below tolerance + for (int b = 0; b < batch_count; b++) { + double error = compute_solution_error(rhs, b, matrix_dim, false); + EXPECT_LT(error, 1e-10) << "Error too large for batch " << b; + } +} + +TEST_F(BatchedTridiagonalSolverTest, AccuracyCyclic) +{ + const int matrix_dim = default_matrix_dim; + const int batch_count = default_batch_count; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); + Vector rhs("rhs", matrix_dim * batch_count); + + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, true); + } + + solver.setup(); + solver.solve(rhs, 0, 1); + + // Check that error is well below tolerance + for (int b = 0; b < batch_count; b++) { + double error = compute_solution_error(rhs, b, matrix_dim, true); + EXPECT_LT(error, 1e-10) << "Error too large for cyclic batch " << b; + } +} + +// ============================================================================ +// Performance/Stress Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, LargeBatchCount) +{ + const int matrix_dim = default_matrix_dim; + const int large_batch = 1000; + + BatchedTridiagonalSolver solver(matrix_dim, large_batch, false); + Vector rhs("rhs", matrix_dim * large_batch); + + // Initialize + for (int b = 0; b < large_batch; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + // Factorize and solve + solver.setup(); + solver.solve(rhs, 0, 1); + + // Verify a few random batches + std::vector test_batches = {0, 250, 500, 750, 999}; + for (int b : test_batches) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) + << "Large batch test failed at batch " << b; + } +} + +TEST_F(BatchedTridiagonalSolverTest, LargeMatrixDimension) +{ + const int matrix_dim = 100; + const int batch_count = 4; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + Vector rhs("rhs", matrix_dim * batch_count); + + for (int b = 0; b < batch_count; b++) { + initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); + } + + solver.setup(); + solver.solve(rhs, 0, 1); + + for (int b = 0; b < batch_count; b++) { + EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) + << "Large matrix dimension test failed at batch " << b; + } +} + +// ============================================================================ +// Accessor Tests +// ============================================================================ + +TEST_F(BatchedTridiagonalSolverTest, MainDiagonalAccessors) +{ + const int matrix_dim = 5; + const int batch_count = 2; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + + // Test write access + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + solver.main_diagonal(b, i) = static_cast(b * matrix_dim + i); + } + } + + // Test read access BEFORE setup() + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + EXPECT_DOUBLE_EQ(solver.main_diagonal(b, i), static_cast(b * matrix_dim + i)) + << "Main diagonal accessor failed at batch " << b << ", index " << i; + } + } +} + +TEST_F(BatchedTridiagonalSolverTest, SubDiagonalAccessors) +{ + const int matrix_dim = 5; + const int batch_count = 2; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + + // Test write access + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + solver.sub_diagonal(b, i) = static_cast(100 + b * matrix_dim + i); + } + } + + // Test read access BEFORE setup() + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + EXPECT_DOUBLE_EQ(solver.sub_diagonal(b, i), static_cast(100 + b * matrix_dim + i)) + << "Sub diagonal accessor failed at batch " << b << ", index " << i; + } + } +} + +TEST_F(BatchedTridiagonalSolverTest, CyclicCornerAccessors) +{ + const int matrix_dim = 5; + const int batch_count = 2; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); + + // Test write access via cyclic_corner + for (int b = 0; b < batch_count; b++) { + solver.cyclic_corner(b) = static_cast(200 + b); + } + + // Test read access via cyclic_corner BEFORE setup() + for (int b = 0; b < batch_count; b++) { + EXPECT_DOUBLE_EQ(solver.cyclic_corner(b), static_cast(200 + b)) + << "Cyclic corner accessor failed at batch " << b; + } + + // Verify that cyclic_corner actually accesses sub_diagonal at the right location + for (int b = 0; b < batch_count; b++) { + EXPECT_DOUBLE_EQ(solver.cyclic_corner(b), solver.sub_diagonal(b, matrix_dim - 1)) + << "Cyclic corner should access sub_diagonal at index matrix_dim-1"; + } +} + +TEST_F(BatchedTridiagonalSolverTest, SetupModifiesInternalState) +{ + // This test verifies that setup() modifies the internal diagonal values + const int matrix_dim = 5; + const int batch_count = 2; + + BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); + + // Set up initial values + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + solver.main_diagonal(b, i) = 4.0; + if (i < matrix_dim - 1) { + solver.sub_diagonal(b, i) = -1.0; + } + } + } + + // Store original values for comparison + std::vector original_main(matrix_dim * batch_count); + std::vector original_sub(matrix_dim * batch_count); + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + original_main[b * matrix_dim + i] = solver.main_diagonal(b, i); + original_sub[b * matrix_dim + i] = solver.sub_diagonal(b, i); + } + } + + // Call setup() - this performs Cholesky factorization + solver.setup(); + + // Verify that values have been modified + bool main_changed = false; + bool sub_changed = false; + for (int b = 0; b < batch_count; b++) { + for (int i = 0; i < matrix_dim; i++) { + if (std::abs(solver.main_diagonal(b, i) - original_main[b * matrix_dim + i]) > 1e-10) { + main_changed = true; + } + if (std::abs(solver.sub_diagonal(b, i) - original_sub[b * matrix_dim + i]) > 1e-10) { + sub_changed = true; + } + } + } + + EXPECT_TRUE(main_changed) << "setup() should modify main diagonal values (Cholesky factorization)"; + EXPECT_TRUE(sub_changed) << "setup() should modify sub diagonal values (Cholesky factorization)"; +} From 6c30cd1ea6f17dc50dc5a0b1ee69251fd95cee79 Mon Sep 17 00:00:00 2001 From: Julian Litz <91479202+julianlitz@users.noreply.github.com> Date: Sat, 31 Jan 2026 00:35:42 +0100 Subject: [PATCH 2/8] Update comment --- include/LinearAlgebra/Solvers/tridiagonal_solver.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h index 9c6ea63c..1fe163a1 100644 --- a/include/LinearAlgebra/Solvers/tridiagonal_solver.h +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -131,7 +131,7 @@ class BatchedTridiagonalSolver } else { SetupCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, gamma_}; - Kokkos::parallel_for("SetupNonCyclic", batch_count_, functor); + Kokkos::parallel_for("SetupCyclic", batch_count_, functor); } Kokkos::fence(); is_factorized_ = true; From 52a2b326ac37431f15cfea602257bd9054c3307a Mon Sep 17 00:00:00 2001 From: Julian Litz <91479202+julianlitz@users.noreply.github.com> Date: Sat, 31 Jan 2026 11:32:12 +0100 Subject: [PATCH 3/8] Update tridiagonal_solver.h --- .../LinearAlgebra/Solvers/tridiagonal_solver.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h index 1fe163a1..c0237dc8 100644 --- a/include/LinearAlgebra/Solvers/tridiagonal_solver.h +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -151,12 +151,12 @@ class BatchedTridiagonalSolver int m_batch_offset; int m_batch_stride; - void operator()(const int batch_idx) const + void operator()(const int k) const { // ----------------------------------- // // Obtain offset for the current batch // - int solve_batch = m_batch_stride * batch_idx + m_batch_offset; - int offset = solve_batch * m_matrix_dimension; + int batch_idx = m_batch_stride * k + m_batch_offset; + int offset = batch_idx * m_matrix_dimension; // -------------------- // // Forward Substitution // @@ -186,12 +186,12 @@ class BatchedTridiagonalSolver int m_batch_offset; int m_batch_stride; - void operator()(const int batch_idx) const + void operator()(const int k) const { // ----------------------------------- // // Obtain offset for the current batch // - int solve_batch = m_batch_stride * batch_idx + m_batch_offset; - int offset = solve_batch * m_matrix_dimension; + int batch_idx = m_batch_stride * k + m_batch_offset; + int offset = batch_idx * m_matrix_dimension; // -------------------- // // Forward Substitution // @@ -266,12 +266,12 @@ class BatchedTridiagonalSolver int m_batch_offset; int m_batch_stride; - void operator()(const int batch_idx) const + void operator()(const int k) const { // ----------------------------------- // // Obtain offset for the current batch // - int solve_batch = m_batch_stride * batch_idx + m_batch_offset; - int offset = solve_batch * m_matrix_dimension; + int batch_idx = m_batch_stride * k + m_batch_offset; + int offset = batch_idx * m_matrix_dimension; // ---------------- // // Diagonal Scaling // From 85571f0376e68d11c3748109572ef10564f75977 Mon Sep 17 00:00:00 2001 From: Julian Litz <91479202+julianlitz@users.noreply.github.com> Date: Sat, 31 Jan 2026 21:32:35 +0100 Subject: [PATCH 4/8] Better test cases --- .../Solvers/tridiagonal_solver.cpp | 853 ++++++------------ 1 file changed, 261 insertions(+), 592 deletions(-) diff --git a/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp b/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp index 2643dff4..a708983d 100644 --- a/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp +++ b/tests/LinearAlgebra/Solvers/tridiagonal_solver.cpp @@ -6,632 +6,301 @@ #include "../../../include/LinearAlgebra/Solvers/tridiagonal_solver.h" #include "../../../include/LinearAlgebra/vector.h" -// Helper function to initialize tridiagonal system -template -void initialize_tridiagonal_system(BatchedTridiagonalSolver& solver, Vector& rhs, int batch_idx, int matrix_dim, - bool is_cyclic) +// clang-format off +TEST(BatchedTridiagonalSolvers, non_cyclic_tridiagonal_n_4) { - // Create a simple test system: A*x = b - // Main diagonal = 4, Sub diagonal = -1 - for (int i = 0; i < matrix_dim; i++) { - solver.main_diagonal(batch_idx, i) = 4.0; - if (i < matrix_dim - 1) { - solver.sub_diagonal(batch_idx, i) = -1.0; - } - } - - if (is_cyclic) { - // For cyclic, set the corner element - solver.sub_diagonal(batch_idx, matrix_dim - 1) = -1.0; - } - - // Set RHS to 1.0 for simple test - for (int i = 0; i < matrix_dim; i++) { - rhs(batch_idx * matrix_dim + i) = 1.0; - } -} - -// Helper function to verify solution -// NOTE: This function uses the ORIGINAL matrix values (before setup() modifies them) -// The test matrix has main_diagonal = 4.0 and sub_diagonal = -1.0 -template -bool verify_solution(const Vector& x, int batch_idx, int matrix_dim, bool is_cyclic, T tolerance = 1e-6) -{ - // Reconstruct the original test matrix (main=4, sub=-1) and compute A*x - // NOTE: We use hardcoded values because setup() has already modified the solver's internal state - std::vector Ax(matrix_dim, T(0)); - - for (int i = 0; i < matrix_dim; ++i) { - Ax[i] += T(4.0) * x(batch_idx * matrix_dim + i); - if (i > 0) { - Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i - 1); - } - if (i < matrix_dim - 1) { - Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i + 1); - } - } - - if (is_cyclic) { - Ax[0] += T(-1.0) * x(batch_idx * matrix_dim + matrix_dim - 1); - Ax[matrix_dim - 1] += T(-1.0) * x(batch_idx * matrix_dim + 0); - } - - // Check if Ax ≈ b (original RHS = 1.0) - T max_error = T(0); - for (int i = 0; i < matrix_dim; ++i) { - T error = std::abs(Ax[i] - T(1.0)); - if (error > max_error) - max_error = error; - } - - return max_error < tolerance; -} - -// Helper function to compute actual error for more detailed testing -// NOTE: Uses hardcoded original matrix values (main=4.0, sub=-1.0) -template -T compute_solution_error(const Vector& x, int batch_idx, int matrix_dim, bool is_cyclic) -{ - std::vector Ax(matrix_dim, T(0)); - - for (int i = 0; i < matrix_dim; ++i) { - Ax[i] += T(4.0) * x(batch_idx * matrix_dim + i); - if (i > 0) { - Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i - 1); - } - if (i < matrix_dim - 1) { - Ax[i] += T(-1.0) * x(batch_idx * matrix_dim + i + 1); - } - } - - if (is_cyclic) { - Ax[0] += T(-1.0) * x(batch_idx * matrix_dim + matrix_dim - 1); - Ax[matrix_dim - 1] += T(-1.0) * x(batch_idx * matrix_dim + 0); - } - - T max_error = T(0); - for (int i = 0; i < matrix_dim; ++i) { - T error = std::abs(Ax[i] - T(1.0)); - if (error > max_error) - max_error = error; - } - - return max_error; -} - -// Test fixture for BatchedTridiagonalSolver tests -class BatchedTridiagonalSolverTest : public ::testing::Test -{ -protected: - void SetUp() override - { - // Kokkos initialization is handled in main() - } - - void TearDown() override - { - // Cleanup if needed - } - - // Common test parameters - static constexpr int default_matrix_dim = 10; - static constexpr int default_batch_count = 8; - static constexpr double default_tolerance = 1e-6; -}; - -// ============================================================================ -// Basic Functionality Tests -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, NonCyclicAllBatches) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize all batches - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } - - // Factorize - solver.setup(); - - // Solve all batches - solver.solve(rhs, 0, 1); - - // Verify all batches - for (int b = 0; b < batch_count; b++) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) - << "Batch " << b << " failed verification"; - } -} - -TEST_F(BatchedTridiagonalSolverTest, CyclicAllBatches) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize all batches - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, true); - } - - // Factorize - solver.setup(); - - // Solve all batches - solver.solve(rhs, 0, 1); - - // Verify all batches - for (int b = 0; b < batch_count; b++) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, true, default_tolerance)) - << "Batch " << b << " failed verification"; - } -} - -// ============================================================================ -// Stride and Offset Tests -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, NonCyclicEvenBatchesStride2Offset0) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize all batches - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } - - solver.setup(); - - // Solve only even batches (0, 2, 4, 6) - solver.solve(rhs, 0, 2); - - // Verify only even batches - std::vector even_batches = {0, 2, 4, 6}; - for (int b : even_batches) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) - << "Even batch " << b << " failed verification"; - } -} - -TEST_F(BatchedTridiagonalSolverTest, NonCyclicOddBatchesStride2Offset1) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize all batches - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } - - solver.setup(); - - // Solve only odd batches (1, 3, 5, 7) - solver.solve(rhs, 1, 2); - - // Verify only odd batches - std::vector odd_batches = {1, 3, 5, 7}; - for (int b : odd_batches) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) - << "Odd batch " << b << " failed verification"; - } -} - -TEST_F(BatchedTridiagonalSolverTest, CyclicStride3Offset1) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize all batches - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, true); - } - - solver.setup(); - - // Solve batches with stride 3, offset 1 (1, 4, 7) - solver.solve(rhs, 1, 3); - - // Verify - std::vector batch_indices = {1, 4, 7}; - for (int b : batch_indices) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, true, default_tolerance)) - << "Batch " << b << " failed verification"; - } -} - -// ============================================================================ -// Edge Cases -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, SingleBatchWithOffset) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - const int target_batch = 5; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize all batches - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = false; + + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); + + // System 1: {{2, 1, 0,0},{1,4,2,0},{0,2,6,3},{0,0,3,8}} * {{a},{b},{c},{d}} = {{1},{2},{3},{4}} + // a = 70/209, b = 69/209, c = 36/209, d = 91/209 + // System 2: {{3,1,0,0},{1,5,2,0},{0,2,7,3},{0,0,3,9}} * {{a},{b},{c},{d}} = {{2},{3},{4},{5}} + // a = 29/54, b = 7/18, c = 7/27, d = 38/81 + // System 3: {{4,1,0,0},{1,6,2,0},{0,2,8,3},{0,0,3,10}} * {{a},{b},{c},{d}} = {{3},{4},{5},{6}} + // a = 938/1473, b = 667/1473, c = 476/1473, d = 247/491 + // System 4: {{5,1,0,0},{1,7,2,0},{0,2,9,3},{0,0,3,11}} * {{a},{b},{c},{d}} = {{4},{5},{6},{7}} + // a = 248/355, b = 36/71, c = 267/710, d = 379/710 + + solver.main_diagonal(0,0) = 2.0; solver.sub_diagonal(0,0) = 1.0; + solver.main_diagonal(0,1) = 4.0; solver.sub_diagonal(0,1) = 2.0; + solver.main_diagonal(0,2) = 6.0; solver.sub_diagonal(0,2) = 3.0; + solver.main_diagonal(0,3) = 8.0; + + solver.main_diagonal(1,0) = 3.0; solver.sub_diagonal(1,0) = 1.0; + solver.main_diagonal(1,1) = 5.0; solver.sub_diagonal(1,1) = 2.0; + solver.main_diagonal(1,2) = 7.0; solver.sub_diagonal(1,2) = 3.0; + solver.main_diagonal(1,3) = 9.0; + + solver.main_diagonal(2,0) = 4.0; solver.sub_diagonal(2,0) = 1.0; + solver.main_diagonal(2,1) = 6.0; solver.sub_diagonal(2,1) = 2.0; + solver.main_diagonal(2,2) = 8.0; solver.sub_diagonal(2,2) = 3.0; + solver.main_diagonal(2,3) = 10.0; + + solver.main_diagonal(3,0) = 5.0; solver.sub_diagonal(3,0) = 1.0; + solver.main_diagonal(3,1) = 7.0; solver.sub_diagonal(3,1) = 2.0; + solver.main_diagonal(3,2) = 9.0; solver.sub_diagonal(3,2) = 3.0; + solver.main_diagonal(3,3) = 11.0; + + Vector rhs("rhs", matrix_dimension * batch_count); + + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; solver.setup(); - // Solve only batch 5 - solver.solve(rhs, target_batch, batch_count); - - // Verify only batch 5 - EXPECT_TRUE(verify_solution(rhs, target_batch, matrix_dim, false, default_tolerance)) - << "Single batch " << target_batch << " failed verification"; + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 70.0/209.0, tol); + EXPECT_NEAR(rhs(1), 69.0/209.0, tol); + EXPECT_NEAR(rhs(2), 36.0/209.0, tol); + EXPECT_NEAR(rhs(3), 91.0/209.0, tol); + + EXPECT_NEAR(rhs(4), 29.0/54.0, tol); + EXPECT_NEAR(rhs(5), 7.0/18.0, tol); + EXPECT_NEAR(rhs(6), 7.0/27.0, tol); + EXPECT_NEAR(rhs(7), 38.0/81.0, tol); + + EXPECT_NEAR(rhs(8), 938.0/1473.0, tol); + EXPECT_NEAR(rhs(9), 667.0/1473.0, tol); + EXPECT_NEAR(rhs(10), 476.0/1473.0, tol); + EXPECT_NEAR(rhs(11), 247.0/491.0, tol); + + EXPECT_NEAR(rhs(12), 248.0/355.0, tol); + EXPECT_NEAR(rhs(13), 36.0/71.0, tol); + EXPECT_NEAR(rhs(14), 267.0/710.0, tol); + EXPECT_NEAR(rhs(15), 379.0/710.0, tol); } -TEST_F(BatchedTridiagonalSolverTest, SmallMatrixSize) +TEST(BatchedTridiagonalSolvers, cyclic_tridiagonal_n_4) { - const int matrix_dim = 3; - const int batch_count = 4; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = true; + + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); + + // System 1: {{2, 1, 0,-1},{1,4,2,0},{0,2,6,3},{-1,0,3,8}} * {{a},{b},{c},{d}} = {{1},{2},{3},{4}} + // a = 42/67, b = 18/67, c = 10/67, d = 35/67 + // System 2: {{3,1,0,-2},{1,5,2,0},{0,2,7,3},{-2,0,3,9}} * {{a},{b},{c},{d}} = {{2},{3},{4},{5}} + // a = 287/274, b = 89/274, c = 45/274, d = 201/274 + // System 3: {{4,1,0,-3},{1,6,2,0},{0,2,8,3},{-3,0,3,10}} * {{a},{b},{c},{d}} = {{3},{4},{5},{6}} + // a = 1532/1113, b = 8/21, c = 188/1113, d = 51/53 + // System 4: {{5,1,0,-4},{1,7,2,0},{0,2,9,3},{-4,0,3,11}} * {{a},{b},{c},{d}} = {{4},{5},{6},{7}} + // a = 271/162, b = 23/54, c = 14/81, d = 97/81 + + solver.main_diagonal(0,0) = 2.0; solver.sub_diagonal(0,0) = 1.0; + solver.main_diagonal(0,1) = 4.0; solver.sub_diagonal(0,1) = 2.0; + solver.main_diagonal(0,2) = 6.0; solver.sub_diagonal(0,2) = 3.0; + solver.main_diagonal(0,3) = 8.0; solver.cyclic_corner(0) = -1.0; + + solver.main_diagonal(1,0) = 3.0; solver.sub_diagonal(1,0) = 1.0; + solver.main_diagonal(1,1) = 5.0; solver.sub_diagonal(1,1) = 2.0; + solver.main_diagonal(1,2) = 7.0; solver.sub_diagonal(1,2) = 3.0; + solver.main_diagonal(1,3) = 9.0; solver.cyclic_corner(1) = -2.0; + + solver.main_diagonal(2,0) = 4.0; solver.sub_diagonal(2,0) = 1.0; + solver.main_diagonal(2,1) = 6.0; solver.sub_diagonal(2,1) = 2.0; + solver.main_diagonal(2,2) = 8.0; solver.sub_diagonal(2,2) = 3.0; + solver.main_diagonal(2,3) = 10.0; solver.cyclic_corner(2) = -3.0; + + solver.main_diagonal(3,0) = 5.0; solver.sub_diagonal(3,0) = 1.0; + solver.main_diagonal(3,1) = 7.0; solver.sub_diagonal(3,1) = 2.0; + solver.main_diagonal(3,2) = 9.0; solver.sub_diagonal(3,2) = 3.0; + solver.main_diagonal(3,3) = 11.0; solver.cyclic_corner(3) = -4.0; + + Vector rhs("rhs", matrix_dimension * batch_count); + + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; solver.setup(); - solver.solve(rhs, 0, 1); - for (int b = 0; b < batch_count; b++) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) - << "Small matrix batch " << b << " failed verification"; - } + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 42.0/67.0, tol); + EXPECT_NEAR(rhs(1), 18.0/67.0, tol); + EXPECT_NEAR(rhs(2), 10.0/67.0, tol); + EXPECT_NEAR(rhs(3), 35.0/67.0, tol); + + EXPECT_NEAR(rhs(4), 287.0/274.0, tol); + EXPECT_NEAR(rhs(5), 89.0/274.0, tol); + EXPECT_NEAR(rhs(6), 45.0/274.0, tol); + EXPECT_NEAR(rhs(7), 201.0/274.0, tol); + + EXPECT_NEAR(rhs(8), 1532.0/1113.0, tol); + EXPECT_NEAR(rhs(9), 8.0/21.0, tol); + EXPECT_NEAR(rhs(10), 188.0/1113.0, tol); + EXPECT_NEAR(rhs(11), 51.0/53.0, tol); + + EXPECT_NEAR(rhs(12), 271.0/162.0, tol); + EXPECT_NEAR(rhs(13), 23.0/54.0, tol); + EXPECT_NEAR(rhs(14), 14.0/81.0, tol); + EXPECT_NEAR(rhs(15), 97.0/81.0, tol); } -TEST_F(BatchedTridiagonalSolverTest, SingleBatchSingleElement) +TEST(BatchedTridiagonalSolvers, non_cyclic_diagonal_n_4) { - const int matrix_dim = 1; - const int batch_count = 1; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // For 1x1 matrix, just set main diagonal - solver.main_diagonal(0, 0) = 4.0; - rhs(0) = 1.0; - - solver.setup(); - solver.solve(rhs, 0, 1); - - // Solution should be 1.0 / 4.0 = 0.25 - EXPECT_NEAR(rhs(0), 0.25, default_tolerance); -} - -// ============================================================================ -// Error Handling Tests -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, SolveBeforeSetupThrows) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // Try to solve without calling setup() first - EXPECT_THROW({ solver.solve(rhs, 0, 1); }, std::runtime_error); -} + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = false; -TEST_F(BatchedTridiagonalSolverTest, DiagonalSolveBeforeSetupThrows) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); + solver.main_diagonal(0,0) = 2.0; + solver.main_diagonal(0,1) = 4.0; + solver.main_diagonal(0,2) = 6.0; + solver.main_diagonal(0,3) = 8.0; - // Try to solve_diagonal without calling setup() first - EXPECT_THROW({ solver.solve_diagonal(rhs, 0, 1); }, std::runtime_error); -} + solver.main_diagonal(1,0) = 3.0; + solver.main_diagonal(1,1) = 5.0; + solver.main_diagonal(1,2) = 7.0; + solver.main_diagonal(1,3) = 9.0; -// ============================================================================ -// Diagonal Solve Tests -// ============================================================================ + solver.main_diagonal(2,0) = 4.0; + solver.main_diagonal(2,1) = 6.0; + solver.main_diagonal(2,2) = 8.0; + solver.main_diagonal(2,3) = 10.0; -TEST_F(BatchedTridiagonalSolverTest, DiagonalSolveAllBatches) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; + solver.main_diagonal(3,0) = 5.0; + solver.main_diagonal(3,1) = 7.0; + solver.main_diagonal(3,2) = 9.0; + solver.main_diagonal(3,3) = 11.0; - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); + Vector rhs("rhs", matrix_dimension * batch_count); - // Initialize with diagonal-only system - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - solver.main_diagonal(b, i) = 2.0; - rhs(b * matrix_dim + i) = 4.0; - } - } + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; solver.setup(); - solver.solve_diagonal(rhs, 0, 1); - - // Each solution element should be 4.0 / 2.0 = 2.0 - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - EXPECT_NEAR(rhs(b * matrix_dim + i), 2.0, default_tolerance) - << "Diagonal solve failed at batch " << b << ", index " << i; - } - } -} - -TEST_F(BatchedTridiagonalSolverTest, DiagonalSolveWithStride) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - // Initialize with diagonal-only system - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - solver.main_diagonal(b, i) = 3.0; - rhs(b * matrix_dim + i) = 6.0; - } - } - solver.setup(); - solver.solve_diagonal(rhs, 1, 2); // Solve odd batches only - - // Check odd batches: should be 6.0 / 3.0 = 2.0 - std::vector odd_batches = {1, 3, 5, 7}; - for (int b : odd_batches) { - for (int i = 0; i < matrix_dim; i++) { - EXPECT_NEAR(rhs(b * matrix_dim + i), 2.0, default_tolerance) - << "Diagonal solve with stride failed at batch " << b << ", index " << i; - } - } + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 1.0/2.0, tol); + EXPECT_NEAR(rhs(1), 2.0/4.0, tol); + EXPECT_NEAR(rhs(2), 3.0/6.0, tol); + EXPECT_NEAR(rhs(3), 4.0/8.0, tol); + + EXPECT_NEAR(rhs(4), 2.0/3.0, tol); + EXPECT_NEAR(rhs(5), 3.0/5.0, tol); + EXPECT_NEAR(rhs(6), 4.0/7.0, tol); + EXPECT_NEAR(rhs(7), 5.0/9.0, tol); + + EXPECT_NEAR(rhs(8), 3.0/4.0, tol); + EXPECT_NEAR(rhs(9), 4.0/6.0, tol); + EXPECT_NEAR(rhs(10), 5.0/8.0, tol); + EXPECT_NEAR(rhs(11), 6.0/10.0, tol); + + EXPECT_NEAR(rhs(12), 4.0/5.0, tol); + EXPECT_NEAR(rhs(13), 5.0/7.0, tol); + EXPECT_NEAR(rhs(14), 6.0/9.0, tol); + EXPECT_NEAR(rhs(15), 7.0/11.0, tol); } -// ============================================================================ -// Numerical Accuracy Tests -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, AccuracyNonCyclic) +TEST(BatchedTridiagonalSolvers, cyclic_diagonal_n_4) { - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; + int batch_count = 4; + int matrix_dimension = 4; + bool is_cyclic = true; - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); + BatchedTridiagonalSolver solver(matrix_dimension, batch_count, is_cyclic); - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } + solver.main_diagonal(0,0) = 2.0; + solver.main_diagonal(0,1) = 4.0; + solver.main_diagonal(0,2) = 6.0; + solver.main_diagonal(0,3) = 8.0; - solver.setup(); - solver.solve(rhs, 0, 1); + solver.main_diagonal(1,0) = 3.0; + solver.main_diagonal(1,1) = 5.0; + solver.main_diagonal(1,2) = 7.0; + solver.main_diagonal(1,3) = 9.0; - // Check that error is well below tolerance - for (int b = 0; b < batch_count; b++) { - double error = compute_solution_error(rhs, b, matrix_dim, false); - EXPECT_LT(error, 1e-10) << "Error too large for batch " << b; - } -} + solver.main_diagonal(2,0) = 4.0; + solver.main_diagonal(2,1) = 6.0; + solver.main_diagonal(2,2) = 8.0; + solver.main_diagonal(2,3) = 10.0; -TEST_F(BatchedTridiagonalSolverTest, AccuracyCyclic) -{ - const int matrix_dim = default_matrix_dim; - const int batch_count = default_batch_count; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); - Vector rhs("rhs", matrix_dim * batch_count); + solver.main_diagonal(3,0) = 5.0; + solver.main_diagonal(3,1) = 7.0; + solver.main_diagonal(3,2) = 9.0; + solver.main_diagonal(3,3) = 11.0; - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, true); - } + Vector rhs("rhs", matrix_dimension * batch_count); - solver.setup(); - solver.solve(rhs, 0, 1); + // Initialize RHS for each system + rhs(0) = 1.0; rhs(1) = 2.0; rhs(2) = 3.0; rhs(3) = 4.0; + rhs(4) = 2.0; rhs(5) = 3.0; rhs(6) = 4.0; rhs(7) = 5.0; + rhs(8) = 3.0; rhs(9) = 4.0; rhs(10) = 5.0; rhs(11) = 6.0; + rhs(12) = 4.0; rhs(13) = 5.0; rhs(14) = 6.0; rhs(15) = 7.0; - // Check that error is well below tolerance - for (int b = 0; b < batch_count; b++) { - double error = compute_solution_error(rhs, b, matrix_dim, true); - EXPECT_LT(error, 1e-10) << "Error too large for cyclic batch " << b; - } -} - -// ============================================================================ -// Performance/Stress Tests -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, LargeBatchCount) -{ - const int matrix_dim = default_matrix_dim; - const int large_batch = 1000; - - BatchedTridiagonalSolver solver(matrix_dim, large_batch, false); - Vector rhs("rhs", matrix_dim * large_batch); - - // Initialize - for (int b = 0; b < large_batch; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } - - // Factorize and solve - solver.setup(); - solver.solve(rhs, 0, 1); - - // Verify a few random batches - std::vector test_batches = {0, 250, 500, 750, 999}; - for (int b : test_batches) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) - << "Large batch test failed at batch " << b; - } -} - -TEST_F(BatchedTridiagonalSolverTest, LargeMatrixDimension) -{ - const int matrix_dim = 100; - const int batch_count = 4; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - Vector rhs("rhs", matrix_dim * batch_count); - - for (int b = 0; b < batch_count; b++) { - initialize_tridiagonal_system(solver, rhs, b, matrix_dim, false); - } - - solver.setup(); - solver.solve(rhs, 0, 1); - - for (int b = 0; b < batch_count; b++) { - EXPECT_TRUE(verify_solution(rhs, b, matrix_dim, false, default_tolerance)) - << "Large matrix dimension test failed at batch " << b; - } -} - -// ============================================================================ -// Accessor Tests -// ============================================================================ - -TEST_F(BatchedTridiagonalSolverTest, MainDiagonalAccessors) -{ - const int matrix_dim = 5; - const int batch_count = 2; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - - // Test write access - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - solver.main_diagonal(b, i) = static_cast(b * matrix_dim + i); - } - } - - // Test read access BEFORE setup() - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - EXPECT_DOUBLE_EQ(solver.main_diagonal(b, i), static_cast(b * matrix_dim + i)) - << "Main diagonal accessor failed at batch " << b << ", index " << i; - } - } -} - -TEST_F(BatchedTridiagonalSolverTest, SubDiagonalAccessors) -{ - const int matrix_dim = 5; - const int batch_count = 2; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - - // Test write access - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - solver.sub_diagonal(b, i) = static_cast(100 + b * matrix_dim + i); - } - } - - // Test read access BEFORE setup() - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - EXPECT_DOUBLE_EQ(solver.sub_diagonal(b, i), static_cast(100 + b * matrix_dim + i)) - << "Sub diagonal accessor failed at batch " << b << ", index " << i; - } - } -} - -TEST_F(BatchedTridiagonalSolverTest, CyclicCornerAccessors) -{ - const int matrix_dim = 5; - const int batch_count = 2; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, true); - - // Test write access via cyclic_corner - for (int b = 0; b < batch_count; b++) { - solver.cyclic_corner(b) = static_cast(200 + b); - } - - // Test read access via cyclic_corner BEFORE setup() - for (int b = 0; b < batch_count; b++) { - EXPECT_DOUBLE_EQ(solver.cyclic_corner(b), static_cast(200 + b)) - << "Cyclic corner accessor failed at batch " << b; - } - - // Verify that cyclic_corner actually accesses sub_diagonal at the right location - for (int b = 0; b < batch_count; b++) { - EXPECT_DOUBLE_EQ(solver.cyclic_corner(b), solver.sub_diagonal(b, matrix_dim - 1)) - << "Cyclic corner should access sub_diagonal at index matrix_dim-1"; - } -} - -TEST_F(BatchedTridiagonalSolverTest, SetupModifiesInternalState) -{ - // This test verifies that setup() modifies the internal diagonal values - const int matrix_dim = 5; - const int batch_count = 2; - - BatchedTridiagonalSolver solver(matrix_dim, batch_count, false); - - // Set up initial values - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - solver.main_diagonal(b, i) = 4.0; - if (i < matrix_dim - 1) { - solver.sub_diagonal(b, i) = -1.0; - } - } - } - - // Store original values for comparison - std::vector original_main(matrix_dim * batch_count); - std::vector original_sub(matrix_dim * batch_count); - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - original_main[b * matrix_dim + i] = solver.main_diagonal(b, i); - original_sub[b * matrix_dim + i] = solver.sub_diagonal(b, i); - } - } - - // Call setup() - this performs Cholesky factorization solver.setup(); - // Verify that values have been modified - bool main_changed = false; - bool sub_changed = false; - for (int b = 0; b < batch_count; b++) { - for (int i = 0; i < matrix_dim; i++) { - if (std::abs(solver.main_diagonal(b, i) - original_main[b * matrix_dim + i]) > 1e-10) { - main_changed = true; - } - if (std::abs(solver.sub_diagonal(b, i) - original_sub[b * matrix_dim + i]) > 1e-10) { - sub_changed = true; - } - } - } - - EXPECT_TRUE(main_changed) << "setup() should modify main diagonal values (Cholesky factorization)"; - EXPECT_TRUE(sub_changed) << "setup() should modify sub diagonal values (Cholesky factorization)"; + int offset, stride; + // Solve each even system separately + offset = 0; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + // Solve each odd system separately + offset = 1; stride = 2; + solver.solve_diagonal(rhs, offset, stride); + + // Verify solutions + double tol = 1e-12; + + EXPECT_NEAR(rhs(0), 1.0/2.0, tol); + EXPECT_NEAR(rhs(1), 2.0/4.0, tol); + EXPECT_NEAR(rhs(2), 3.0/6.0, tol); + EXPECT_NEAR(rhs(3), 4.0/8.0, tol); + + EXPECT_NEAR(rhs(4), 2.0/3.0, tol); + EXPECT_NEAR(rhs(5), 3.0/5.0, tol); + EXPECT_NEAR(rhs(6), 4.0/7.0, tol); + EXPECT_NEAR(rhs(7), 5.0/9.0, tol); + + EXPECT_NEAR(rhs(8), 3.0/4.0, tol); + EXPECT_NEAR(rhs(9), 4.0/6.0, tol); + EXPECT_NEAR(rhs(10), 5.0/8.0, tol); + EXPECT_NEAR(rhs(11), 6.0/10.0, tol); + + EXPECT_NEAR(rhs(12), 4.0/5.0, tol); + EXPECT_NEAR(rhs(13), 5.0/7.0, tol); + EXPECT_NEAR(rhs(14), 6.0/9.0, tol); + EXPECT_NEAR(rhs(15), 7.0/11.0, tol); } From 4b9b659721655c8e021ad9e29795e2d337d5b7c1 Mon Sep 17 00:00:00 2001 From: Julian Litz <91479202+julianlitz@users.noreply.github.com> Date: Sat, 31 Jan 2026 21:33:46 +0100 Subject: [PATCH 5/8] Diagonal solver support cyclic cases --- .../Solvers/tridiagonal_solver.h | 39 ++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h index c0237dc8..0887ffd6 100644 --- a/include/LinearAlgebra/Solvers/tridiagonal_solver.h +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -259,7 +259,7 @@ class BatchedTridiagonalSolver // It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries. // Note that .setup() doesn't modify the main diagonal in this case. - struct SolveDiagonal { + struct SolveDiagonalNonCyclic { int m_matrix_dimension; Vector m_main_diagonal; Vector m_rhs; @@ -281,6 +281,30 @@ class BatchedTridiagonalSolver } }; + struct SolveDiagonalCyclic { + int m_matrix_dimension; + Vector m_main_diagonal; + Vector m_gamma; + Vector m_rhs; + int m_batch_offset; + int m_batch_stride; + + void operator()(const int k) const + { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = m_batch_stride * k + m_batch_offset; + int offset = batch_idx * m_matrix_dimension; + + // ---------------- // + // Diagonal Scaling // + m_rhs(offset + 0) /= m_main_diagonal(offset + 0) + m_gamma(batch_idx); + for (int i = 1; i < m_matrix_dimension; i++) { + m_rhs(offset + i) /= m_main_diagonal(offset + i); + } + } + }; + void solve_diagonal(Vector rhs, int batch_offset = 0, int batch_stride = 1) { if (!is_factorized_) { @@ -290,9 +314,14 @@ class BatchedTridiagonalSolver // Compute the effective number of batches to solve int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; - SolveDiagonal functor{matrix_dimension_, main_diagonal_, rhs, batch_offset, batch_stride}; - Kokkos::parallel_for("SolveDiagonal", effective_batch_count, functor); - + if (!is_cyclic_) { + SolveDiagonalNonCyclic functor{matrix_dimension_, main_diagonal_, rhs, batch_offset, batch_stride}; + Kokkos::parallel_for("SolveDiagonalNonCyclic", effective_batch_count, functor); + } + else { + SolveDiagonalCyclic functor{matrix_dimension_, main_diagonal_, gamma_, rhs, batch_offset, batch_stride}; + Kokkos::parallel_for("SolveDiagonalCyclic", effective_batch_count, functor); + } Kokkos::fence(); } @@ -307,4 +336,4 @@ class BatchedTridiagonalSolver bool is_cyclic_; bool is_factorized_; -}; \ No newline at end of file +}; From 4c172a15c0367ae9d7bb6e2ffa2b61e6bdab7fb7 Mon Sep 17 00:00:00 2001 From: Julian Litz <91479202+julianlitz@users.noreply.github.com> Date: Sun, 1 Feb 2026 10:42:43 +0100 Subject: [PATCH 6/8] Update tridiagonal_solver.h --- include/LinearAlgebra/Solvers/tridiagonal_solver.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h index 0887ffd6..be22dab9 100644 --- a/include/LinearAlgebra/Solvers/tridiagonal_solver.h +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -257,7 +257,7 @@ class BatchedTridiagonalSolver /* ---------------------------- */ // This step performs only the diagonal scaling part of the solve process. // It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries. - // Note that .setup() doesn't modify the main diagonal in this case. + // Note that .setup() modifies main_diagonal(0) in the cyclic case. struct SolveDiagonalNonCyclic { int m_matrix_dimension; From 5696ede06295ba671f49448e27091f9bf6fa3727 Mon Sep 17 00:00:00 2001 From: julianlitz Date: Mon, 2 Feb 2026 16:53:11 +0100 Subject: [PATCH 7/8] Prefer lamda over struct --- .../Solvers/tridiagonal_solver.h | 362 ++++++++---------- 1 file changed, 156 insertions(+), 206 deletions(-) diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h index be22dab9..cc747d5c 100644 --- a/include/LinearAlgebra/Solvers/tridiagonal_solver.h +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -3,6 +3,7 @@ #include #include "../vector.h" +#include "../vector_operations.h" template class BatchedTridiagonalSolver @@ -18,8 +19,8 @@ class BatchedTridiagonalSolver , is_cyclic_(is_cyclic) , is_factorized_(false) { - Kokkos::deep_copy(main_diagonal_, T(0)); - Kokkos::deep_copy(sub_diagonal_, T(0)); + assign(main_diagonal_, T(0)); + assign(sub_diagonal_, T(0)); } /* ---------------------------- */ @@ -66,72 +67,56 @@ class BatchedTridiagonalSolver // This step factorizes the tridiagonal matrix into lower triangular (L) and diagonal (D) matrices. // For cyclic systems, it also applies the Shermann-Morrison adjustment to account for the cyclic connection. - struct SetupNonCyclic { - int m_matrix_dimension; - Vector m_main_diagonal; - Vector m_sub_diagonal; - - void operator()(const int batch_idx) const - { - // ----------------------------------- // - // Obtain offset for the current batch // - int offset = batch_idx * m_matrix_dimension; - - // ---------------------- // - // Cholesky Decomposition // - for (int i = 1; i < m_matrix_dimension; i++) { - m_sub_diagonal(offset + i - 1) /= m_main_diagonal(offset + i - 1); - const T factor = m_sub_diagonal(offset + i - 1); - m_main_diagonal(offset + i) -= factor * factor * m_main_diagonal(offset + i - 1); - } - } - }; - - struct SetupCyclic { - int m_matrix_dimension; - Vector m_main_diagonal; - Vector m_sub_diagonal; - Vector m_gamma; - - void operator()(const int batch_idx) const - { - // ----------------------------------- // - // Obtain offset for the current batch // - int offset = batch_idx * m_matrix_dimension; - - // ------------------------------------------------- // - // Shermann-Morrison Adjustment // - // - Modify the first and last main diagonal element // - // - Compute and store gamma for later use // - // ------------------------------------------------- // - T cyclic_corner_element = m_sub_diagonal(offset + m_matrix_dimension - 1); - /* gamma_ = -main_diagonal(0);*/ - m_gamma(batch_idx) = -m_main_diagonal(offset + 0); - /* main_diagonal(0) -= gamma_;*/ - m_main_diagonal(offset + 0) -= m_gamma(batch_idx); - /* main_diagonal(matrix_dimension_ - 1) -= cyclic_corner_element()^2 / gamma_;*/ - m_main_diagonal(offset + m_matrix_dimension - 1) -= - cyclic_corner_element * cyclic_corner_element / m_gamma(batch_idx); - - // ---------------------- // - // Cholesky Decomposition // - for (int i = 1; i < m_matrix_dimension; i++) { - m_sub_diagonal(offset + i - 1) /= m_main_diagonal(offset + i - 1); - const T factor = m_sub_diagonal(offset + i - 1); - m_main_diagonal(offset + i) -= factor * factor * m_main_diagonal(offset + i - 1); - } - } - }; - void setup() { + // Create local copies for lambda capture + int matrix_dimension = matrix_dimension_; + Vector main_diagonal = main_diagonal_; + Vector sub_diagonal = sub_diagonal_; + Vector gamma = gamma_; + if (!is_cyclic_) { - SetupNonCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_}; - Kokkos::parallel_for("SetupNonCyclic", batch_count_, functor); + Kokkos::parallel_for( + "SetupNonCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) { + // ----------------------------------- // + // Obtain offset for the current batch // + int offset = batch_idx * matrix_dimension; + + // ---------------------- // + // Cholesky Decomposition // + for (int i = 1; i < matrix_dimension; i++) { + sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1); + const T factor = sub_diagonal(offset + i - 1); + main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1); + } + }); } else { - SetupCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, gamma_}; - Kokkos::parallel_for("SetupCyclic", batch_count_, functor); + Kokkos::parallel_for( + "SetupCyclic", batch_count_, KOKKOS_LAMBDA(const int batch_idx) { + // ----------------------------------- // + // Obtain offset for the current batch // + int offset = batch_idx * matrix_dimension; + + // ------------------------------------------------- // + // Shermann-Morrison Adjustment // + // - Modify the first and last main diagonal element // + // - Compute and store gamma for later use // + // ------------------------------------------------- // + T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1); + gamma(batch_idx) = -main_diagonal(offset + 0); + main_diagonal(offset + 0) -= gamma(batch_idx); + main_diagonal(offset + matrix_dimension - 1) -= + cyclic_corner_element * cyclic_corner_element / gamma(batch_idx); + + // ---------------------- // + // Cholesky Decomposition // + for (int i = 1; i < matrix_dimension; i++) { + sub_diagonal(offset + i - 1) /= main_diagonal(offset + i - 1); + const T factor = sub_diagonal(offset + i - 1); + main_diagonal(offset + i) -= factor * factor * main_diagonal(offset + i - 1); + } + }); } Kokkos::fence(); is_factorized_ = true; @@ -143,94 +128,6 @@ class BatchedTridiagonalSolver // This step solves the system Ax = b using the factorized form of A. // For cyclic systems, it also performs the Shermann-Morrison reconstruction to obtain the final solution. - struct SolveNonCyclic { - int m_matrix_dimension; - Vector m_main_diagonal; - Vector m_sub_diagonal; - Vector m_rhs; - int m_batch_offset; - int m_batch_stride; - - void operator()(const int k) const - { - // ----------------------------------- // - // Obtain offset for the current batch // - int batch_idx = m_batch_stride * k + m_batch_offset; - int offset = batch_idx * m_matrix_dimension; - - // -------------------- // - // Forward Substitution // - for (int i = 1; i < m_matrix_dimension; i++) { - m_rhs(offset + i) -= m_sub_diagonal(offset + i - 1) * m_rhs(offset + i - 1); - } - // ---------------- // - // Diagonal Scaling // - for (int i = 0; i < m_matrix_dimension; i++) { - m_rhs(offset + i) /= m_main_diagonal(offset + i); - } - // --------------------- // - // Backward Substitution // - for (int i = m_matrix_dimension - 2; i >= 0; i--) { - m_rhs(offset + i) -= m_sub_diagonal(offset + i) * m_rhs(offset + i + 1); - } - } - }; - - struct SolveCyclic { - int m_matrix_dimension; - Vector m_main_diagonal; - Vector m_sub_diagonal; - Vector m_buffer; - Vector m_gamma; - Vector m_rhs; - int m_batch_offset; - int m_batch_stride; - - void operator()(const int k) const - { - // ----------------------------------- // - // Obtain offset for the current batch // - int batch_idx = m_batch_stride * k + m_batch_offset; - int offset = batch_idx * m_matrix_dimension; - - // -------------------- // - // Forward Substitution // - T cyclic_corner_element = m_sub_diagonal(offset + m_matrix_dimension - 1); - m_buffer(offset + 0) = m_gamma(batch_idx); - for (int i = 1; i < m_matrix_dimension; i++) { - m_rhs(offset + i) -= m_sub_diagonal(offset + i - 1) * m_rhs(offset + i - 1); - if (i < m_matrix_dimension - 1) - m_buffer(offset + i) = 0.0 - m_sub_diagonal(offset + i - 1) * m_buffer(offset + i - 1); - else - m_buffer(offset + i) = - cyclic_corner_element - m_sub_diagonal(offset + i - 1) * m_buffer(offset + i - 1); - } - // ---------------- // - // Diagonal Scaling // - for (int i = 0; i < m_matrix_dimension; i++) { - m_rhs(offset + i) /= m_main_diagonal(offset + i); - m_buffer(offset + i) /= m_main_diagonal(offset + i); - } - // --------------------- // - // Backward Substitution // - for (int i = m_matrix_dimension - 2; i >= 0; i--) { - m_rhs(offset + i) -= m_sub_diagonal(offset + i) * m_rhs(offset + i + 1); - m_buffer(offset + i) -= m_sub_diagonal(offset + i) * m_buffer(offset + i + 1); - } - // ------------------------------- // - // Shermann-Morrison Reonstruction // - const T dot_product_x_v = - m_rhs(offset + 0) + cyclic_corner_element / m_gamma(batch_idx) * m_rhs(offset + m_matrix_dimension - 1); - const T dot_product_u_v = m_buffer(offset + 0) + cyclic_corner_element / m_gamma(batch_idx) * - m_buffer(offset + m_matrix_dimension - 1); - const T factor = dot_product_x_v / (1.0 + dot_product_u_v); - - for (int i = 0; i < m_matrix_dimension; i++) { - m_rhs(offset + i) -= factor * m_buffer(offset + i); - } - } - }; - void solve(Vector rhs, int batch_offset = 0, int batch_stride = 1) { if (!is_factorized_) { @@ -240,14 +137,85 @@ class BatchedTridiagonalSolver // Compute the effective number of batches to solve int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; + // Create local copies for lambda capture + int matrix_dimension = matrix_dimension_; + Vector main_diagonal = main_diagonal_; + Vector sub_diagonal = sub_diagonal_; + Vector buffer = buffer_; + Vector gamma = gamma_; + if (!is_cyclic_) { - SolveNonCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, rhs, batch_offset, batch_stride}; - Kokkos::parallel_for("SolveNonCyclic", effective_batch_count, functor); + Kokkos::parallel_for( + "SolveNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // -------------------- // + // Forward Substitution // + for (int i = 1; i < matrix_dimension; i++) { + rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); + } + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + } + // --------------------- // + // Backward Substitution // + for (int i = matrix_dimension - 2; i >= 0; i--) { + rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1); + } + }); } else { - SolveCyclic functor{matrix_dimension_, main_diagonal_, sub_diagonal_, buffer_, gamma_, rhs, - batch_offset, batch_stride}; - Kokkos::parallel_for("SolveCyclic", effective_batch_count, functor); + Kokkos::parallel_for( + "SolveCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // -------------------- // + // Forward Substitution // + T cyclic_corner_element = sub_diagonal(offset + matrix_dimension - 1); + buffer(offset + 0) = gamma(batch_idx); + for (int i = 1; i < matrix_dimension; i++) { + rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); + if (i < matrix_dimension - 1) + buffer(offset + i) = 0.0 - sub_diagonal(offset + i - 1) * buffer(offset + i - 1); + else + buffer(offset + i) = + cyclic_corner_element - sub_diagonal(offset + i - 1) * buffer(offset + i - 1); + } + + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + buffer(offset + i) /= main_diagonal(offset + i); + } + + // --------------------- // + // Backward Substitution // + for (int i = matrix_dimension - 2; i >= 0; i--) { + rhs(offset + i) -= sub_diagonal(offset + i) * rhs(offset + i + 1); + buffer(offset + i) -= sub_diagonal(offset + i) * buffer(offset + i + 1); + } + + // ------------------------------- // + // Shermann-Morrison Reonstruction // + const T dot_product_x_v = + rhs(offset + 0) + cyclic_corner_element / gamma(batch_idx) * rhs(offset + matrix_dimension - 1); + const T dot_product_u_v = buffer(offset + 0) + cyclic_corner_element / gamma(batch_idx) * + buffer(offset + matrix_dimension - 1); + const T factor = dot_product_x_v / (1.0 + dot_product_u_v); + + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) -= factor * buffer(offset + i); + } + }); } Kokkos::fence(); } @@ -259,52 +227,6 @@ class BatchedTridiagonalSolver // It is useful when the matrix has a non-zero diagonal but zero off-diagonal entries. // Note that .setup() modifies main_diagonal(0) in the cyclic case. - struct SolveDiagonalNonCyclic { - int m_matrix_dimension; - Vector m_main_diagonal; - Vector m_rhs; - int m_batch_offset; - int m_batch_stride; - - void operator()(const int k) const - { - // ----------------------------------- // - // Obtain offset for the current batch // - int batch_idx = m_batch_stride * k + m_batch_offset; - int offset = batch_idx * m_matrix_dimension; - - // ---------------- // - // Diagonal Scaling // - for (int i = 0; i < m_matrix_dimension; i++) { - m_rhs(offset + i) /= m_main_diagonal(offset + i); - } - } - }; - - struct SolveDiagonalCyclic { - int m_matrix_dimension; - Vector m_main_diagonal; - Vector m_gamma; - Vector m_rhs; - int m_batch_offset; - int m_batch_stride; - - void operator()(const int k) const - { - // ----------------------------------- // - // Obtain offset for the current batch // - int batch_idx = m_batch_stride * k + m_batch_offset; - int offset = batch_idx * m_matrix_dimension; - - // ---------------- // - // Diagonal Scaling // - m_rhs(offset + 0) /= m_main_diagonal(offset + 0) + m_gamma(batch_idx); - for (int i = 1; i < m_matrix_dimension; i++) { - m_rhs(offset + i) /= m_main_diagonal(offset + i); - } - } - }; - void solve_diagonal(Vector rhs, int batch_offset = 0, int batch_stride = 1) { if (!is_factorized_) { @@ -314,13 +236,41 @@ class BatchedTridiagonalSolver // Compute the effective number of batches to solve int effective_batch_count = (batch_count_ - batch_offset + batch_stride - 1) / batch_stride; + // Create local copies for lambda capture + int matrix_dimension = matrix_dimension_; + Vector main_diagonal = main_diagonal_; + Vector gamma = gamma_; + if (!is_cyclic_) { - SolveDiagonalNonCyclic functor{matrix_dimension_, main_diagonal_, rhs, batch_offset, batch_stride}; - Kokkos::parallel_for("SolveDiagonalNonCyclic", effective_batch_count, functor); + Kokkos::parallel_for( + "SolveDiagonalNonCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // ---------------- // + // Diagonal Scaling // + for (int i = 0; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + } + }); } else { - SolveDiagonalCyclic functor{matrix_dimension_, main_diagonal_, gamma_, rhs, batch_offset, batch_stride}; - Kokkos::parallel_for("SolveDiagonalCyclic", effective_batch_count, functor); + Kokkos::parallel_for( + "SolveDiagonalCyclic", effective_batch_count, KOKKOS_LAMBDA(const int k) { + // ----------------------------------- // + // Obtain offset for the current batch // + int batch_idx = batch_stride * k + batch_offset; + int offset = batch_idx * matrix_dimension; + + // ---------------- // + // Diagonal Scaling // + rhs(offset + 0) /= main_diagonal(offset + 0) + gamma(batch_idx); + for (int i = 1; i < matrix_dimension; i++) { + rhs(offset + i) /= main_diagonal(offset + i); + } + }); } Kokkos::fence(); } From 0602d43bcb4b0f0792012feea1470d0069ba068e Mon Sep 17 00:00:00 2001 From: julianlitz Date: Mon, 2 Feb 2026 16:54:39 +0100 Subject: [PATCH 8/8] add line breaks --- include/LinearAlgebra/Solvers/tridiagonal_solver.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/LinearAlgebra/Solvers/tridiagonal_solver.h b/include/LinearAlgebra/Solvers/tridiagonal_solver.h index cc747d5c..091cf89b 100644 --- a/include/LinearAlgebra/Solvers/tridiagonal_solver.h +++ b/include/LinearAlgebra/Solvers/tridiagonal_solver.h @@ -157,11 +157,13 @@ class BatchedTridiagonalSolver for (int i = 1; i < matrix_dimension; i++) { rhs(offset + i) -= sub_diagonal(offset + i - 1) * rhs(offset + i - 1); } + // ---------------- // // Diagonal Scaling // for (int i = 0; i < matrix_dimension; i++) { rhs(offset + i) /= main_diagonal(offset + i); } + // --------------------- // // Backward Substitution // for (int i = matrix_dimension - 2; i >= 0; i--) {