1 /* 2 * This file is part of CasADi. 3 * 4 * CasADi -- A symbolic framework for dynamic optimization. 5 * Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl, 6 * K.U. Leuven. All rights reserved. 7 * Copyright (C) 2011-2014 Greg Horn 8 * 9 * CasADi is free software; you can redistribute it and/or 10 * modify it under the terms of the GNU Lesser General Public 11 * License as published by the Free Software Foundation; either 12 * version 3 of the License, or (at your option) any later version. 13 * 14 * CasADi is distributed in the hope that it will be useful, 15 * but WITHOUT ANY WARRANTY; without even the implied warranty of 16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 * Lesser General Public License for more details. 18 * 19 * You should have received a copy of the GNU Lesser General Public 20 * License along with CasADi; if not, write to the Free Software 21 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 22 * 23 */ 24 25 26 #include "transpose.hpp" 27 #include "serializing_stream.hpp" 28 29 using namespace std; 30 31 namespace casadi { 32 Transpose(const MX & x)33 Transpose::Transpose(const MX& x) { 34 set_dep(x); 35 set_sparsity(x.sparsity().T()); 36 } 37 serialize_type(SerializingStream & s) const38 void Transpose::serialize_type(SerializingStream& s) const { 39 MXNode::serialize_type(s); 40 s.pack("Transpose::dense", false); 41 } 42 serialize_type(SerializingStream & s) const43 void DenseTranspose::serialize_type(SerializingStream& s) const { 44 MXNode::serialize_type(s); // NOLINT 45 s.pack("Transpose::dense", true); 46 } 47 deserialize(DeserializingStream & s)48 MXNode* Transpose::deserialize(DeserializingStream& s) { 49 bool t; 50 s.unpack("Transpose::dense", t); 51 if (t) { 52 return new DenseTranspose(s); 53 } else { 54 return new Transpose(s); 55 } 56 } 57 eval(const double ** arg,double ** res,casadi_int * iw,double * w) const58 int Transpose::eval(const double** arg, double** res, casadi_int* iw, double* w) const { 59 return eval_gen<double>(arg, res, iw, w); 60 } 61 eval(const double ** arg,double ** res,casadi_int * iw,double * w) const62 int DenseTranspose::eval(const double** arg, double** res, casadi_int* iw, double* w) const { 63 return eval_gen<double>(arg, res, iw, w); 64 } 65 66 int Transpose:: eval_sx(const SXElem ** arg,SXElem ** res,casadi_int * iw,SXElem * w) const67 eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const { 68 return eval_gen<SXElem>(arg, res, iw, w); 69 } 70 71 int DenseTranspose:: eval_sx(const SXElem ** arg,SXElem ** res,casadi_int * iw,SXElem * w) const72 eval_sx(const SXElem** arg, SXElem** res, casadi_int* iw, SXElem* w) const { 73 return eval_gen<SXElem>(arg, res, iw, w); 74 } 75 76 template<typename T> eval_gen(const T * const * arg,T * const * res,casadi_int * iw,T * w) const77 int Transpose::eval_gen(const T* const* arg, T* const* res, 78 casadi_int* iw, T* w) const { 79 // Get sparsity patterns 80 //const vector<casadi_int>& x_colind = input[0]->colind(); 81 const casadi_int* x_row = dep(0).row(); 82 casadi_int x_sz = dep(0).nnz(); 83 const casadi_int* xT_colind = sparsity().colind(); 84 casadi_int xT_ncol = sparsity().size2(); 85 86 const T* x = arg[0]; 87 T* xT = res[0]; 88 89 // Transpose 90 copy(xT_colind, xT_colind+xT_ncol+1, iw); 91 for (casadi_int el=0; el<x_sz; ++el) { 92 xT[iw[x_row[el]]++] = x[el]; 93 } 94 return 0; 95 } 96 97 template<typename T> eval_gen(const T * const * arg,T * const * res,casadi_int * iw,T * w) const98 int DenseTranspose::eval_gen(const T* const* arg, T* const* res, 99 casadi_int* iw, T* w) const { 100 // Get sparsity patterns 101 casadi_int x_nrow = dep().size1(); 102 casadi_int x_ncol = dep().size2(); 103 104 const T* x = arg[0]; 105 T* xT = res[0]; 106 for (casadi_int i=0; i<x_ncol; ++i) { 107 for (casadi_int j=0; j<x_nrow; ++j) { 108 xT[i+j*x_ncol] = x[j+i*x_nrow]; 109 } 110 } 111 return 0; 112 } 113 114 int Transpose:: sp_forward(const bvec_t ** arg,bvec_t ** res,casadi_int * iw,bvec_t * w) const115 sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const { 116 // Shortands 117 const bvec_t *x = arg[0]; 118 bvec_t *xT = res[0]; 119 120 // Get sparsity 121 casadi_int nz = nnz(); 122 const casadi_int* x_row = dep().row(); 123 const casadi_int* xT_colind = sparsity().colind(); 124 casadi_int xT_ncol = sparsity().size2(); 125 126 // Loop over the nonzeros of the argument 127 copy(xT_colind, xT_colind+xT_ncol+1, iw); 128 for (casadi_int el=0; el<nz; ++el) { 129 xT[iw[*x_row++]++] = *x++; 130 } 131 return 0; 132 } 133 134 int Transpose:: sp_reverse(bvec_t ** arg,bvec_t ** res,casadi_int * iw,bvec_t * w) const135 sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const { 136 // Shortands 137 bvec_t *x = arg[0]; 138 bvec_t *xT = res[0]; 139 140 // Get sparsity 141 casadi_int nz = nnz(); 142 const casadi_int* x_row = dep().row(); 143 const casadi_int* xT_colind = sparsity().colind(); 144 casadi_int xT_ncol = sparsity().size2(); 145 146 // Loop over the nonzeros of the argument 147 copy(xT_colind, xT_colind+xT_ncol+1, iw); 148 for (casadi_int el=0; el<nz; ++el) { 149 casadi_int elT = iw[*x_row++]++; 150 *x++ |= xT[elT]; 151 xT[elT] = 0; 152 } 153 return 0; 154 } 155 156 int DenseTranspose:: sp_forward(const bvec_t ** arg,bvec_t ** res,casadi_int * iw,bvec_t * w) const157 sp_forward(const bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const { 158 // Shorthands 159 const bvec_t *x = arg[0]; 160 bvec_t *xT = res[0]; 161 casadi_int x_nrow = dep().size1(); 162 casadi_int x_ncol = dep().size2(); 163 164 // Loop over the elements 165 for (casadi_int rr=0; rr<x_nrow; ++rr) { 166 for (casadi_int cc=0; cc<x_ncol; ++cc) { 167 *xT++ = x[rr+cc*x_nrow]; 168 } 169 } 170 return 0; 171 } 172 173 int DenseTranspose:: sp_reverse(bvec_t ** arg,bvec_t ** res,casadi_int * iw,bvec_t * w) const174 sp_reverse(bvec_t** arg, bvec_t** res, casadi_int* iw, bvec_t* w) const { 175 // Shorthands 176 bvec_t *x = arg[0]; 177 bvec_t *xT = res[0]; 178 casadi_int x_nrow = dep().size1(); 179 casadi_int x_ncol = dep().size2(); 180 181 // Loop over the elements 182 for (casadi_int rr=0; rr<x_nrow; ++rr) { 183 for (casadi_int cc=0; cc<x_ncol; ++cc) { 184 x[rr+cc*x_nrow] |= *xT; 185 *xT++ = 0; 186 } 187 } 188 return 0; 189 } 190 disp(const std::vector<std::string> & arg) const191 std::string Transpose::disp(const std::vector<std::string>& arg) const { 192 return arg.at(0) + "'"; 193 } 194 eval_mx(const std::vector<MX> & arg,std::vector<MX> & res) const195 void Transpose::eval_mx(const std::vector<MX>& arg, std::vector<MX>& res) const { 196 res[0] = arg[0].T(); 197 } 198 ad_forward(const std::vector<std::vector<MX>> & fseed,std::vector<std::vector<MX>> & fsens) const199 void Transpose::ad_forward(const std::vector<std::vector<MX> >& fseed, 200 std::vector<std::vector<MX> >& fsens) const { 201 for (casadi_int d=0; d<fsens.size(); ++d) { 202 fsens[d][0] = fseed[d][0].T(); 203 } 204 } 205 ad_reverse(const std::vector<std::vector<MX>> & aseed,std::vector<std::vector<MX>> & asens) const206 void Transpose::ad_reverse(const std::vector<std::vector<MX> >& aseed, 207 std::vector<std::vector<MX> >& asens) const { 208 for (casadi_int d=0; d<aseed.size(); ++d) { 209 asens[d][0] += aseed[d][0].T(); 210 } 211 } 212 generate(CodeGenerator & g,const std::vector<casadi_int> & arg,const std::vector<casadi_int> & res) const213 void Transpose::generate(CodeGenerator& g, 214 const std::vector<casadi_int>& arg, 215 const std::vector<casadi_int>& res) const { 216 g << g.trans(g.work(arg[0], nnz()), dep().sparsity(), 217 g.work(res[0], nnz()), sparsity(), "iw") << ";\n"; 218 } 219 generate(CodeGenerator & g,const std::vector<casadi_int> & arg,const std::vector<casadi_int> & res) const220 void DenseTranspose::generate(CodeGenerator& g, 221 const std::vector<casadi_int>& arg, 222 const std::vector<casadi_int>& res) const { 223 g.local("cs", "const casadi_real", "*"); 224 g.local("rr", "casadi_real", "*"); 225 g.local("i", "casadi_int"); 226 g.local("j", "casadi_int"); 227 g << "for (i=0, rr=" << g.work(res[0], nnz()) << ", " 228 << "cs=" << g.work(arg[0], nnz()) << "; i<" << dep().size2() << "; ++i) " 229 << "for (j=0; j<" << dep().size1() << "; ++j) " 230 << "rr[i+j*" << dep().size2() << "] = *cs++;\n"; 231 } 232 233 } // namespace casadi 234