1 #ifndef ODEBASE_H_ 2 #define ODEBASE_H_ 3 4 #include <boost/property_tree/ptree.hpp> 5 6 #include "MUQ/Modeling/ModPiece.h" 7 #include "MUQ/Modeling/ODEData.h" 8 9 namespace muq { 10 namespace Modeling { 11 12 /// A bass class to integrate ODE's 13 class ODEBase : public ModPiece { 14 public: 15 16 /** 17 @param[in] rhs The right hand side of the ODE 18 @param[in] inputSizes The input input sizes 19 @param[in] outputSizes The output sizes 20 @param[in] pt A boost::property_tree::ptree with options/tolerances for the ODE integrator 21 */ 22 #if MUQ_HAS_PARCER==1 23 ODEBase(std::shared_ptr<ModPiece> const& rhs, Eigen::VectorXi const& inputSizes, Eigen::VectorXi const& outputSizes, boost::property_tree::ptree const& pt, std::shared_ptr<parcer::Communicator> const& comm = nullptr); 24 #else 25 ODEBase(std::shared_ptr<ModPiece> const& rhs, Eigen::VectorXi const& inputSizes, Eigen::VectorXi const& outputSizes, boost::property_tree::ptree const& pt); 26 #endif 27 28 virtual ~ODEBase(); 29 30 protected: 31 32 /// Are we computing the Jacobian, the action of the Jacobian, or the action of the Jacobian transpose 33 enum DerivativeMode { 34 /// The Jacobian 35 Jac, 36 /// The action of the Jacobian 37 JacAction, 38 /// The action of the Jacobian transpose 39 JacTransAction 40 }; 41 42 /// Check the return flag of a Sundials function 43 /** 44 @param[in] flagvalue The value of the Sundials flag 45 @param[in] funcname The name of the Sundials function 46 @param[in] opt An option to determine how to check the flag, 0: check if flag is nullptr, 1: flag is an int, check if flag<0 (indicates Sundials error) 47 \return false: failure, true: success 48 */ 49 bool CheckFlag(void* flagvalue, std::string const& funcname, unsigned int const opt) const; 50 51 /// Alloc memory and set up options for the Sundials solver 52 /** 53 @param[in] cvode_mem The Sundials solver 54 @param[in] state The initial state 55 @param[in] data An object that holds the RHS inputs and can evaluate the RHS 56 */ 57 void CreateSolverMemory(void* cvode_mem, N_Vector const& state, std::shared_ptr<ODEData> data) const; 58 59 int CreateSolverMemoryB(void* cvode_mem, double const timeFinal, N_Vector const& lambda, N_Vector const& nvGrad, std::shared_ptr<ODEData> data) const; 60 61 /// Deal with Sundials errors 62 /** 63 Sundials will call this function if it runs into a problem 64 @param[in] error_code Sundials error code 65 @param[in] module The name of the CVODES module reporting the error 66 @param[in] function The name of the function in which the error occured 67 @param[in] msg The error message 68 @param[in] user_data A pointer to an muq::Modeling::ODEData 69 */ 70 static void ErrorHandler(int error_code, const char *module, const char *function, char *msg, void *user_data); 71 72 /// Evaluate the right hand side 73 /** 74 @param[in] time The current time 75 @param[in] state The current state 76 @param[out] deriv The derivative of the state with respect to time 77 @param[in] user_data A pointer to an muq::Modeling::ODEData 78 */ 79 static int EvaluateRHS(realtype time, N_Vector state, N_Vector deriv, void *user_data); 80 81 static int AdjointRHS(realtype time, N_Vector state, N_Vector lambda, N_Vector deriv, void *user_data); 82 83 /// Evaluate the Jacobian of the right hand side 84 /** 85 @param[in] N 86 @param[in] time The current time 87 @param[in] state The current state 88 @param[in] rhs The derivative of the state with respect to time 89 @param[out] jac The Jacobian of the right hand side with respect to the current state 90 @param[in] user_data A pointer to an muq::Modeling::ODEData 91 @param[in] tmp1 92 @param[in] tmp2 93 @param[in] tmp3 94 */ 95 static int RHSJacobian(long int N, realtype time, N_Vector state, N_Vector rhs, DlsMat jac, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); 96 97 static int AdjointJacobian(long int N, realtype time, N_Vector state, N_Vector lambda, N_Vector rhs, DlsMat jac, void *user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); 98 99 /// Evaluate the action of the Jacobian of the right hand side 100 /** 101 @param[in] v The vector the Jacobian is acting on 102 @param[out] Jv The action of the Jacobian on v 103 @param[in] time The current time 104 @param[in] state The current state 105 @param[in] rhs The derivative of the state with respect to time 106 @param[in] user_data A pointer to an muq::Modeling::ODEData 107 @param[in] tmp 108 */ 109 static int RHSJacobianAction(N_Vector v, N_Vector Jv, realtype time, N_Vector state, N_Vector rhs, void *user_data, N_Vector tmp); 110 111 static int AdjointJacobianAction(N_Vector target, N_Vector output, realtype time, N_Vector state, N_Vector lambda, N_Vector adjRhs, void *user_data, N_Vector tmp); 112 113 /// Sundials uses this function to compute the derivative of the state at each timestep 114 /** 115 @param[in] Ns The number of sensitivities 116 @param[in] time The current time 117 @param[in] y The current state 118 @param[in] ydot The derivative of the current state with respect to time 119 @param[in] ys 120 @param[in] ySdot The sensitivties 121 @param[in] user_data A pointer to an muq::Modeling::ODEData 122 */ 123 static int ForwardSensitivityRHS(int Ns, realtype time, N_Vector y, N_Vector ydot, N_Vector *ys, N_Vector *ySdot, void *user_data, N_Vector tmp1, N_Vector tmp2); 124 125 static int AdjointQuad(realtype time, N_Vector state, N_Vector lambda, N_Vector quadRhs, void *user_data); 126 127 /// Set up the solver for sensitivity information 128 /** 129 @param[in] cvode_mem The Sundials solver 130 @param[in] paramSize The size of the input parameter we are differenating wrt 131 @param[in,out] sensState This will become the 'current' Jacobian 132 */ 133 void SetUpSensitivity(void *cvode_mem, unsigned int const paramSize, N_Vector *sensState) const; 134 135 /// Which linear solver should we use? 136 enum LinearSolver { 137 /// Dense solver 138 Dense, 139 /// SPGMR 140 SPGMR, 141 /// SPBCG 142 SPBCG, 143 /// SPTFQMR 144 SPTFQMR 145 }; 146 147 /// The right hand side of the ODE 148 std::shared_ptr<ModPiece> rhs; 149 150 /// Linear solver method 151 LinearSolver slvr; 152 153 /// The relative tolerance 154 const double reltol; 155 156 /// The absolute tolerance 157 const double abstol; 158 159 /// The maximum time step size 160 const double maxStepSize; 161 162 /// The maximum number of time steps 163 const unsigned int maxNumSteps; 164 165 /// Multistep method 166 int multiStep; 167 168 /// Nonlinear solver method 169 int solveMethod; 170 171 /// Is the RHS autonomous? 172 const bool autonomous; 173 174 /// Check point gap 175 const unsigned int checkPtGap; 176 177 #if MUQ_HAS_PARCER==1 178 /// The global size of the state vector 179 const unsigned int globalSize = std::numeric_limits<unsigned int>::quiet_NaN(); 180 181 std::shared_ptr<parcer::Communicator> comm = nullptr; 182 #endif 183 184 private: 185 186 }; 187 188 } // namespace Modeling 189 } // namespace muq 190 191 #endif 192