1 /*  _______________________________________________________________________
2 
3     PECOS: Parallel Environment for Creation Of Stochastics
4     Copyright (c) 2011, Sandia National Laboratories.
5     This software is distributed under the GNU Lesser General Public License.
6     For more information, see the README file in the top Pecos directory.
7     _______________________________________________________________________ */
8 
9 #include "linear_solvers.hpp"
10 
11 namespace Pecos {
12 namespace util {
13 
regression_solver_factory(OptionsList & opts)14 std::shared_ptr<LinearSystemSolver> regression_solver_factory(OptionsList &opts){
15   std::string name = "regression_type";
16   RegressionType regression_type =
17     get_enum_enforce_existance<RegressionType>(opts, name);
18 
19   bool use_cross_validation = opts.get("use-cross-validation", false);
20   if (use_cross_validation){
21     std::shared_ptr<CrossValidatedSolver> cv_solver(new CrossValidatedSolver);
22     cv_solver->set_linear_system_solver(regression_type);
23     return cv_solver;
24   }
25 
26   switch (regression_type){
27   case ORTHOG_MATCH_PURSUIT : {
28     std::shared_ptr<OMPSolver> omp_solver(new OMPSolver);
29       return omp_solver;
30   }
31   case LEAST_ANGLE_REGRESSION : {
32     std::shared_ptr<LARSolver> lars_solver(new LARSolver);
33     lars_solver->set_sub_solver(LEAST_ANGLE_REGRESSION);
34     return lars_solver;
35   }
36   case LASSO_REGRESSION : {
37     std::shared_ptr<LARSolver> lars_solver(new LARSolver);
38     lars_solver->set_sub_solver(LASSO_REGRESSION);
39     return lars_solver;
40   }
41   case EQ_CONS_LEAST_SQ_REGRESSION : {
42     std::shared_ptr<EqConstrainedLSQSolver>
43       eqlsq_solver(new EqConstrainedLSQSolver);
44     return eqlsq_solver;
45   }
46   case SVD_LEAST_SQ_REGRESSION: case LU_LEAST_SQ_REGRESSION:
47   case QR_LEAST_SQ_REGRESSION:
48   {
49     //\todo add set_lsq_solver to lsqsolver class so we can switch
50     // between svd, qr and lu factorization methods
51     std::shared_ptr<LSQSolver> lsq_solver(new LSQSolver);
52     return lsq_solver;
53   }
54   default: {
55     throw(std::runtime_error("Incorrect \"regression-type\""));
56   }
57   }
58 }
59 
cast_linear_system_solver_to_ompsolver(std::shared_ptr<LinearSystemSolver> & solver)60 std::shared_ptr<OMPSolver> cast_linear_system_solver_to_ompsolver(std::shared_ptr<LinearSystemSolver> &solver){
61   std::shared_ptr<OMPSolver> solver_cast =
62     std::dynamic_pointer_cast<OMPSolver>(solver);
63   if (!solver_cast)
64     throw(std::runtime_error("Could not cast to OMPSolver shared_ptr"));
65   return solver_cast;
66 }
67 
cast_linear_system_solver_to_larssolver(std::shared_ptr<LinearSystemSolver> & solver)68 std::shared_ptr<LARSolver> cast_linear_system_solver_to_larssolver(std::shared_ptr<LinearSystemSolver> &solver){
69   std::shared_ptr<LARSolver> solver_cast =
70     std::dynamic_pointer_cast<LARSolver>(solver);
71   if (!solver_cast)
72     throw(std::runtime_error("Could not cast to LARSolver shared_ptr"));
73   return solver_cast;
74 }
75 
cast_linear_system_solver_to_lsqsolver(std::shared_ptr<LinearSystemSolver> & solver)76 std::shared_ptr<LSQSolver> cast_linear_system_solver_to_lsqsolver(std::shared_ptr<LinearSystemSolver> &solver){
77   std::shared_ptr<LSQSolver> solver_cast =
78     std::dynamic_pointer_cast<LSQSolver>(solver);
79   if (!solver_cast)
80     throw(std::runtime_error("Could not cast to LSQSolver shared_ptr"));
81   return solver_cast;
82 }
83 
cast_linear_system_solver_to_eqconstrainedlsqsolver(std::shared_ptr<LinearSystemSolver> & solver)84 std::shared_ptr<EqConstrainedLSQSolver> cast_linear_system_solver_to_eqconstrainedlsqsolver(std::shared_ptr<LinearSystemSolver> &solver){
85   std::shared_ptr<EqConstrainedLSQSolver> solver_cast =
86     std::dynamic_pointer_cast<EqConstrainedLSQSolver>(solver);
87   if (!solver_cast)
88     throw(std::runtime_error("Could not cast to EqConstrainedLSQSolver shared_ptr"));
89   return solver_cast;
90 }
91 
92 }  // namespace util
93 }  // namespace Pecos
94