1 /*
2  *    This file is part of CasADi.
3  *
4  *    CasADi -- A symbolic framework for dynamic optimization.
5  *    Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6  *                            K.U. Leuven. All rights reserved.
7  *    Copyright (C) 2011-2014 Greg Horn
8  *
9  *    CasADi is free software; you can redistribute it and/or
10  *    modify it under the terms of the GNU Lesser General Public
11  *    License as published by the Free Software Foundation; either
12  *    version 3 of the License, or (at your option) any later version.
13  *
14  *    CasADi is distributed in the hope that it will be useful,
15  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *    Lesser General Public License for more details.
18  *
19  *    You should have received a copy of the GNU Lesser General Public
20  *    License along with CasADi; if not, write to the Free Software
21  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22  *
23  */
24 
25 
26 #include "dple_impl.hpp"
27 #include <typeinfo>
28 
29 using namespace std;
30 namespace casadi {
31 
has_dple(const string & name)32   bool has_dple(const string& name) {
33     return Dple::has_plugin(name);
34   }
35 
load_dple(const string & name)36   void load_dple(const string& name) {
37     Dple::load_plugin(name);
38   }
39 
doc_dple(const string & name)40   string doc_dple(const string& name) {
41     return Dple::getPlugin(name).doc;
42   }
43 
dplesol(const MX & A,const MX & V,const std::string & solver,const Dict & opts)44   MX dplesol(const MX& A, const MX& V, const std::string& solver, const Dict& opts) {
45     SpDict sp;
46     sp["a"] = A.sparsity();
47     sp["v"] = V.sparsity();
48     Function f = dplesol("dplesol", solver, sp, opts);
49     MXDict f_in;
50     f_in["a"] = A;
51     f_in["v"] = V;
52     MXDict f_out = f(f_in);
53     return f_out["p"];
54   }
55 
dplesol(const MXVector & A,const MXVector & V,const std::string & solver,const Dict & opts)56   CASADI_EXPORT MXVector dplesol(const MXVector& A, const MXVector& V, const std::string& solver,
57     const Dict& opts) {
58       casadi_assert(A.size()==V.size(),
59         "dplesol: sizes of A vector (" + str(A.size()) + ") and V vector "
60         "(" + str(V.size()) + ") must match.");
61       std::vector<MX> Adense, Vdense;
62 
63       for (casadi_int i=0;i<A.size();++i) {
64         Adense.push_back(densify(A[i]));
65         Vdense.push_back(densify(V[i]));
66       }
67 
68       MX ret = dplesol(diagcat(Adense), diagcat(Vdense), solver, opts);
69       return diagsplit(ret, ret.size1()/A.size());
70   }
71 
dplesol(const DMVector & A,const DMVector & V,const std::string & solver,const Dict & opts)72   CASADI_EXPORT DMVector dplesol(const DMVector& A, const DMVector& V, const std::string& solver,
73     const Dict& opts) {
74       casadi_assert(A.size()==V.size(),
75         "dplesol: sizes of A vector (" + str(A.size()) + ") and V vector "
76         "(" + str(V.size()) + ") must match.");
77       std::vector<DM> Adense, Vdense;
78 
79       for (casadi_int i=0;i<A.size();++i) {
80         Adense.push_back(densify(A[i]));
81         Vdense.push_back(densify(V[i]));
82       }
83 
84       DM Afull = diagcat(Adense);
85       DM Vfull = diagcat(Vdense);
86 
87       SpDict sp;
88       sp["a"] = Afull.sparsity();
89       sp["v"] = Vfull.sparsity();
90       Function f = dplesol("dplesol", solver, sp, opts);
91       DMDict f_in;
92       f_in["a"] = Afull;
93       f_in["v"] = Vfull;
94       DMDict f_out = f(f_in);
95       return diagsplit(f_out["p"], f_out["p"].size1()/A.size());
96   }
97 
dplesol(const string & name,const string & solver,const SpDict & st,const Dict & opts)98   Function dplesol(const string& name, const string& solver,
99                 const SpDict& st, const Dict& opts) {
100     return Function::create(Dple::instantiate(name, solver, st), opts);
101   }
102 
dple_in()103   vector<string> dple_in() {
104     vector<string> ret(dple_n_in());
105     for (size_t i=0; i<ret.size(); ++i) ret[i]=dple_in(i);
106     return ret;
107   }
108 
dple_out()109   vector<string> dple_out() {
110     vector<string> ret(dple_n_out());
111     for (size_t i=0; i<ret.size(); ++i) ret[i]=dple_out(i);
112     return ret;
113   }
114 
dple_in(casadi_int ind)115   string dple_in(casadi_int ind) {
116     switch (static_cast<DpleInput>(ind)) {
117     case DPLE_A:      return "a";
118     case DPLE_V:      return "v";
119     case DPLE_NUM_IN: break;
120     }
121     return string();
122   }
123 
dple_out(casadi_int ind)124   string dple_out(casadi_int ind) {
125     switch (static_cast<DpleOutput>(ind)) {
126       case DPLE_P:      return "p";
127       case DPLE_NUM_OUT: break;
128     }
129     return string();
130   }
131 
dple_n_in()132   casadi_int dple_n_in() {
133     return DPLE_NUM_IN;
134   }
135 
dple_n_out()136   casadi_int dple_n_out() {
137     return DPLE_NUM_OUT;
138   }
139 
140   // Constructor
Dple(const std::string & name,const SpDict & st)141   Dple::Dple(const std::string& name, const SpDict &st)
142     : FunctionInternal(name) {
143     for (auto i=st.begin(); i!=st.end(); ++i) {
144       if (i->first=="a") {
145         A_ = i->second;
146       } else if (i->first=="v") {
147         V_ = i->second;
148       } else {
149         casadi_error("Unrecognized field in Dple structure: " + str(i->first));
150       }
151     }
152 
153   }
154 
get_sparsity_in(casadi_int i)155   Sparsity Dple::get_sparsity_in(casadi_int i) {
156     switch (static_cast<DpleInput>(i)) {
157       case DPLE_A:
158         return A_;
159       case DPLE_V:
160         return V_;
161       case DPLE_NUM_IN: break;
162     }
163     return Sparsity();
164   }
165 
get_sparsity_out(casadi_int i)166   Sparsity Dple::get_sparsity_out(casadi_int i) {
167     switch (static_cast<DpleOutput>(i)) {
168       case DPLE_P:
169         return V_;
170       case DPLE_NUM_OUT: break;
171     }
172     return Sparsity();
173   }
174 
175   const Options Dple::options_
176   = {{&FunctionInternal::options_},
177      {{"const_dim",
178        {OT_BOOL,
179         "Assume constant dimension of P"}},
180       {"pos_def",
181         {OT_BOOL,
182          "Assume P positive definite"}},
183       {"error_unstable",
184         {OT_BOOL,
185         "Throw an exception when it is detected that Product(A_i, i=N..1)"
186         "has eigenvalues greater than 1-eps_unstable"}},
187       {"eps_unstable",
188         {OT_DOUBLE,
189         "A margin for unstability detection"}}
190      }
191   };
192 
init(const Dict & opts)193   void Dple::init(const Dict& opts) {
194     // Call the init method of the base class
195     FunctionInternal::init(opts);
196 
197     // Default options
198     const_dim_ = true;
199     pos_def_ = false;
200     error_unstable_ = false;
201     eps_unstable_ = 1e-4;
202 
203     // Read options
204     for (auto&& op : opts) {
205       if (op.first=="const_dim") {
206         const_dim_ = op.second;
207       } else if  (op.first=="pos_def") {
208         pos_def_ = op.second;
209       } else if  (op.first=="error_unstable") {
210         error_unstable_ = op.second;
211       } else if  (op.first=="eps_unstable") {
212         eps_unstable_ = op.second;
213       }
214     }
215 
216     casadi_assert_dev(V_.size2() % V_.size1() == 0);
217     nrhs_ = V_.size2() / V_.size1();
218     casadi_assert_dev(nrhs_>=1);
219 
220     std::vector<Sparsity> Vs = horzsplit(V_, V_.size1());
221     Sparsity Vref = Vs[0];
222     casadi_assert(Vref.is_symmetric(),
223       "V must be symmetric but got " + Vref.dim() + ".");
224 
225     for (auto&& s : Vs)
226       casadi_assert_dev(s==Vref);
227 
228     casadi_assert(const_dim_, "Not implemented");
229 
230     casadi_int blocksize = Vref.colind()[1];
231     K_ = Vref.size1()/blocksize;
232     Sparsity block = Sparsity::dense(blocksize, blocksize);
233 
234     std::vector<Sparsity> blocks(K_, block);
235     casadi_assert(Vref==diagcat(blocks), "Structure not recognised.");
236     casadi_assert(A_==Vref, "Structure not recognised.");
237 
238 
239   }
240 
get_forward(casadi_int nfwd,const std::string & name,const std::vector<std::string> & inames,const std::vector<std::string> & onames,const Dict & opts) const241   Function Dple::get_forward(casadi_int nfwd, const std::string& name,
242                                const std::vector<std::string>& inames,
243                                const std::vector<std::string>& onames,
244                                const Dict& opts) const {
245     // Symbolic A
246     MX A = MX::sym("A", A_);
247     Function Vdotf;
248     {
249       MX P = MX::sym("P", A_);
250       MX Adot = MX::sym("P", A_);
251       MX Vdot = MX::sym("P", A_);
252 
253       MX temp = mtimes(std::vector<MX>{Adot, P, A.T()}) +
254                 mtimes(std::vector<MX>{A, P, Adot.T()}) + Vdot;
255       Vdotf = Function("PAVbar", {A, P, Adot, Vdot},
256                 { (temp+temp.T())/2});
257     }
258 
259     MX P = MX::sym("P", V_);
260     MX Adot = MX::sym("Adot", repmat(A_, 1, nfwd));
261     MX Vdot = MX::sym("Vdot", repmat(V_, 1, nfwd));
262     MX Qdot = Vdotf.map("map", "serial", nrhs_, {0, 2}, std::vector<casadi_int>{})
263          .map("map", "serial", nfwd, {0, 1}, std::vector<casadi_int>{})({A, P, Adot, Vdot})[0];
264     MX Pdot = dplesol(A, Qdot, plugin_name(), opts);
265     MX V = MX::sym("V", Sparsity(size_in(DPLE_V))); // We dont need V
266     return Function(name, {A, V, P, Adot, Vdot}, {Pdot}, inames, onames);
267 
268   }
269 
get_reverse(casadi_int nadj,const std::string & name,const std::vector<std::string> & inames,const std::vector<std::string> & onames,const Dict & opts) const270   Function Dple::get_reverse(casadi_int nadj, const std::string& name,
271                                const std::vector<std::string>& inames,
272                                const std::vector<std::string>& onames,
273                                const Dict& opts) const {
274 
275     // Symbolic A
276     MX A = MX::sym("A", A_);
277 
278     // Helper function to reverse, reverse-tranpose,
279     // and reverse-symmetrize one block-diagonal matrix
280     casadi_int n = A_.size1()/K_;
281     std::vector<MX> ret = diagsplit(A, n);
282     std::reverse(ret.begin(), ret.end());
283     std::vector<MX> retT;
284     std::vector<MX> retS;
285     for (auto & e : ret) retT.push_back(e.T());
286     for (auto & e : ret) retS.push_back((e+e.T())/2);
287     Function revS = Function("revS", {A}, {diagcat(retS)});
288     Function revT = Function("revT", {A}, {diagcat(retT)});
289     Function rev  = Function("rev", {A}, {diagcat(ret)});
290 
291     // Function to compute the formula for Abar
292     Function Abarf;
293     {
294       MX P = MX::sym("P", A_);
295       MX Vbar_rev = MX::sym("Vbar", A_);
296       MX A_rev = MX::sym("A", A_);
297 
298       Abarf = Function("PAVbar", {P, A_rev, Vbar_rev},
299                 {2*revT(mtimes(std::vector<MX>{rev(P)[0], A_rev, Vbar_rev}))[0]});
300     }
301 
302     // original output
303     MX P = MX::sym("P", V_);
304 
305     // Symbolic reverse seed for P
306     MX Pbar = MX::sym("Pbar", repmat(V_, 1, nadj));
307     // Symmetrize the seed
308     MX Pbar_rev = revS.map(nrhs_).map(nadj)(Pbar)[0];
309 
310     // Reverse A for new dple
311     MX A_rev = revT(A)[0];
312 
313     // Solver a dple with nrhs*nadj right-hand sides
314     MX Vbar_rev = dplesol(A_rev, Pbar_rev, plugin_name(), opts);
315 
316     // Undo the reversal for Vbar
317     MX Vbar = rev.map(nrhs_).map(nadj)(Vbar_rev)[0];
318 
319     MX Abar = Abarf.map("map", "serial", nrhs_, std::vector<casadi_int>{1}, {0}).
320                     map("map", "serial", nadj, {0, 1}, std::vector<casadi_int>{})(
321                       {P, A_rev, Vbar_rev})[0];
322 
323     MX V = MX::sym("V", Sparsity(size_in(DPLE_V))); // We dont need V
324     return Function(name, {A, V, P, Pbar}, {Abar, Vbar}, inames, onames);
325   }
326 
~Dple()327   Dple::~Dple() {
328   }
329 
330   std::map<std::string, Dple::Plugin> Dple::solvers_;
331 
332   const std::string Dple::infix_ = "dple";
333 
334 } // namespace casadi
335