1 /* 2 * This file is part of CasADi. 3 * 4 * CasADi -- A symbolic framework for dynamic optimization. 5 * Copyright (C) 2010-2014 Joel Andersson, Joris Gillis, Moritz Diehl, 6 * K.U. Leuven. All rights reserved. 7 * Copyright (C) 2011-2014 Greg Horn 8 * 9 * CasADi is free software; you can redistribute it and/or 10 * modify it under the terms of the GNU Lesser General Public 11 * License as published by the Free Software Foundation; either 12 * version 3 of the License, or (at your option) any later version. 13 * 14 * CasADi is distributed in the hope that it will be useful, 15 * but WITHOUT ANY WARRANTY; without even the implied warranty of 16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 * Lesser General Public License for more details. 18 * 19 * You should have received a copy of the GNU Lesser General Public 20 * License along with CasADi; if not, write to the Free Software 21 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 22 * 23 */ 24 25 26 #include "interpolant_impl.hpp" 27 #include "casadi_misc.hpp" 28 #include "mx_node.hpp" 29 #include "casadi_low.hpp" 30 #include <typeinfo> 31 32 using namespace std; 33 namespace casadi { 34 has_interpolant(const string & name)35 bool has_interpolant(const string& name) { 36 return Interpolant::has_plugin(name); 37 } 38 load_interpolant(const string & name)39 void load_interpolant(const string& name) { 40 Interpolant::load_plugin(name); 41 } 42 doc_interpolant(const string & name)43 string doc_interpolant(const string& name) { 44 return Interpolant::getPlugin(name).doc; 45 } 46 stack_grid(const std::vector<std::vector<double>> & grid,std::vector<casadi_int> & offset,std::vector<double> & stacked)47 void Interpolant::stack_grid(const std::vector< std::vector<double> >& grid, 48 std::vector<casadi_int>& offset, std::vector<double>& stacked) { 49 50 // Get offset for each input dimension 51 offset.clear(); 52 offset.reserve(grid.size()+1); 53 offset.push_back(0); 54 for (auto&& g : grid) offset.push_back(offset.back()+g.size()); 55 56 // Stack input grids 57 stacked.clear(); 58 stacked.reserve(offset.back()); 59 for (auto&& g : grid) stacked.insert(stacked.end(), g.begin(), g.end()); 60 } 61 check_grid(const std::vector<std::vector<double>> & grid)62 void Interpolant::check_grid(const std::vector< std::vector<double> >& grid) { 63 // Dimension at least 1 64 casadi_assert(!grid.empty(), "At least one input required"); 65 66 // Grid must be strictly increasing 67 for (auto&& g : grid) { 68 casadi_assert(is_increasing(g), "Gridpoints must be strictly increasing"); 69 casadi_assert(is_regular(g), "Gridpoints must be regular"); 70 casadi_assert(g.size()>=2, "Need at least two grid points for every input"); 71 } 72 } 73 check_grid(const std::vector<casadi_int> & grid_dims)74 void Interpolant::check_grid(const std::vector<casadi_int> & grid_dims) { 75 // Dimension at least 1 76 casadi_assert(!grid_dims.empty(), "At least one dimension required"); 77 78 // Grid must be strictly increasing 79 for (casadi_int d : grid_dims) { 80 casadi_assert(d>=2, "Need at least two grid points for every input"); 81 } 82 } 83 meshgrid(const std::vector<std::vector<double>> & grid)84 std::vector<double> Interpolant::meshgrid(const std::vector< std::vector<double> >& grid) { 85 std::vector<casadi_int> cnts(grid.size()+1, 0); 86 std::vector<casadi_int> sizes(grid.size(), 0); 87 for (casadi_int k=0;k<grid.size();++k) sizes[k]= grid[k].size(); 88 89 casadi_int total_iter = 1; 90 for (casadi_int k=0;k<grid.size();++k) total_iter*= sizes[k]; 91 92 casadi_int n_dims = grid.size(); 93 94 std::vector<double> ret(total_iter*n_dims); 95 for (casadi_int i=0;i<total_iter;++i) { 96 97 for (casadi_int j=0;j<grid.size();++j) { 98 ret[i*n_dims+j] = grid[j][cnts[j]]; 99 } 100 101 cnts[0]++; 102 casadi_int j = 0; 103 while (j<n_dims && cnts[j]==sizes[j]) { 104 cnts[j] = 0; 105 j++; 106 cnts[j]++; 107 } 108 109 } 110 111 return ret; 112 } 113 coeff_size() const114 casadi_int Interpolant::coeff_size() const { 115 return coeff_size(offset_, m_); 116 } 117 coeff_size(const std::vector<casadi_int> & offset,casadi_int m)118 casadi_int Interpolant::coeff_size(const std::vector<casadi_int>& offset, casadi_int m) { 119 casadi_int ret = 1; 120 for (casadi_int k=0;k<offset.size()-1;++k) { 121 ret *= offset[k+1]-offset[k]; 122 } 123 return m*ret; 124 } 125 interpolant(const std::string & name,const std::string & solver,const std::vector<std::vector<double>> & grid,const std::vector<double> & values,const Dict & opts)126 Function interpolant(const std::string& name, 127 const std::string& solver, 128 const std::vector<std::vector<double> >& grid, 129 const std::vector<double>& values, 130 const Dict& opts) { 131 Interpolant::check_grid(grid); 132 // Get offset for each input dimension 133 vector<casadi_int> offset; 134 // Stack input grids 135 vector<double> stacked; 136 137 // Consistency check, number of elements 138 casadi_uint nel=1; 139 for (auto&& g : grid) nel *= g.size(); 140 casadi_assert(values.size() % nel== 0, 141 "Inconsistent number of elements. Must be a multiple of " + 142 str(nel) + ", but got " + str(values.size()) + " instead."); 143 144 Interpolant::stack_grid(grid, offset, stacked); 145 146 casadi_int m = values.size()/nel; 147 return Interpolant::construct(solver, name, stacked, offset, values, m, opts); 148 } 149 construct(const std::string & solver,const std::string & name,const std::vector<double> & grid,const std::vector<casadi_int> & offset,const std::vector<double> & values,casadi_int m,const Dict & opts)150 Function Interpolant::construct(const std::string& solver, 151 const std::string& name, 152 const std::vector<double>& grid, 153 const std::vector<casadi_int>& offset, 154 const std::vector<double>& values, 155 casadi_int m, 156 const Dict& opts) { 157 bool do_inline = false; 158 Dict options = extract_from_dict(opts, "inline", do_inline); 159 if (do_inline && !Interpolant::getPlugin(solver).exposed.do_inline) { 160 options["inline"] = true; 161 do_inline = false; 162 } 163 if (do_inline && Interpolant::getPlugin(solver).exposed.do_inline) { 164 return Interpolant::getPlugin(solver).exposed. 165 do_inline(name, grid, offset, values, m, options); 166 } else { 167 return Function::create(Interpolant::getPlugin(solver) 168 .creator(name, grid, offset, values, m), options); 169 } 170 } 171 interpolant(const std::string & name,const std::string & solver,const std::vector<casadi_int> & grid_dims,const std::vector<double> & values,const Dict & opts)172 Function interpolant(const std::string& name, 173 const std::string& solver, 174 const std::vector<casadi_int>& grid_dims, 175 const std::vector<double>& values, 176 const Dict& opts) { 177 Interpolant::check_grid(grid_dims); 178 179 // Consistency check, number of elements 180 casadi_uint nel = product(grid_dims); 181 casadi_assert(values.size() % nel== 0, 182 "Inconsistent number of elements. Must be a multiple of " + 183 str(nel) + ", but got " + str(values.size()) + " instead."); 184 185 casadi_int m = values.size()/nel; 186 return Interpolant::construct(solver, name, std::vector<double>{}, 187 cumsum0(grid_dims), values, m, opts); 188 } 189 interpolant(const std::string & name,const std::string & solver,const std::vector<std::vector<double>> & grid,casadi_int m,const Dict & opts)190 Function interpolant(const std::string& name, 191 const std::string& solver, 192 const std::vector<std::vector<double> >& grid, 193 casadi_int m, 194 const Dict& opts) { 195 Interpolant::check_grid(grid); 196 197 // Get offset for each input dimension 198 vector<casadi_int> offset; 199 // Stack input grids 200 vector<double> stacked; 201 202 Interpolant::stack_grid(grid, offset, stacked); 203 return Interpolant::construct(solver, name, stacked, offset, std::vector<double>{}, m, opts); 204 } 205 interpolant(const std::string & name,const std::string & solver,const std::vector<casadi_int> & grid_dims,casadi_int m,const Dict & opts)206 Function interpolant(const std::string& name, 207 const std::string& solver, 208 const std::vector<casadi_int>& grid_dims, 209 casadi_int m, 210 const Dict& opts) { 211 Interpolant::check_grid(grid_dims); 212 return Interpolant::construct(solver, name, std::vector<double>{}, 213 cumsum0(grid_dims), std::vector<double>{}, m, opts); 214 } 215 216 Interpolant:: Interpolant(const std::string & name,const std::vector<double> & grid,const std::vector<casadi_int> & offset,const std::vector<double> & values,casadi_int m)217 Interpolant(const std::string& name, 218 const std::vector<double>& grid, 219 const std::vector<casadi_int>& offset, 220 const std::vector<double>& values, 221 casadi_int m) 222 : FunctionInternal(name), m_(m), grid_(grid), offset_(offset), values_(values) { 223 // Number of grid points 224 ndim_ = offset_.size()-1; 225 } 226 ~Interpolant()227 Interpolant::~Interpolant() { 228 } 229 get_sparsity_in(casadi_int i)230 Sparsity Interpolant::get_sparsity_in(casadi_int i) { 231 if (i==0) return Sparsity::dense(ndim_, batch_x_); 232 if (arg_values(i)) return Sparsity::dense(coeff_size()); 233 if (arg_grid(i)) return Sparsity::dense(offset_.back()); 234 casadi_assert_dev(false); 235 } 236 get_sparsity_out(casadi_int i)237 Sparsity Interpolant::get_sparsity_out(casadi_int i) { 238 casadi_assert_dev(i==0); 239 return Sparsity::dense(m_, batch_x_); 240 } 241 get_name_in(casadi_int i)242 std::string Interpolant::get_name_in(casadi_int i) { 243 if (i==0) return "x"; 244 if (arg_values(i)) return "c"; 245 if (arg_grid(i)) return "g"; 246 casadi_assert_dev(false); 247 } 248 get_name_out(casadi_int i)249 std::string Interpolant::get_name_out(casadi_int i) { 250 casadi_assert_dev(i==0); 251 return "f"; 252 } 253 254 std::map<std::string, Interpolant::Plugin> Interpolant::solvers_; 255 256 const std::string Interpolant::infix_ = "interpolant"; 257 258 const Options Interpolant::options_ 259 = {{&FunctionInternal::options_}, 260 {{"lookup_mode", 261 {OT_STRINGVECTOR, 262 "Specifies, for each grid dimenion, the lookup algorithm used to find the correct index. " 263 "'linear' uses a for-loop + break; (default when #knots<=100), " 264 "'exact' uses floored division (only for uniform grids), " 265 "'binary' uses a binary search. (default when #knots>100)."}}, 266 {"inline", 267 {OT_BOOL, 268 "Implement the lookup table in MX primitives. " 269 "Useful when you need derivatives with respect to grid and/or coefficients. " 270 "Such derivatives are fundamentally dense, so use with caution."}}, 271 {"batch_x", 272 {OT_INT, 273 "Evaluate a batch of different inputs at once (default 1)."}} 274 } 275 }; 276 arg_values(casadi_int i) const277 bool Interpolant::arg_values(casadi_int i) const { 278 if (!has_parametric_values()) return false; 279 return arg_values()==i; 280 } arg_grid(casadi_int i) const281 bool Interpolant::arg_grid(casadi_int i) const { 282 if (!has_parametric_grid()) return false; 283 return arg_grid()==i; 284 } 285 arg_values() const286 casadi_int Interpolant::arg_values() const { 287 casadi_assert_dev(has_parametric_values()); 288 return 1+has_parametric_grid(); 289 } arg_grid() const290 casadi_int Interpolant::arg_grid() const { 291 casadi_assert_dev(has_parametric_grid()); 292 return 1; 293 } 294 init(const Dict & opts)295 void Interpolant::init(const Dict& opts) { 296 297 batch_x_ = 1; 298 299 // Read options 300 for (auto&& op : opts) { 301 if (op.first=="lookup_mode") { 302 lookup_modes_ = op.second; 303 } else if (op.first=="batch_x") { 304 batch_x_ = op.second; 305 } 306 } 307 308 // Call the base class initializer 309 FunctionInternal::init(opts); 310 311 // Needed by casadi_interpn 312 alloc_w(ndim_, true); 313 alloc_iw(2*ndim_, true); 314 } 315 interpret_lookup_mode(const std::vector<std::string> & modes,const std::vector<double> & knots,const std::vector<casadi_int> & offset,const std::vector<casadi_int> & margin_left,const std::vector<casadi_int> & margin_right)316 std::vector<casadi_int> Interpolant::interpret_lookup_mode( 317 const std::vector<std::string>& modes, const std::vector<double>& knots, 318 const std::vector<casadi_int>& offset, 319 const std::vector<casadi_int>& margin_left, const std::vector<casadi_int>& margin_right) { 320 casadi_assert_dev(modes.empty() || modes.size()==offset.size()-1); 321 322 std::vector<casadi_int> ret; 323 for (casadi_int i=0;i<offset.size()-1;++i) { 324 casadi_int n = offset[i+1]-offset[i]; 325 ret.push_back(Low::interpret_lookup_mode(modes.empty() ? "auto": modes[i], n)); 326 } 327 328 for (casadi_int i=0;i<offset.size()-1;++i) { 329 if (ret[i]==LOOKUP_EXACT) { 330 if (!knots.empty()) { 331 casadi_int m_left = margin_left.empty() ? 0 : margin_left[i]; 332 casadi_int m_right = margin_right.empty() ? 0 : margin_right[i]; 333 334 std::vector<double> grid( 335 knots.begin()+offset[i]+m_left, 336 knots.begin()+offset[i+1]-m_right); 337 casadi_assert_dev(is_increasing(grid) && is_equally_spaced(grid)); 338 } 339 } 340 } 341 return ret; 342 } 343 serialize_body(SerializingStream & s) const344 void Interpolant::serialize_body(SerializingStream &s) const { 345 FunctionInternal::serialize_body(s); 346 s.version("Interpolant", 2); 347 s.pack("Interpolant::ndim", ndim_); 348 s.pack("Interpolant::m", m_); 349 s.pack("Interpolant::grid", grid_); 350 s.pack("Interpolant::offset", offset_); 351 s.pack("Interpolant::values", values_); 352 s.pack("Interpolant::lookup_modes", lookup_modes_); 353 s.pack("Interpolant::batch_x", batch_x_); 354 } 355 serialize_type(SerializingStream & s) const356 void Interpolant::serialize_type(SerializingStream &s) const { 357 FunctionInternal::serialize_type(s); 358 PluginInterface<Interpolant>::serialize_type(s); 359 } 360 deserialize(DeserializingStream & s)361 ProtoFunction* Interpolant::deserialize(DeserializingStream& s) { 362 return PluginInterface<Interpolant>::deserialize(s); 363 } 364 Interpolant(DeserializingStream & s)365 Interpolant::Interpolant(DeserializingStream & s) : FunctionInternal(s) { 366 int version = s.version("Interpolant", 1, 2); 367 s.unpack("Interpolant::ndim", ndim_); 368 s.unpack("Interpolant::m", m_); 369 s.unpack("Interpolant::grid", grid_); 370 s.unpack("Interpolant::offset", offset_); 371 s.unpack("Interpolant::values", values_); 372 s.unpack("Interpolant::lookup_modes", lookup_modes_); 373 if (version==1) { 374 batch_x_ = 1; 375 } else { 376 s.unpack("Interpolant::batch_x", batch_x_); 377 } 378 } 379 380 } // namespace casadi 381