1 /* 2 * This file is part of CasADi. 3 * 4 * CasADi -- A symbolic framework for dynamic optimization. 5 * Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl, 6 * K.U. Leuven. All rights reserved. 7 * Copyright (C) 2011-2014 Greg Horn 8 * 9 * CasADi is free software; you can redistribute it and/or 10 * modify it under the terms of the GNU Lesser General Public 11 * License as published by the Free Software Foundation; either 12 * version 3 of the License, or (at your option) any later version. 13 * 14 * CasADi is distributed in the hope that it will be useful, 15 * but WITHOUT ANY WARRANTY; without even the implied warranty of 16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 * Lesser General Public License for more details. 18 * 19 * You should have received a copy of the GNU Lesser General Public 20 * License along with CasADi; if not, write to the Free Software 21 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 22 * 23 */ 24 25 26 #ifndef CASADI_ROOTFINDER_IMPL_HPP 27 #define CASADI_ROOTFINDER_IMPL_HPP 28 29 #include "rootfinder.hpp" 30 #include "oracle_function.hpp" 31 #include "plugin_interface.hpp" 32 33 /// \cond INTERNAL 34 namespace casadi { 35 36 /** \brief Integrator memory */ 37 struct CASADI_EXPORT RootfinderMemory : public OracleMemory { 38 // Inputs 39 const double** iarg; 40 41 // Outputs 42 double** ires; 43 44 // Success? 45 bool success; 46 47 // Return status 48 FunctionInternal::UnifiedReturnStatus unified_return_status; 49 }; 50 51 /// Internal class 52 class CASADI_EXPORT 53 Rootfinder : public OracleFunction, public PluginInterface<Rootfinder> { 54 public: 55 /** \brief Constructor 56 * 57 * \param f Function mapping from (n+1) inputs to 1 output. 58 */ 59 Rootfinder(const std::string& name, const Function& oracle); 60 61 /// Destructor 62 ~Rootfinder() override = 0; 63 64 ///@{ 65 /** \brief Number of function inputs and outputs */ get_n_in()66 size_t get_n_in() override { return oracle_.n_in();} get_n_out()67 size_t get_n_out() override { return oracle_.n_out();} 68 ///@} 69 70 /// @{ 71 /** \brief Sparsities of function inputs and outputs */ get_sparsity_in(casadi_int i)72 Sparsity get_sparsity_in(casadi_int i) override { return oracle_.sparsity_in(i);} get_sparsity_out(casadi_int i)73 Sparsity get_sparsity_out(casadi_int i) override { return oracle_.sparsity_out(i);} 74 /// @} 75 76 ///@{ 77 /** \brief Names of function input and outputs */ get_name_in(casadi_int i)78 std::string get_name_in(casadi_int i) override { return oracle_.name_in(i);} get_name_out(casadi_int i)79 std::string get_name_out(casadi_int i) override { return oracle_.name_out(i);} 80 /// @} 81 82 ///@{ 83 /** \brief Options */ 84 static const Options options_; get_options() const85 const Options& get_options() const override { return options_;} 86 ///@} 87 88 /// Initialize 89 void init(const Dict& opts) override; 90 91 /** \brief Initalize memory block */ 92 int init_mem(void* mem) const override; 93 94 /** \brief Set the (persistent) work vectors */ 95 void set_work(void* mem, const double**& arg, double**& res, 96 casadi_int*& iw, double*& w) const override; 97 98 // Evaluate numerically 99 int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override; 100 101 // Solve the NLP 102 virtual int solve(void* mem) const = 0; 103 104 /// Get all statistics 105 Dict get_stats(void* mem) const override; 106 107 /** \brief Propagate sparsity forward */ 108 int sp_forward(const bvec_t** arg, bvec_t** res, 109 casadi_int* iw, bvec_t* w, void* mem) const override; 110 111 /** \brief Propagate sparsity backwards */ 112 int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override; 113 114 ///@{ 115 /// Is the class able to propagate seeds through the algorithm? has_spfwd() const116 bool has_spfwd() const override { return true;} has_sprev() const117 bool has_sprev() const override { return true;} 118 ///@} 119 120 /** \brief Do the derivative functions need nondifferentiated outputs? */ uses_output() const121 bool uses_output() const override {return true;} 122 123 ///@{ 124 /** \brief Generate a function that calculates \a nfwd forward derivatives */ has_forward(casadi_int nfwd) const125 bool has_forward(casadi_int nfwd) const override { return true;} 126 Function get_forward(casadi_int nfwd, const std::string& name, 127 const std::vector<std::string>& inames, 128 const std::vector<std::string>& onames, 129 const Dict& opts) const override; 130 ///@} 131 132 ///@{ 133 /** \brief Generate a function that calculates \a nadj adjoint derivatives */ has_reverse(casadi_int nadj) const134 bool has_reverse(casadi_int nadj) const override { return true;} 135 Function get_reverse(casadi_int nadj, const std::string& name, 136 const std::vector<std::string>& inames, 137 const std::vector<std::string>& onames, 138 const Dict& opts) const override; 139 ///@} 140 141 /** \brief Create call to (cached) derivative function, forward mode */ 142 virtual void ad_forward(const std::vector<MX>& arg, const std::vector<MX>& res, 143 const std::vector<std::vector<MX> >& fseed, 144 std::vector<std::vector<MX> >& fsens, 145 bool always_inline, bool never_inline) const; 146 147 /** \brief Create call to (cached) derivative function, reverse mode */ 148 virtual void ad_reverse(const std::vector<MX>& arg, const std::vector<MX>& res, 149 const std::vector<std::vector<MX> >& aseed, 150 std::vector<std::vector<MX> >& asens, 151 bool always_inline, bool never_inline) const; 152 153 /// Number of equations 154 casadi_int n_; 155 156 /// Linear solver 157 Linsol linsol_; 158 Sparsity sp_jac_; 159 160 /// Constraints on decision variables 161 std::vector<casadi_int> u_c_; 162 163 /// Indices of the input and output that correspond to the actual root-finding 164 casadi_int iin_, iout_; 165 166 /// Throw an exception on failure? 167 bool error_on_fail_; 168 169 // Creator function for internal class 170 typedef Rootfinder* (*Creator)(const std::string& name, const Function& oracle); 171 172 // No static functions exposed 173 struct Exposed{ }; 174 175 /// Collection of solvers 176 static std::map<std::string, Plugin> solvers_; 177 178 /// Short name shortname()179 static std::string shortname() { return "rootfinder";} 180 181 /// Infix 182 static const std::string infix_; 183 184 /// Convert dictionary to Problem 185 template<typename XType> 186 static Function create_oracle(const std::map<std::string, XType>& d, 187 const Dict& opts); 188 189 /** \brief Serialize an object without type information */ 190 void serialize_body(SerializingStream &s) const override; 191 /** \brief Serialize type information */ 192 void serialize_type(SerializingStream &s) const override; 193 194 /** \brief Deserialize into MX */ 195 static ProtoFunction* deserialize(DeserializingStream& s); 196 197 /** \brief String used to identify the immediate FunctionInternal subclass */ serialize_base_function() const198 std::string serialize_base_function() const override { return "Rootfinder"; } 199 200 protected: 201 /** \brief Deserializing constructor */ 202 explicit Rootfinder(DeserializingStream& s); 203 }; 204 205 206 207 } // namespace casadi 208 /// \endcond 209 210 #endif // CASADI_ROOTFINDER_IMPL_HPP 211