1 /* _______________________________________________________________________ 2 3 DAKOTA: Design Analysis Kit for Optimization and Terascale Applications 4 Copyright 2014-2020 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 5 This software is distributed under the GNU Lesser General Public License. 6 For more information, see the README file in the top Dakota directory. 7 _______________________________________________________________________ */ 8 9 #ifndef DAKOTA_UTIL_LINEAR_SOLVERS_HPP 10 #define DAKOTA_UTIL_LINEAR_SOLVERS_HPP 11 12 #include "util_data_types.hpp" 13 14 #include <memory> 15 16 namespace dakota { 17 namespace util { 18 19 // -------------------------------------------------------------------------------- 20 21 /** 22 * \brief The LinearSolverBase class serves as an API for derived solvers. 23 */ 24 class LinearSolverBase 25 { 26 public: 27 28 /// How best to Doxygenate class enums? RWH 29 enum class SOLVER_TYPE { CHOLESKY, 30 EQ_CONS_LEAST_SQ_REGRESSION, 31 LASSO_REGRESSION, 32 LEAST_ANGLE_REGRESSION, 33 LU, 34 ORTHOG_MATCH_PURSUIT, 35 QR_LEAST_SQ_REGRESSION, 36 SVD_LEAST_SQ_REGRESSION 37 }; 38 39 /// Constructor 40 LinearSolverBase(); 41 42 /// Destructor 43 ~LinearSolverBase(); 44 45 /** 46 * \brief Convert solver name to enum type 47 * \param[in] solver_name LinearSolverBase name to map 48 * \returns Corresponding LinearSolverBase enum 49 */ 50 static SOLVER_TYPE solver_type(const std::string& solver_name); 51 52 /** 53 * \brief Query to determine if the matrix of the solver has been factored. 54 */ 55 virtual bool is_factorized() const; 56 57 /** 58 * \brief Perform the matrix factorization for the linear solver matrix. 59 * 60 * \param[in] A The incoming matrix to factorize. 61 */ 62 virtual void factorize(const MatrixXd &A); 63 64 /** 65 * \brief Find a solution to linear problem. 66 * 67 * \param[in] A The linear system left-hand-side matrix. 68 * \param[in] b The linear system right-hand-side (multi-)vector. 69 * \param[in] x The linear system solution (multi-)vector. 70 */ 71 virtual void solve(const MatrixXd &A, const MatrixXd &b, MatrixXd &x); 72 73 /** 74 * \brief Find a solution to linear problem where the LHS is already factorized. 75 * 76 * \param[in] b The linear system right-hand-side (multi-)vector. 77 * \param[in] x The linear system solution (multi-)vector. 78 */ 79 virtual void solve(const MatrixXd &b, MatrixXd &x); 80 }; 81 82 /** 83 * \brief Free function to construct LinearSolverBase 84 * 85 * \param[in] type Which solver to construct 86 * \returns Shared pointer to a LinearSolverBase 87 */ 88 std::shared_ptr<LinearSolverBase> solver_factory(LinearSolverBase::SOLVER_TYPE type); 89 90 // -------------------------------------------------------------------------------- 91 92 /** 93 * \brief The LUSolver class is used to solve linear systems with the 94 * LU decomposition. 95 */ 96 class LUSolver : public LinearSolverBase 97 { 98 public: 99 100 /// Constructor 101 LUSolver(); 102 103 /// Destructor 104 ~LUSolver(); 105 106 /** 107 * \brief Query to determine if the matrix of the solver has been factored. 108 */ 109 bool is_factorized() const override; 110 111 /** 112 * \brief Perform the matrix factorization for the linear solver matrix. 113 * 114 * \param[in] A The incoming matrix to factorize. 115 */ 116 void factorize(const MatrixXd &A) override; 117 118 /** 119 * \brief Find the solution to Ax = b. 120 * 121 * \param[in] A The linear system left-hand-side matrix. 122 * \param[in] b The linear system right-hand-side (multi-)vector. 123 * \param[in] x The linear system solution (multi-)vector. 124 */ 125 void solve(const MatrixXd &A, const MatrixXd &b, MatrixXd &x) override; 126 127 /** 128 * \brief Find the solution to Ax = b when A is already factorized. 129 * 130 * \param[in] b The linear system right-hand-side (multi-)vector. 131 * \param[in] x The linear system solution (multi-)vector. 132 */ 133 void solve(const MatrixXd &b, MatrixXd &x) override; 134 135 private: 136 137 std::shared_ptr<Eigen::FullPivLU<MatrixXd>> LU_Ptr; 138 }; 139 140 // -------------------------------------------------------------------------------- 141 142 /** 143 * \brief The SVDSolver class is used to solve linear systems with the 144 * singular value decomposition. 145 */ 146 class SVDSolver : public LinearSolverBase 147 { 148 public: 149 150 /// Constructor 151 SVDSolver(); 152 153 /// Destructor 154 ~SVDSolver(); 155 156 /** 157 * \brief Query to determine if the matrix of the solver has been factored. 158 */ 159 bool is_factorized() const override; 160 161 /** 162 * \brief Perform the matrix factorization for the linear solver matrix. 163 * 164 * \param[in] A The incoming matrix to factorize. 165 */ 166 void factorize(const MatrixXd &A) override; 167 168 /** 169 * \brief Find a solution to Ax = b. 170 * 171 * \param[in] A The linear system left-hand-side matrix. 172 * \param[in] b The linear system right-hand-side (multi-)vector. 173 * \param[in] x The linear system solution (multi-)vector. 174 */ 175 void solve(const MatrixXd &A, const MatrixXd &b, MatrixXd &x) override; 176 177 /** 178 * \brief Find a solution to Ax = b when A is already factorized. 179 * 180 * \param[in] b The linear system right-hand-side (multi-)vector. 181 * \param[in] x The linear system solution (multi-)vector. 182 */ 183 void solve(const MatrixXd &b, MatrixXd & x) override; 184 185 private: 186 187 std::shared_ptr<Eigen::BDCSVD<MatrixXd>> SVD_Ptr; 188 }; 189 190 // -------------------------------------------------------------------------------- 191 192 /** 193 * \brief The QRSolver class solves the linear least squares problem with a 194 * QR decomposition. 195 */ 196 class QRSolver : public LinearSolverBase 197 { 198 public: 199 200 /// Constructor 201 QRSolver(); 202 203 /// Destructor 204 ~QRSolver(); 205 206 /** 207 * \brief Query to determine if the matrix of the solver has been factored. 208 */ 209 bool is_factorized() const override; 210 211 /** 212 * \brief Perform the matrix factorization for the linear solver matrix. 213 * 214 * \param[in] A The incoming matrix to factorize. 215 */ 216 void factorize(const MatrixXd &A) override; 217 218 /** 219 * \brief Find the solution to (A^T*A)x = A^T*b . 220 * 221 * \param[in] A The matrix for the QR decomposition. 222 * \param[in] b The linear system right-hand-side (multi-)vector. 223 * \param[in] x The linear system solution (multi-)vector. 224 */ 225 void solve(const MatrixXd &A, const MatrixXd &b, MatrixXd &x) override; 226 227 /** 228 * \brief Find a solution to (A^T*A)x = A^T*b when A is already factorized. 229 * 230 * \param[in] b The linear system right-hand-side (multi-)vector. 231 * \param[in] x The linear system solution (multi-)vector. 232 */ 233 void solve(const MatrixXd &b, MatrixXd &x) override; 234 235 private: 236 237 std::shared_ptr<Eigen::ColPivHouseholderQR<MatrixXd>> QR_Ptr; 238 239 }; 240 241 // -------------------------------------------------------------------------------- 242 243 /** 244 * \brief The CholeskySolver class is used to solve linear systems with a 245 * symmetric matrix with a pivoted Cholesky decomposition. 246 */ 247 class CholeskySolver : public LinearSolverBase 248 { 249 public: 250 251 /// Constructor 252 CholeskySolver(); 253 254 /// Destructor 255 ~CholeskySolver(); 256 257 /** 258 * \brief Query to determine if the matrix of the solver has been factored. 259 */ 260 bool is_factorized() const override; 261 262 /** 263 * \brief Perform the matrix factorization for the linear solver matrix. 264 * 265 * \param[in] A The incoming matrix to factorize. 266 */ 267 void factorize(const MatrixXd &A) override; 268 269 /** 270 * \brief Find a solution to Ax = b. 271 * 272 * \param[in] A The linear system left-hand-side matrix. 273 * \param[in] b The linear system right-hand-side (multi-)vector. 274 * \param[in] x The linear system solution (multi-)vector. 275 */ 276 void solve(const MatrixXd &A, const MatrixXd &b, MatrixXd &x) override; 277 278 /** 279 * \brief Find a solution to Ax = b when A is already factorized. 280 * 281 * \param[in] b The linear system right-hand-side (multi-)vector. 282 * \param[in] x The linear system solution (multi-)vector. 283 */ 284 void solve(const MatrixXd &b, MatrixXd &x) override; 285 286 private: 287 288 /// Cached LDL^T factorization 289 std::shared_ptr<Eigen::LDLT<MatrixXd>> LDLT_Ptr; 290 }; 291 292 } // namespace util 293 } // namespace dakota 294 295 #endif // include guard 296