diff --git a/CHANGELOG.md b/CHANGELOG.md index f688ce34..98cd0f40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,3 +67,5 @@ It is seamless from the user perspective and fixed many bugs. 15. Added examples/sysGmres.cpp to demonstrate how to use SystemSolver with GMRES. 16. Updated MatrixHandler::addConst to return integer error codes instead of void. + +17. Added a preconditioner interface class so users can define thier own preconditioners. diff --git a/examples/experimental/r_KLU_rf_FGMRES_reuse_factorization.cpp b/examples/experimental/r_KLU_rf_FGMRES_reuse_factorization.cpp index 35baf7dc..3baafd63 100644 --- a/examples/experimental/r_KLU_rf_FGMRES_reuse_factorization.cpp +++ b/examples/experimental/r_KLU_rf_FGMRES_reuse_factorization.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -185,7 +186,8 @@ int main(int argc, char* argv[]) << status << std::endl; vec_rhs->copyDataFrom(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE); status = Rf->solve(vec_rhs, vec_x); - FGMRES->setupPreconditioner("LU", Rf); + ReSolve::PreconditionerLU precond_lu(Rf); + FGMRES->setPreconditioner(&precond_lu); } // if (i%2!=0) vec_x->setToZero(ReSolve::memory::DEVICE); real_type norm_x = vector_handler->dot(vec_x, vec_x, ReSolve::memory::DEVICE); diff --git a/examples/gpuRefactor.cpp b/examples/gpuRefactor.cpp index 90f5ee06..dd7d26c2 100644 --- a/examples/gpuRefactor.cpp +++ b/examples/gpuRefactor.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -306,7 +307,8 @@ int gpuRefactor(int argc, char* argv[]) { // Setup iterative refinement FGMRES.resetMatrix(A); - FGMRES.setupPreconditioner("LU", &Rf); + ReSolve::PreconditionerLU precond_lu(&Rf); + FGMRES.setPreconditioner(&precond_lu); // If refactorization produced finite solution do iterative refinement if (std::isfinite(helper.getNormRelativeResidual())) diff --git a/examples/kluFactor.cpp b/examples/kluFactor.cpp index 78ad54d5..090ec127 100644 --- a/examples/kluFactor.cpp +++ b/examples/kluFactor.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -191,7 +192,8 @@ int main(int argc, char* argv[]) { // Setup iterative refinement FGMRES.setup(A); - FGMRES.setupPreconditioner("LU", &KLU); + ReSolve::PreconditionerLU precond_lu(&KLU); + FGMRES.setPreconditioner(&precond_lu); // If refactorization produced finite solution do iterative refinement if (std::isfinite(helper.getNormRelativeResidual())) diff --git a/examples/kluRefactor.cpp b/examples/kluRefactor.cpp index ddf188d4..992fc70f 100644 --- a/examples/kluRefactor.cpp +++ b/examples/kluRefactor.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -197,7 +198,8 @@ int main(int argc, char* argv[]) { // Setup iterative refinement FGMRES.setup(A); - FGMRES.setupPreconditioner("LU", KLU); + ReSolve::PreconditionerLU precond_lu(KLU); + FGMRES.setPreconditioner(&precond_lu); // If refactorization produced finite solution do iterative refinement if (std::isfinite(helper.getNormRelativeResidual())) diff --git a/examples/randGmres.cpp b/examples/randGmres.cpp index 629cc264..18845a07 100644 --- a/examples/randGmres.cpp +++ b/examples/randGmres.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -177,7 +178,8 @@ int runGmresExample(int argc, char* argv[]) FGMRES.setup(A); FGMRES.resetMatrix(A); - FGMRES.setupPreconditioner("LU", &Precond); + ReSolve::PreconditionerLU precond_lu(&Precond); + FGMRES.setPreconditioner(&precond_lu); FGMRES.setFlexible(1); FGMRES.solve(vec_rhs, vec_x); diff --git a/resolve/CMakeLists.txt b/resolve/CMakeLists.txt index ad45cf41..07dbf7c2 100644 --- a/resolve/CMakeLists.txt +++ b/resolve/CMakeLists.txt @@ -20,6 +20,8 @@ set(ReSolve_SRC LinSolverIterativeRandFGMRES.cpp LinSolverDirectSerialILU0.cpp SystemSolver.cpp + Preconditioner.cpp + PreconditionerLU.cpp ) set(ReSolve_KLU_SRC LinSolverDirectKLU.cpp) @@ -49,6 +51,8 @@ set(ReSolve_HEADER_INSTALL SystemSolver.hpp GramSchmidt.hpp MemoryUtils.hpp + Preconditioner.hpp + PreconditionerLU.hpp ) set(ReSolve_KLU_HEADER_INSTALL LinSolverDirectKLU.hpp) diff --git a/resolve/LinSolverDirect.cpp b/resolve/LinSolverDirect.cpp index d82e1046..21d5183f 100644 --- a/resolve/LinSolverDirect.cpp +++ b/resolve/LinSolverDirect.cpp @@ -61,6 +61,23 @@ namespace ReSolve return 0; } + /** + * @brief Resets the matrix for the solver. + * + * @param[in] A - matrix to be reset + * + * @return int - error code, 0 if successful + */ + int LinSolverDirect::reset(matrix::Sparse* A) + { + if (A == nullptr) + { + return 1; + } + A_ = A; + return 0; + } + /** * @brief Placeholder function for symbolic factorization. */ diff --git a/resolve/LinSolverDirect.hpp b/resolve/LinSolverDirect.hpp index aafc03f8..a7316d9e 100644 --- a/resolve/LinSolverDirect.hpp +++ b/resolve/LinSolverDirect.hpp @@ -29,6 +29,7 @@ namespace ReSolve virtual int analyze(); // the same as symbolic factorization virtual int factorize(); virtual int refactorize(); + virtual int reset(matrix::Sparse* A); virtual int solve(vector_type* rhs, vector_type* x) = 0; virtual int solve(vector_type* x) = 0; diff --git a/resolve/LinSolverDirectCpuILU0.hpp b/resolve/LinSolverDirectCpuILU0.hpp index 5c8561bf..d53367aa 100644 --- a/resolve/LinSolverDirectCpuILU0.hpp +++ b/resolve/LinSolverDirectCpuILU0.hpp @@ -56,7 +56,7 @@ namespace ReSolve index_type* Q = nullptr, vector_type* rhs = nullptr) override; // if values of A change, but the nnz pattern does not, redo the analysis only (reuse buffers though) - int reset(matrix::Sparse* A); + int reset(matrix::Sparse* A) override; int analyze() override; int factorize() override; diff --git a/resolve/LinSolverDirectCuSolverRf.hpp b/resolve/LinSolverDirectCuSolverRf.hpp index da51c1d8..bf9951ca 100644 --- a/resolve/LinSolverDirectCuSolverRf.hpp +++ b/resolve/LinSolverDirectCuSolverRf.hpp @@ -37,7 +37,7 @@ namespace ReSolve matrix::Sparse* U, index_type* P, index_type* Q, - vector_type* rhs = nullptr); + vector_type* rhs = nullptr) override; int refactorize() override; int solve(vector_type* rhs, vector_type* x) override; diff --git a/resolve/LinSolverDirectCuSparseILU0.hpp b/resolve/LinSolverDirectCuSparseILU0.hpp index f4ae5ae6..a56b2321 100644 --- a/resolve/LinSolverDirectCuSparseILU0.hpp +++ b/resolve/LinSolverDirectCuSparseILU0.hpp @@ -37,7 +37,7 @@ namespace ReSolve index_type* Q = nullptr, vector_type* rhs = nullptr) override; // if values of A change, but the nnz pattern does not, redo the analysis only (reuse buffers though) - int reset(matrix::Sparse* A); + int reset(matrix::Sparse* A) override; int solve(vector_type* rhs, vector_type* x) override; int solve(vector_type* rhs) override; diff --git a/resolve/LinSolverDirectRocSparseILU0.hpp b/resolve/LinSolverDirectRocSparseILU0.hpp index 1a18a14d..ccb376cd 100644 --- a/resolve/LinSolverDirectRocSparseILU0.hpp +++ b/resolve/LinSolverDirectRocSparseILU0.hpp @@ -39,7 +39,7 @@ namespace ReSolve index_type* Q = nullptr, vector_type* rhs = nullptr) override; // if values of A change, but the nnz pattern does not, redo the analysis only (reuse buffers though) - int reset(matrix::Sparse* A); + int reset(matrix::Sparse* A) override; int solve(vector_type* rhs, vector_type* x) override; int solve(vector_type* rhs) override; // the solution is returned IN RHS (rhs is overwritten) diff --git a/resolve/LinSolverDirectSerialILU0.hpp b/resolve/LinSolverDirectSerialILU0.hpp index 64b96632..f69aa2d9 100644 --- a/resolve/LinSolverDirectSerialILU0.hpp +++ b/resolve/LinSolverDirectSerialILU0.hpp @@ -34,7 +34,7 @@ namespace ReSolve index_type* P = nullptr, index_type* Q = nullptr, vector_type* rhs = nullptr) override; - int reset(matrix::Sparse* A); + int reset(matrix::Sparse* A) override; int solve(vector_type* rhs, vector_type* x) override; int solve(vector_type* rhs) override; // the solutuon is returned IN RHS (rhs is overwritten) diff --git a/resolve/LinSolverIterative.cpp b/resolve/LinSolverIterative.cpp index c24af8f7..efc281b4 100644 --- a/resolve/LinSolverIterative.cpp +++ b/resolve/LinSolverIterative.cpp @@ -57,6 +57,16 @@ namespace ReSolve return maxit_; } + int LinSolverIterative::setPreconditioner(Preconditioner* preconditioner) + { + if (preconditioner == nullptr) + { + return 1; + } + preconditioner_ = preconditioner; + return 0; + } + int LinSolverIterative::setOrthogonalization(GramSchmidt* /* gs */) { out::error() << "Solver does not implement setting orthogonalization.\n"; diff --git a/resolve/LinSolverIterative.hpp b/resolve/LinSolverIterative.hpp index 3cd1ef2e..f881df97 100644 --- a/resolve/LinSolverIterative.hpp +++ b/resolve/LinSolverIterative.hpp @@ -15,6 +15,7 @@ namespace ReSolve { class GramSchmidt; class LinSolverDirect; + class Preconditioner; class LinSolverIterative : public LinSolver { @@ -22,8 +23,7 @@ namespace ReSolve LinSolverIterative(); virtual ~LinSolverIterative(); virtual int setup(matrix::Sparse* A); - virtual int resetMatrix(matrix::Sparse* A) = 0; - virtual int setupPreconditioner(std::string type, LinSolverDirect* LU_solver) = 0; + virtual int resetMatrix(matrix::Sparse* A) = 0; virtual int solve(vector_type* rhs, vector_type* init_guess) = 0; @@ -31,6 +31,7 @@ namespace ReSolve virtual real_type getInitResidualNorm() const; virtual index_type getNumIter() const; + virtual int setPreconditioner(Preconditioner* preconditioner); virtual int setOrthogonalization(GramSchmidt* gs); real_type getTol() const; @@ -40,6 +41,8 @@ namespace ReSolve void setMaxit(index_type new_maxit); protected: + Preconditioner* preconditioner_{nullptr}; + real_type initial_residual_norm_; real_type final_residual_norm_; index_type total_iters_; diff --git a/resolve/LinSolverIterativeFGMRES.cpp b/resolve/LinSolverIterativeFGMRES.cpp index 643f23a4..871aa7c8 100644 --- a/resolve/LinSolverIterativeFGMRES.cpp +++ b/resolve/LinSolverIterativeFGMRES.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -317,20 +318,6 @@ namespace ReSolve return 0; } - int LinSolverIterativeFGMRES::setupPreconditioner(std::string type, LinSolverDirect* LU_solver) - { - if (type != "LU") - { - out::warning() << "Only LU-type solve can be used as a preconditioner at this time." << std::endl; - return 1; - } - else - { - LU_solver_ = LU_solver; - return 0; - } - } - int LinSolverIterativeFGMRES::resetMatrix(matrix::Sparse* new_matrix) { A_ = new_matrix; @@ -598,7 +585,7 @@ namespace ReSolve void LinSolverIterativeFGMRES::precV(vector_type* rhs, vector_type* x) { - LU_solver_->solve(rhs, x); + preconditioner_->apply(rhs, x); } void LinSolverIterativeFGMRES::setMemorySpace() diff --git a/resolve/LinSolverIterativeFGMRES.hpp b/resolve/LinSolverIterativeFGMRES.hpp index 38887579..de9ebc8d 100644 --- a/resolve/LinSolverIterativeFGMRES.hpp +++ b/resolve/LinSolverIterativeFGMRES.hpp @@ -16,6 +16,7 @@ namespace ReSolve // Forward declarations class SketchingHandler; class GramSchmidt; + class Preconditioner; namespace matrix { @@ -55,7 +56,6 @@ namespace ReSolve int solve(vector_type* rhs, vector_type* x) override; int setup(matrix::Sparse* A) override; int resetMatrix(matrix::Sparse* new_A) override; - int setupPreconditioner(std::string name, LinSolverDirect* LU_solver) override; int setOrthogonalization(GramSchmidt* gs) override; int setRestart(index_type restart); @@ -103,10 +103,9 @@ namespace ReSolve real_type* h_s_{nullptr}; real_type* h_rs_{nullptr}; - GramSchmidt* GS_{nullptr}; - LinSolverDirect* LU_solver_{nullptr}; - index_type n_{0}; - bool is_solver_set_{false}; + GramSchmidt* GS_{nullptr}; + index_type n_{0}; + bool is_solver_set_{false}; MemoryHandler mem_; ///< Device memory manager object }; diff --git a/resolve/LinSolverIterativeRandFGMRES.cpp b/resolve/LinSolverIterativeRandFGMRES.cpp index db26b007..e570cdaf 100644 --- a/resolve/LinSolverIterativeRandFGMRES.cpp +++ b/resolve/LinSolverIterativeRandFGMRES.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -406,20 +407,6 @@ namespace ReSolve return 0; } - int LinSolverIterativeRandFGMRES::setupPreconditioner(std::string type, LinSolverDirect* LU_solver) - { - if (type != "LU") - { - out::warning() << "Only cusolverRf tri solve can be used as a preconditioner at this time." << std::endl; - return 1; - } - else - { - LU_solver_ = LU_solver; - return 0; - } - } - index_type LinSolverIterativeRandFGMRES::getKrand() { return k_rand_; @@ -788,7 +775,7 @@ namespace ReSolve void LinSolverIterativeRandFGMRES::precV(vector_type* rhs, vector_type* x) { - LU_solver_->solve(rhs, x); + preconditioner_->apply(rhs, x); } /** diff --git a/resolve/LinSolverIterativeRandFGMRES.hpp b/resolve/LinSolverIterativeRandFGMRES.hpp index c16f7711..e2da24b8 100644 --- a/resolve/LinSolverIterativeRandFGMRES.hpp +++ b/resolve/LinSolverIterativeRandFGMRES.hpp @@ -16,6 +16,7 @@ namespace ReSolve // Forward declarations class SketchingHandler; class GramSchmidt; + class Preconditioner; namespace matrix { @@ -67,7 +68,6 @@ namespace ReSolve int solve(vector_type* rhs, vector_type* x) override; int setup(matrix::Sparse* A) override; int resetMatrix(matrix::Sparse* new_A) override; - int setupPreconditioner(std::string name, LinSolverDirect* LU_solver) override; int setOrthogonalization(GramSchmidt* gs) override; int setRestart(index_type restart); @@ -123,10 +123,9 @@ namespace ReSolve real_type* h_rs_{nullptr}; vector_type* vec_aux_{nullptr}; - GramSchmidt* GS_{nullptr}; - LinSolverDirect* LU_solver_{nullptr}; - index_type n_{0}; - real_type one_over_k_{1.0}; + GramSchmidt* GS_{nullptr}; + index_type n_{0}; + real_type one_over_k_{1.0}; index_type k_rand_{0}; ///< size of sketch space. We need to know it so we can allocate S! MemoryHandler mem_; ///< Device memory manager object diff --git a/resolve/Preconditioner.cpp b/resolve/Preconditioner.cpp new file mode 100644 index 00000000..4d85556d --- /dev/null +++ b/resolve/Preconditioner.cpp @@ -0,0 +1,25 @@ +/** + * @file Preconditioner.cpp + * @author Kakeru Ueda (k.ueda.2290@m.isct.ac.jp) + * @brief Implementation of preconditioner base class. + * + */ + +#include "Preconditioner.hpp" + +namespace ReSolve +{ + Preconditioner::Preconditioner() + { + } + + Preconditioner::~Preconditioner() + { + } + + int Preconditioner::reset(matrix_type* /* A */) + { + return 1; + } + +} // namespace ReSolve diff --git a/resolve/Preconditioner.hpp b/resolve/Preconditioner.hpp new file mode 100644 index 00000000..4ed21cda --- /dev/null +++ b/resolve/Preconditioner.hpp @@ -0,0 +1,39 @@ +/** + * @file Preconditioner.hpp + * @author Kakeru Ueda (k.ueda.2290@m.isct.ac.jp) + * @brief Declaration of preconditioner base class. + * + */ +#pragma once + +namespace ReSolve +{ + namespace matrix + { + class Sparse; + } // namespace matrix + + namespace vector + { + class Vector; + } // namespace vector + + /** + * @class Preconditioner + * + * @brief Interface for preconditioner. + */ + class Preconditioner + { + public: + using vector_type = vector::Vector; + using matrix_type = matrix::Sparse; + + Preconditioner(); + virtual ~Preconditioner(); + + virtual int setup(matrix_type* A) = 0; + virtual int apply(vector_type* rhs, vector_type* x) = 0; + virtual int reset(matrix_type* /* A */); + }; +} // namespace ReSolve diff --git a/resolve/PreconditionerLU.cpp b/resolve/PreconditionerLU.cpp new file mode 100644 index 00000000..fd6ec3c7 --- /dev/null +++ b/resolve/PreconditionerLU.cpp @@ -0,0 +1,87 @@ +/** + * @file PreconditionerLU.cpp + * @author Kakeru Ueda (k.ueda.2290@m.isct.ac.jp) + * @brief Declaration of preconditioner ILU0 class. + * + */ + +#include "PreconditionerLU.hpp" + +#include + +namespace ReSolve +{ + /** + * @brief Constructor for PreconditionerLU. + * + * @param[in] solver - Pointer to the LinSolverDirect object. + */ + PreconditionerLU::PreconditionerLU(LinSolverDirect* solver) + { + solver_ = solver; + } + + /** + * @brief Destructor for PreconditionerLU + */ + PreconditionerLU::~PreconditionerLU() + { + } + + /** + * @brief Sets up the preconditioner with the given matrix + * + * @param[in] A - System matrix to set up the preconditioner with + * + * @return int 0 if successful, 1 if it fails + */ + int PreconditionerLU::setup(matrix_type* A) + { + if (A == nullptr) + { + return 1; + } + solver_->setup(A); + + return 0; + } + + /** + * @brief Applies the preconditioner to solve the system Mx = rhs + * + * Computes x = M^(-1) * rhs where M is the preconditioner matrix. + * + * @param[in] rhs - Right-hand-side vector + * @param[in] x - Solution vector + * + * @return int 0 if successful, 1 if fails + */ + int PreconditionerLU::apply(vector_type* rhs, vector_type* x) + { + if (solver_ == nullptr) + { + return 1; + } + solver_->solve(rhs, x); + + return 0; + } + + /** + * @brief Resets the preconditioner with the given matrix + * + * @param[in] A - System matrix to reset the preconditioner with + * + * @return int 0 if successful, 1 if it fails + */ + int PreconditionerLU::reset(matrix_type* A) + { + if (A == nullptr) + { + return 1; + } + solver_->reset(A); + + return 0; + } +} // namespace ReSolve diff --git a/resolve/PreconditionerLU.hpp b/resolve/PreconditionerLU.hpp new file mode 100644 index 00000000..14ede135 --- /dev/null +++ b/resolve/PreconditionerLU.hpp @@ -0,0 +1,41 @@ +/** + * @file PreconditionerILU0.cpp + * @author Kakeru Ueda (k.ueda.2290@m.isct.ac.jp) + * @brief Declaration of preconditioner ILU0 class. + * + */ + +#include + +namespace ReSolve +{ + // Forward declaration of workspace + class LinSolverDirect; + + namespace matrix + { + class Sparse; + } // namespace matrix + + namespace vector + { + class Vector; + } // namespace vector + + class PreconditionerLU : public Preconditioner + { + public: + using vector_type = vector::Vector; + using matrix_type = matrix::Sparse; + + PreconditionerLU(LinSolverDirect* solver); + ~PreconditionerLU(); + + int setup(matrix_type* A) override; + int apply(vector_type* rhs, vector_type* x) override; + int reset(matrix_type* A) override; + + private: + LinSolverDirect* solver_{nullptr}; + }; +} // namespace ReSolve diff --git a/resolve/SystemSolver.cpp b/resolve/SystemSolver.cpp index 884c3fbd..ecf927c2 100644 --- a/resolve/SystemSolver.cpp +++ b/resolve/SystemSolver.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -170,13 +171,14 @@ namespace ReSolve if (precondition_method_ != "none") { delete preconditioner_; + delete preconditionSolver_; } delete matrixHandler_; delete vectorHandler_; } - int SystemSolver::setMatrix(matrix::Sparse* A) + int SystemSolver::setMatrix(matrix_type* A) { int status = 0; A_ = A; @@ -229,6 +231,11 @@ namespace ReSolve delete refactorizationSolver_; refactorizationSolver_ = nullptr; } + if (preconditionSolver_) + { + delete preconditionSolver_; + preconditionSolver_ = nullptr; + } if (preconditioner_) { delete preconditioner_; @@ -312,19 +319,21 @@ namespace ReSolve { if (memspace_ == "cpu") { - // preconditioner_ = new LinSolverDirectSerialILU0(workspaceCpu_); - preconditioner_ = new LinSolverDirectCpuILU0(workspaceCpu_); + preconditionSolver_ = new LinSolverDirectSerialILU0(workspaceCpu_); + preconditioner_ = new PreconditionerLU(preconditionSolver_); #ifdef RESOLVE_USE_CUDA } else if (memspace_ == "cuda") { - preconditioner_ = new LinSolverDirectCuSparseILU0(workspaceCuda_); + preconditionSolver_ = new LinSolverDirectCuSparseILU0(workspaceCuda_); + preconditioner_ = new PreconditionerLU(preconditionSolver_); #endif #ifdef RESOLVE_USE_HIP } else if (memspace_ == "hip") { - preconditioner_ = new LinSolverDirectRocSparseILU0(workspaceHip_); + preconditionSolver_ = new LinSolverDirectRocSparseILU0(workspaceHip_); + preconditioner_ = new PreconditionerLU(preconditionSolver_); #endif } else @@ -482,9 +491,16 @@ namespace ReSolve if (irMethod_ == "fgmres") { status += iterativeSolver_->setup(A_); - status += iterativeSolver_->setupPreconditioner("LU", refactorizationSolver_); - } + if (preconditioner_) + { + delete preconditioner_; + preconditioner_ = nullptr; + } + + preconditioner_ = new PreconditionerLU(refactorizationSolver_); + status += iterativeSolver_->setPreconditioner(preconditioner_); + } return status; } @@ -540,6 +556,13 @@ namespace ReSolve return status; } + /** + * @brief Sets up the preconditioner for the system solver + * + * Initializes and attaches the preconditioner to the iterative solver. + * + * @return int 0 if successful, 1 if it fails + */ int SystemSolver::preconditionerSetup() { int status = 0; @@ -550,7 +573,28 @@ namespace ReSolve { is_solve_on_device_ = true; } - iterativeSolver_->setupPreconditioner("LU", preconditioner_); + status += iterativeSolver_->setPreconditioner(preconditioner_); + } + + return status; + } + + /** + * @brief Reset the preconditioner with a new matrix. + * + * Assumes the matrix sparsity pattern does not change. + * + * @param[in] A New sparse matrix (values updated). + * + * @return int 0 if successful, 1 if it fails + */ + int SystemSolver::resetPreconditioner(matrix_type* A) + { + int status = 0; + A_ = A; + if (precondition_method_ == "ilu0") + { + status += preconditioner_->reset(A); } return status; diff --git a/resolve/SystemSolver.hpp b/resolve/SystemSolver.hpp index 0aa5173c..ca0c313d 100644 --- a/resolve/SystemSolver.hpp +++ b/resolve/SystemSolver.hpp @@ -10,6 +10,7 @@ namespace ReSolve class LinAlgWorkspaceCpu; class MatrixHandler; class VectorHandler; + class Preconditioner; namespace vector { @@ -27,9 +28,6 @@ namespace ReSolve using vector_type = vector::Vector; using matrix_type = matrix::Sparse; - /// @brief Temporary until abstract preconditioner class is created - using precond_type = LinSolverDirect; - SystemSolver(LinAlgWorkspaceCpu* workspaceCpu, std::string factor = "klu", std::string refactor = "klu", @@ -52,12 +50,13 @@ namespace ReSolve ~SystemSolver(); int initialize(); - int setMatrix(matrix::Sparse* A); + int setMatrix(matrix_type* A); int analyze(); // symbolic part int factorize(); // numeric part int refactorize(); int refactorizationSetup(); int preconditionerSetup(); + int resetPreconditioner(matrix_type* A); int solve(vector_type* rhs, vector_type* x); // for direct and iterative int refine(vector_type* rhs, vector_type* x); // for iterative refinement @@ -90,10 +89,10 @@ namespace ReSolve private: LinSolverDirect* factorizationSolver_{nullptr}; LinSolverDirect* refactorizationSolver_{nullptr}; + LinSolverDirect* preconditionSolver_{nullptr}; LinSolverIterative* iterativeSolver_{nullptr}; GramSchmidt* gs_{nullptr}; - - precond_type* preconditioner_{nullptr}; + Preconditioner* preconditioner_{nullptr}; LinAlgWorkspaceCUDA* workspaceCuda_{nullptr}; LinAlgWorkspaceHIP* workspaceHip_{nullptr}; diff --git a/tests/functionality/testKlu.cpp b/tests/functionality/testKlu.cpp index 9bb6933b..9ac09c3f 100644 --- a/tests/functionality/testKlu.cpp +++ b/tests/functionality/testKlu.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -150,7 +151,8 @@ int runTest(int argc, char* argv[], std::string& solver_name) status = FGMRES.setup(A); error_sum += status; - status = FGMRES.setupPreconditioner("LU", &KLU); + ReSolve::PreconditionerLU precond_lu(&KLU); + status = FGMRES.setPreconditioner(&precond_lu); error_sum += status; status = FGMRES.solve(&vec_rhs, &vec_x); error_sum += status; @@ -197,7 +199,8 @@ int runTest(int argc, char* argv[], std::string& solver_name) if (is_ir) { FGMRES.resetMatrix(A); - status = FGMRES.setupPreconditioner("LU", &KLU); + ReSolve::PreconditionerLU precond_lu(&KLU); + status = FGMRES.setPreconditioner(&precond_lu); error_sum += status; status = FGMRES.solve(&vec_rhs, &vec_x); error_sum += status; diff --git a/tests/functionality/testRandGmres.cpp b/tests/functionality/testRandGmres.cpp index 35ed2e9c..ca9adb46 100644 --- a/tests/functionality/testRandGmres.cpp +++ b/tests/functionality/testRandGmres.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -132,7 +133,8 @@ int runTest(int argc, char* argv[]) FGMRES.setRestart(200); FGMRES.setSketchingMethod(LinSolverIterativeRandFGMRES::cs); - status = FGMRES.setupPreconditioner("LU", &ILU); + PreconditionerLU precond_lu(&ILU); + status = FGMRES.setPreconditioner(&precond_lu); error_sum += status; FGMRES.setFlexible(true); diff --git a/tests/functionality/testRefactor.cpp b/tests/functionality/testRefactor.cpp index aff7c2de..c3e87c28 100644 --- a/tests/functionality/testRefactor.cpp +++ b/tests/functionality/testRefactor.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -190,7 +191,8 @@ int runTest(int argc, char* argv[], std::string& solver_name) status = FGMRES.setup(A); error_sum += status; - status = FGMRES.setupPreconditioner("LU", &Rf); + ReSolve::PreconditionerLU precond_lu(&Rf); + status = FGMRES.setPreconditioner(&precond_lu); error_sum += status; status = FGMRES.solve(&vec_rhs, &vec_x); @@ -239,7 +241,8 @@ int runTest(int argc, char* argv[], std::string& solver_name) if (is_ir) { FGMRES.resetMatrix(A); - status = FGMRES.setupPreconditioner("LU", &Rf); + ReSolve::PreconditionerLU precond_lu(&Rf); + status = FGMRES.setPreconditioner(&precond_lu); error_sum += status; status = FGMRES.solve(&vec_rhs, &vec_x);