1 /*
2  *    This file is part of CasADi.
3  *
4  *    CasADi -- A symbolic framework for dynamic optimization.
5  *    Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl,
6  *                            K.U. Leuven. All rights reserved.
7  *    Copyright (C) 2011-2014 Greg Horn
8  *
9  *    CasADi is free software; you can redistribute it and/or
10  *    modify it under the terms of the GNU Lesser General Public
11  *    License as published by the Free Software Foundation; either
12  *    version 3 of the License, or (at your option) any later version.
13  *
14  *    CasADi is distributed in the hope that it will be useful,
15  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
16  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  *    Lesser General Public License for more details.
18  *
19  *    You should have received a copy of the GNU Lesser General Public
20  *    License along with CasADi; if not, write to the Free Software
21  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
22  *
23  */
24 
25 
26 #ifndef CASADI_BINARY_MX_IMPL_HPP
27 #define CASADI_BINARY_MX_IMPL_HPP
28 
29 #include "binary_mx.hpp"
30 #include "casadi_misc.hpp"
31 #include "global_options.hpp"
32 #include "serializing_stream.hpp"
33 #include <sstream>
34 #include <vector>
35 
36 using namespace std;
37 
38 namespace casadi {
39 
40   template<bool ScX, bool ScY>
BinaryMX(Operation op,const MX & x,const MX & y)41   BinaryMX<ScX, ScY>::BinaryMX(Operation op, const MX& x, const MX& y) : op_(op) {
42     set_dep(x, y);
43     if (ScX) {
44       set_sparsity(y.sparsity());
45     } else {
46       set_sparsity(x.sparsity());
47     }
48   }
49 
50   template<bool ScX, bool ScY>
~BinaryMX()51   BinaryMX<ScX, ScY>::~BinaryMX() {
52   }
53 
54   template<bool ScX, bool ScY>
disp(const std::vector<std::string> & arg) const55   std::string BinaryMX<ScX, ScY>::disp(const std::vector<std::string>& arg) const {
56     return casadi_math<double>::print(op_, arg.at(0), arg.at(1));
57   }
58 
59   template<bool ScX, bool ScY>
eval_mx(const std::vector<MX> & arg,std::vector<MX> & res) const60   void BinaryMX<ScX, ScY>::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const {
61     casadi_math<MX>::fun(op_, arg[0], arg[1], res[0]);
62   }
63 
64   template<bool ScX, bool ScY>
ad_forward(const std::vector<std::vector<MX>> & fseed,std::vector<std::vector<MX>> & fsens) const65   void BinaryMX<ScX, ScY>::ad_forward(const std::vector<std::vector<MX> >& fseed,
66                                    std::vector<std::vector<MX> >& fsens) const {
67     // Get partial derivatives
68     MX pd[2];
69     casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
70 
71     // Propagate forward seeds
72     for (casadi_int d=0; d<fsens.size(); ++d) {
73       fsens[d][0] = pd[0]*fseed[d][0] + pd[1]*fseed[d][1];
74     }
75   }
76 
77   template<bool ScX, bool ScY>
ad_reverse(const std::vector<std::vector<MX>> & aseed,std::vector<std::vector<MX>> & asens) const78   void BinaryMX<ScX, ScY>::ad_reverse(const std::vector<std::vector<MX> >& aseed,
79                                    std::vector<std::vector<MX> >& asens) const {
80     // Get partial derivatives
81     MX pd[2];
82     casadi_math<MX>::der(op_, dep(0), dep(1), shared_from_this<MX>(), pd);
83 
84     // Propagate adjoint seeds
85     for (casadi_int d=0; d<aseed.size(); ++d) {
86       MX s = aseed[d][0];
87       for (casadi_int c=0; c<2; ++c) {
88         // Get increment of sensitivity c
89         MX t = pd[c]*s;
90 
91         // If dimension mismatch (i.e. one argument is scalar), then sum all the entries
92         if (!t.is_scalar() && t.size() != dep(c).size()) {
93           if (pd[c].size()!=s.size()) pd[c] = MX(s.sparsity(), pd[c]);
94           t = dot(pd[c], s);
95         }
96 
97         // Propagate the seeds
98         asens[d][c] += t;
99       }
100     }
101   }
102 
103   template<bool ScX, bool ScY>
104   void BinaryMX<ScX, ScY>::
generate(CodeGenerator & g,const std::vector<casadi_int> & arg,const std::vector<casadi_int> & res) const105   generate(CodeGenerator& g,
106            const std::vector<casadi_int>& arg, const std::vector<casadi_int>& res) const {
107     // Quick return if nothing to do
108     if (nnz()==0) return;
109 
110     // Check if inplace
111     bool inplace;
112     switch (op_) {
113     case OP_ADD:
114     case OP_SUB:
115     case OP_MUL:
116     case OP_DIV:
117       inplace = res[0]==arg[0];
118       break;
119     default:
120       inplace = false;
121       break;
122     }
123 
124     // Scalar names of arguments (start assuming all scalars)
125     string r = g.workel(res[0]);
126     string x = g.workel(arg[0]);
127     string y = g.workel(arg[1]);
128 
129     // Avoid emitting '/*' which will be mistaken for a comment
130     if (op_==OP_DIV && g.codegen_scalars && dep(1).nnz()==1) {
131       y = "(" + y + ")";
132     }
133 
134     // Codegen loop, if needed
135     if (nnz()>1) {
136       // Iterate over result
137       g.local("rr", "casadi_real", "*");
138       g.local("i", "casadi_int");
139       g << "for (i=0, " << "rr=" << g.work(res[0], nnz());
140       r = "(*rr++)";
141 
142       // Iterate over first argument?
143       if (!ScX && !inplace) {
144         g.local("cr", "const casadi_real", "*");
145         g << ", cr=" << g.work(arg[0], dep(0).nnz());
146         if (op_==OP_OR || op_==OP_AND) {
147           // Avoid short-circuiting with side effects
148           x = "cr[i]";
149         } else {
150           x = "(*cr++)";
151         }
152 
153       }
154 
155       // Iterate over second argument?
156       if (!ScY) {
157         g.local("cs", "const casadi_real", "*");
158         g << ", cs=" << g.work(arg[1], dep(1).nnz());
159         if (op_==OP_OR || op_==OP_AND) {
160           // Avoid short-circuiting with side effects
161           y = "cs[i]";
162         } else {
163           y = "(*cs++)";
164         }
165       }
166 
167       // Close loop
168       g << "; i<" << nnz() << "; ++i) ";
169     }
170 
171     // Perform operation
172     g << r << " ";
173     if (inplace) {
174       g << casadi_math<double>::sep(op_) << "= " << y;
175     } else {
176       g << " = " << g.print_op(op_, x, y);
177     }
178     g << ";\n";
179   }
180 
181   template<bool ScX, bool ScY>
182   int BinaryMX<ScX, ScY>::
eval(const double ** arg,double ** res,casadi_int * iw,double * w) const183   eval(const double** arg, double** res, casadi_int* iw, double* w) const {
184     return eval_gen<double>(arg, res, iw, w);
185   }
186 
187   template<bool ScX, bool ScY>
188   int BinaryMX<ScX, ScY>::
eval_sx(const SXElem ** arg,SXElem ** res,casadi_int * iw,SXElem * w) const189   eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const {
190     return eval_gen<SXElem>(arg, res, iw, w);
191   }
192 
193   template<bool ScX, bool ScY>
194   template<typename T>
195   int BinaryMX<ScX, ScY>::
eval_gen(const T * const * arg,T * const * res,casadi_int * iw,T * w) const196   eval_gen(const T* const* arg, T* const* res, casadi_int* iw, T* w) const {
197     // Get data
198     T* output0 = res[0];
199     const T* input0 = arg[0];
200     const T* input1 = arg[1];
201 
202     if (!ScX && !ScY) {
203       casadi_math<T>::fun(op_, input0, input1, output0, nnz());
204     } else if (ScX) {
205       casadi_math<T>::fun(op_, *input0, input1, output0, nnz());
206     } else {
207       casadi_math<T>::fun(op_, input0, *input1, output0, nnz());
208     }
209     return 0;
210   }
211 
212   template<bool ScX, bool ScY>
213   int BinaryMX<ScX, ScY>::
sp_forward(const bvec_t ** arg,bvec_t ** res,casadi_int * iw,bvec_t * w) const214   sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
215     const bvec_t *a0=arg[0], *a1=arg[1];
216     bvec_t *r=res[0];
217     casadi_int n=nnz();
218     for (casadi_int i=0; i<n; ++i) {
219       if (ScX && ScY)
220         *r++ = *a0 | *a1;
221       else if (ScX && !ScY)
222         *r++ = *a0 | *a1++;
223       else if (!ScX && ScY)
224         *r++ = *a0++ | *a1;
225       else
226         *r++ = *a0++ | *a1++;
227     }
228     return 0;
229   }
230 
231   template<bool ScX, bool ScY>
232   int BinaryMX<ScX, ScY>::
sp_reverse(bvec_t ** arg,bvec_t ** res,casadi_int * iw,bvec_t * w) const233   sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const {
234     bvec_t *a0=arg[0], *a1=arg[1], *r = res[0];
235     casadi_int n=nnz();
236     for (casadi_int i=0; i<n; ++i) {
237       bvec_t s = *r;
238       *r++ = 0;
239       if (ScX)
240         *a0 |= s;
241       else
242         *a0++ |= s;
243       if (ScY)
244         *a1 |= s;
245       else
246         *a1++ |= s;
247     }
248     return 0;
249   }
250 
251   template<bool ScX, bool ScY>
get_unary(casadi_int op) const252   MX BinaryMX<ScX, ScY>::get_unary(casadi_int op) const {
253     //switch (op_) {
254     //default: break; // no rule
255     //}
256 
257     // Fallback to default implementation
258     return MXNode::get_unary(op);
259   }
260 
261   template<bool ScX, bool ScY>
_get_binary(casadi_int op,const MX & y,bool scX,bool scY) const262   MX BinaryMX<ScX, ScY>::_get_binary(casadi_int op, const MX& y, bool scX, bool scY) const {
263     if (!GlobalOptions::simplification_on_the_fly) return MXNode::_get_binary(op, y, scX, scY);
264 
265     switch (op_) {
266     case OP_ADD:
267       if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return dep(1);
268       if (op==OP_SUB && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
269       break;
270     case OP_SUB:
271       if (op==OP_SUB && MX::is_equal(y, dep(0), maxDepth())) return -dep(1);
272       if (op==OP_ADD && MX::is_equal(y, dep(1), maxDepth())) return dep(0);
273       break;
274     default: break; // no rule
275     }
276 
277     // Fallback to default implementation
278     return MXNode::_get_binary(op, y, scX, scY);
279   }
280 
281   template<bool ScX, bool ScY>
serialize_body(SerializingStream & s) const282   void BinaryMX<ScX, ScY>::serialize_body(SerializingStream& s) const {
283     MXNode::serialize_body(s);
284     s.pack("BinaryMX::op", static_cast<int>(op_));
285   }
286 
287   template<bool ScX, bool ScY>
serialize_type(SerializingStream & s) const288   void BinaryMX<ScX, ScY>::serialize_type(SerializingStream& s) const {
289     MXNode::serialize_type(s);
290     char type_x = ScX;
291     char type_y = ScY;
292     char type = type_x | (type_y << 1);
293     s.pack("BinaryMX::scalar_flags", type);
294   }
295 
296   template<bool ScX, bool ScY>
deserialize(DeserializingStream & s)297   MXNode* BinaryMX<ScX, ScY>::deserialize(DeserializingStream& s) {
298     char t;
299     s.unpack("BinaryMX::scalar_flags", t);
300     bool scX = t & 1;
301     bool scY = t & 2;
302 
303     if (scX) {
304       if (scY) return new BinaryMX<true, true>(s);
305       return new BinaryMX<true, false>(s);
306     } else {
307       if (scY) return new BinaryMX<false, true>(s);
308       return new BinaryMX<false, false>(s);
309     }
310   }
311 
312   template<bool ScX, bool ScY>
BinaryMX(DeserializingStream & s)313   BinaryMX<ScX, ScY>::BinaryMX(DeserializingStream& s) : MXNode(s) {
314     int op;
315     s.unpack("BinaryMX::op", op);
316     op_ = Operation(op);
317   }
318 
319 } // namespace casadi
320 
321 #endif // CASADI_BINARY_MX_IMPL_HPP
322