1 /*
2  *    This file is part of CasADi.
3  *
4  *    CasADi -- A symbolic framework for dynamic optimization.
5  *    Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6  *                            K.U. Leuven. All rights reserved.
7  *    Copyright (C) 2011-2014 Greg Horn
8  *
9  *    CasADi is free software; you can redistribute it and/or
10  *    modify it under the terms of the GNU Lesser General Public
11  *    License as published by the Free Software Foundation; either
12  *    version 3 of the License, or (at your option) any later version.
13  *
14  *    CasADi is distributed in the hope that it will be useful,
15  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *    Lesser General Public License for more details.
18  *
19  *    You should have received a copy of the GNU Lesser General Public
20  *    License along with CasADi; if not, write to the Free Software
21  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22  *
23  */
24 
25 
26 #ifndef CASADI_ROOTFINDER_IMPL_HPP
27 #define CASADI_ROOTFINDER_IMPL_HPP
28 
29 #include "rootfinder.hpp"
30 #include "oracle_function.hpp"
31 #include "plugin_interface.hpp"
32 
33 /// \cond INTERNAL
34 namespace casadi {
35 
36   /** \brief Integrator memory */
37   struct CASADI_EXPORT RootfinderMemory : public OracleMemory {
38     // Inputs
39     const double** iarg;
40 
41     // Outputs
42     double** ires;
43 
44     // Success?
45     bool success;
46 
47     // Return status
48     FunctionInternal::UnifiedReturnStatus unified_return_status;
49   };
50 
51   /// Internal class
52   class CASADI_EXPORT
53   Rootfinder : public OracleFunction, public PluginInterface<Rootfinder> {
54   public:
55     /** \brief Constructor
56      *
57      * \param f   Function mapping from (n+1) inputs to 1 output.
58      */
59     Rootfinder(const std::string& name, const Function& oracle);
60 
61     /// Destructor
62     ~Rootfinder() override = 0;
63 
64     ///@{
65     /** \brief Number of function inputs and outputs */
get_n_in()66     size_t get_n_in() override { return oracle_.n_in();}
get_n_out()67     size_t get_n_out() override { return oracle_.n_out();}
68     ///@}
69 
70     /// @{
71     /** \brief Sparsities of function inputs and outputs */
get_sparsity_in(casadi_int i)72     Sparsity get_sparsity_in(casadi_int i) override { return oracle_.sparsity_in(i);}
get_sparsity_out(casadi_int i)73     Sparsity get_sparsity_out(casadi_int i) override { return oracle_.sparsity_out(i);}
74     /// @}
75 
76     ///@{
77     /** \brief Names of function input and outputs */
get_name_in(casadi_int i)78     std::string get_name_in(casadi_int i) override { return oracle_.name_in(i);}
get_name_out(casadi_int i)79     std::string get_name_out(casadi_int i) override { return oracle_.name_out(i);}
80     /// @}
81 
82     ///@{
83     /** \brief Options */
84     static const Options options_;
get_options() const85     const Options& get_options() const override { return options_;}
86     ///@}
87 
88     /// Initialize
89     void init(const Dict& opts) override;
90 
91     /** \brief Initalize memory block */
92     int init_mem(void* mem) const override;
93 
94     /** \brief Set the (persistent) work vectors */
95     void set_work(void* mem, const double**& arg, double**& res,
96                           casadi_int*& iw, double*& w) const override;
97 
98     // Evaluate numerically
99     int eval(const double** arg, double** res, casadi_int* iw, double* w, void* mem) const override;
100 
101     // Solve the NLP
102     virtual int solve(void* mem) const = 0;
103 
104     /// Get all statistics
105     Dict get_stats(void* mem) const override;
106 
107     /** \brief  Propagate sparsity forward */
108     int sp_forward(const bvec_t** arg, bvec_t** res,
109                     casadi_int* iw, bvec_t* w, void* mem) const override;
110 
111     /** \brief  Propagate sparsity backwards */
112     int sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w, void* mem) const override;
113 
114     ///@{
115     /// Is the class able to propagate seeds through the algorithm?
has_spfwd() const116     bool has_spfwd() const override { return true;}
has_sprev() const117     bool has_sprev() const override { return true;}
118     ///@}
119 
120     /** \brief Do the derivative functions need nondifferentiated outputs? */
uses_output() const121     bool uses_output() const override {return true;}
122 
123     ///@{
124     /** \brief Generate a function that calculates \a nfwd forward derivatives */
has_forward(casadi_int nfwd) const125     bool has_forward(casadi_int nfwd) const override { return true;}
126     Function get_forward(casadi_int nfwd, const std::string& name,
127                          const std::vector<std::string>& inames,
128                          const std::vector<std::string>& onames,
129                          const Dict& opts) const override;
130     ///@}
131 
132     ///@{
133     /** \brief Generate a function that calculates \a nadj adjoint derivatives */
has_reverse(casadi_int nadj) const134     bool has_reverse(casadi_int nadj) const override { return true;}
135     Function get_reverse(casadi_int nadj, const std::string& name,
136                          const std::vector<std::string>& inames,
137                          const std::vector<std::string>& onames,
138                          const Dict& opts) const override;
139     ///@}
140 
141     /** \brief Create call to (cached) derivative function, forward mode  */
142     virtual void ad_forward(const std::vector<MX>& arg, const std::vector<MX>& res,
143                          const std::vector<std::vector<MX> >& fseed,
144                          std::vector<std::vector<MX> >& fsens,
145                          bool always_inline, bool never_inline) const;
146 
147     /** \brief Create call to (cached) derivative function, reverse mode  */
148     virtual void ad_reverse(const std::vector<MX>& arg, const std::vector<MX>& res,
149                          const std::vector<std::vector<MX> >& aseed,
150                          std::vector<std::vector<MX> >& asens,
151                          bool always_inline, bool never_inline) const;
152 
153     /// Number of equations
154     casadi_int n_;
155 
156     /// Linear solver
157     Linsol linsol_;
158     Sparsity sp_jac_;
159 
160     /// Constraints on decision variables
161     std::vector<casadi_int> u_c_;
162 
163     /// Indices of the input and output that correspond to the actual root-finding
164     casadi_int iin_, iout_;
165 
166     /// Throw an exception on failure?
167     bool error_on_fail_;
168 
169     // Creator function for internal class
170     typedef Rootfinder* (*Creator)(const std::string& name, const Function& oracle);
171 
172     // No static functions exposed
173     struct Exposed{ };
174 
175     /// Collection of solvers
176     static std::map<std::string, Plugin> solvers_;
177 
178     /// Short name
shortname()179     static std::string shortname() { return "rootfinder";}
180 
181     /// Infix
182     static const std::string infix_;
183 
184     /// Convert dictionary to Problem
185     template<typename XType>
186       static Function create_oracle(const std::map<std::string, XType>& d,
187                                     const Dict& opts);
188 
189     /** \brief Serialize an object without type information */
190     void serialize_body(SerializingStream &s) const override;
191     /** \brief Serialize type information */
192     void serialize_type(SerializingStream &s) const override;
193 
194     /** \brief Deserialize into MX */
195     static ProtoFunction* deserialize(DeserializingStream& s);
196 
197     /** \brief String used to identify the immediate FunctionInternal subclass */
serialize_base_function() const198     std::string serialize_base_function() const override { return "Rootfinder"; }
199 
200   protected:
201     /** \brief Deserializing constructor */
202     explicit Rootfinder(DeserializingStream& s);
203   };
204 
205 
206 
207 } // namespace casadi
208 /// \endcond
209 
210 #endif // CASADI_ROOTFINDER_IMPL_HPP
211