1 /* _______________________________________________________________________ 2 3 PECOS: Parallel Environment for Creation Of Stochastics 4 Copyright (c) 2011, Sandia National Laboratories. 5 This software is distributed under the GNU Lesser General Public License. 6 For more information, see the README file in the top Pecos directory. 7 _______________________________________________________________________ */ 8 9 #include "linear_solvers.hpp" 10 11 namespace Pecos { 12 namespace util { 13 regression_solver_factory(OptionsList & opts)14std::shared_ptr<LinearSystemSolver> regression_solver_factory(OptionsList &opts){ 15 std::string name = "regression_type"; 16 RegressionType regression_type = 17 get_enum_enforce_existance<RegressionType>(opts, name); 18 19 bool use_cross_validation = opts.get("use-cross-validation", false); 20 if (use_cross_validation){ 21 std::shared_ptr<CrossValidatedSolver> cv_solver(new CrossValidatedSolver); 22 cv_solver->set_linear_system_solver(regression_type); 23 return cv_solver; 24 } 25 26 switch (regression_type){ 27 case ORTHOG_MATCH_PURSUIT : { 28 std::shared_ptr<OMPSolver> omp_solver(new OMPSolver); 29 return omp_solver; 30 } 31 case LEAST_ANGLE_REGRESSION : { 32 std::shared_ptr<LARSolver> lars_solver(new LARSolver); 33 lars_solver->set_sub_solver(LEAST_ANGLE_REGRESSION); 34 return lars_solver; 35 } 36 case LASSO_REGRESSION : { 37 std::shared_ptr<LARSolver> lars_solver(new LARSolver); 38 lars_solver->set_sub_solver(LASSO_REGRESSION); 39 return lars_solver; 40 } 41 case EQ_CONS_LEAST_SQ_REGRESSION : { 42 std::shared_ptr<EqConstrainedLSQSolver> 43 eqlsq_solver(new EqConstrainedLSQSolver); 44 return eqlsq_solver; 45 } 46 case SVD_LEAST_SQ_REGRESSION: case LU_LEAST_SQ_REGRESSION: 47 case QR_LEAST_SQ_REGRESSION: 48 { 49 //\todo add set_lsq_solver to lsqsolver class so we can switch 50 // between svd, qr and lu factorization methods 51 std::shared_ptr<LSQSolver> lsq_solver(new LSQSolver); 52 return lsq_solver; 53 } 54 default: { 55 throw(std::runtime_error("Incorrect \"regression-type\"")); 56 } 57 } 58 } 59 cast_linear_system_solver_to_ompsolver(std::shared_ptr<LinearSystemSolver> & solver)60std::shared_ptr<OMPSolver> cast_linear_system_solver_to_ompsolver(std::shared_ptr<LinearSystemSolver> &solver){ 61 std::shared_ptr<OMPSolver> solver_cast = 62 std::dynamic_pointer_cast<OMPSolver>(solver); 63 if (!solver_cast) 64 throw(std::runtime_error("Could not cast to OMPSolver shared_ptr")); 65 return solver_cast; 66 } 67 cast_linear_system_solver_to_larssolver(std::shared_ptr<LinearSystemSolver> & solver)68std::shared_ptr<LARSolver> cast_linear_system_solver_to_larssolver(std::shared_ptr<LinearSystemSolver> &solver){ 69 std::shared_ptr<LARSolver> solver_cast = 70 std::dynamic_pointer_cast<LARSolver>(solver); 71 if (!solver_cast) 72 throw(std::runtime_error("Could not cast to LARSolver shared_ptr")); 73 return solver_cast; 74 } 75 cast_linear_system_solver_to_lsqsolver(std::shared_ptr<LinearSystemSolver> & solver)76std::shared_ptr<LSQSolver> cast_linear_system_solver_to_lsqsolver(std::shared_ptr<LinearSystemSolver> &solver){ 77 std::shared_ptr<LSQSolver> solver_cast = 78 std::dynamic_pointer_cast<LSQSolver>(solver); 79 if (!solver_cast) 80 throw(std::runtime_error("Could not cast to LSQSolver shared_ptr")); 81 return solver_cast; 82 } 83 cast_linear_system_solver_to_eqconstrainedlsqsolver(std::shared_ptr<LinearSystemSolver> & solver)84std::shared_ptr<EqConstrainedLSQSolver> cast_linear_system_solver_to_eqconstrainedlsqsolver(std::shared_ptr<LinearSystemSolver> &solver){ 85 std::shared_ptr<EqConstrainedLSQSolver> solver_cast = 86 std::dynamic_pointer_cast<EqConstrainedLSQSolver>(solver); 87 if (!solver_cast) 88 throw(std::runtime_error("Could not cast to EqConstrainedLSQSolver shared_ptr")); 89 return solver_cast; 90 } 91 92 } // namespace util 93 } // namespace Pecos 94