1 #ifndef VEXCL_MBA_HPP 2 #define VEXCL_MBA_HPP 3 4 /* 5 The MIT License 6 7 Copyright (c) 2012-2018 Denis Demidov <dennis.demidov@gmail.com> 8 9 Permission is hereby granted, free of charge, to any person obtaining a copy 10 of this software and associated documentation files (the "Software"), to deal 11 in the Software without restriction, including without limitation the rights 12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 copies of the Software, and to permit persons to whom the Software is 14 furnished to do so, subject to the following conditions: 15 16 The above copyright notice and this permission notice shall be included in 17 all copies or substantial portions of the Software. 18 19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25 THE SOFTWARE. 26 */ 27 28 /** 29 * \file vexcl/mba.hpp 30 * \author Denis Demidov <dennis.demidov@gmail.com> 31 * \brief Scattered data interpolation with multilevel B-Splines. 32 */ 33 34 #include <vector> 35 #include <array> 36 #include <sstream> 37 #include <memory> 38 #include <algorithm> 39 #include <numeric> 40 #include <type_traits> 41 #include <cassert> 42 43 #include <boost/tuple/tuple.hpp> 44 #include <boost/fusion/adapted/boost_tuple.hpp> 45 46 #include <vexcl/operations.hpp> 47 48 // Include boost.preprocessor header if variadic templates are not available. 49 // Also include it if we use gcc v4.6. 50 // This is required due to bug http://gcc.gnu.org/bugzilla/show_bug.cgi?id=35722 51 #if defined(BOOST_NO_VARIADIC_TEMPLATES) || (defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ == 6) 52 # include <boost/preprocessor/repetition.hpp> 53 # ifndef VEXCL_MAX_ARITY 54 # define VEXCL_MAX_ARITY BOOST_PROTO_MAX_ARITY 55 # endif 56 #endif 57 namespace vex { 58 59 struct mba_terminal {}; 60 61 typedef vector_expression< 62 typename boost::proto::terminal< mba_terminal >::type 63 > mba_terminal_expression; 64 65 template <class MBA, class ExprTuple> 66 struct mba_interp : public mba_terminal_expression { 67 typedef typename MBA::value_type value_type; 68 69 const MBA &cloud; 70 const ExprTuple coord; 71 mba_interpvex::mba_interp72 mba_interp(const MBA &cloud, const ExprTuple coord) 73 : cloud(cloud), coord(coord) {} 74 }; 75 76 namespace detail { 77 // Compile time value of N^M. 78 template <size_t N, size_t M> 79 struct power : std::integral_constant<size_t, N * power<N, M-1>::value> {}; 80 81 template <size_t N> 82 struct power<N, 0> : std::integral_constant<size_t, 1> {}; 83 84 // Nested loop counter of compile-time size (M loops of size N). 85 template <size_t N, size_t M> 86 class scounter { 87 public: scounter()88 scounter() : idx(0) { 89 std::fill(i.begin(), i.end(), static_cast<size_t>(0)); 90 } 91 operator [](size_t d) const92 size_t operator[](size_t d) const { 93 return i[d]; 94 } 95 operator ++()96 scounter& operator++() { 97 for(size_t d = M; d--; ) { 98 if (++i[d] < N) break; 99 i[d] = 0; 100 } 101 102 ++idx; 103 104 return *this; 105 } 106 operator size_t() const107 operator size_t() const { 108 return idx; 109 } 110 valid() const111 bool valid() const { 112 return idx < power<N, M>::value; 113 } 114 private: 115 size_t idx; 116 std::array<size_t, M> i; 117 }; 118 119 // Nested loop counter of run-time size (M loops of given sizes). 120 template <size_t M> 121 class dcounter { 122 public: dcounter(const std::array<size_t,M> & N)123 dcounter(const std::array<size_t, M> &N) 124 : idx(0), 125 size(std::accumulate(N.begin(), N.end(), 126 static_cast<size_t>(1), std::multiplies<size_t>())), 127 N(N) 128 { 129 std::fill(i.begin(), i.end(), static_cast<size_t>(0)); 130 } 131 operator [](size_t d) const132 size_t operator[](size_t d) const { 133 return i[d]; 134 } 135 operator ++()136 dcounter& operator++() { 137 for(size_t d = M; d--; ) { 138 if (++i[d] < N[d]) break; 139 i[d] = 0; 140 } 141 142 ++idx; 143 144 return *this; 145 } 146 operator size_t() const147 operator size_t() const { 148 return idx; 149 } 150 valid() const151 bool valid() const { 152 return idx < size; 153 } 154 private: 155 size_t idx, size; 156 std::array<size_t, M> N, i; 157 }; 158 } // namespace detail 159 160 /// Scattered data interpolation with multilevel B-Splines. 161 template <size_t NDIM, typename real = double> 162 class mba { 163 public: 164 typedef real value_type; 165 typedef std::array<real, NDIM> point; 166 typedef std::array<size_t, NDIM> index; 167 168 static const size_t ndim = NDIM; 169 170 std::vector< backend::command_queue > queue; 171 std::vector< backend::device_vector<real> > phi; 172 point xmin, hinv; 173 index n, stride; 174 175 /** Creates the approximation functor. 176 * `cmin` and `cmax` specify the domain boundaries, `coo` and `val` 177 * contain coordinates and values of the data points. `grid` is the 178 * initial control grid size. The approximation hierarchy will have at 179 * most `levels` and will stop when the desired approximation precision 180 * `tol` will be reached. 181 */ mba(const std::vector<backend::command_queue> & queue,const point & cmin,const point & cmax,const std::vector<point> & coo,std::vector<real> val,std::array<size_t,NDIM> grid,size_t levels=8,real tol=1e-8)182 mba( 183 const std::vector<backend::command_queue> &queue, 184 const point &cmin, const point &cmax, 185 const std::vector<point> &coo, std::vector<real> val, 186 std::array<size_t, NDIM> grid, size_t levels = 8, real tol = 1e-8 187 ) : queue(queue) 188 { 189 init(cmin, cmax, coo.begin(), coo.end(), val.begin(), grid, levels, tol); 190 } 191 192 /** Creates the approximation functor. 193 * `cmin` and `cmax` specify the domain boundaries. Coordinates and 194 * values of the data points are passed as iterator ranges. `grid` is 195 * the initial control grid size. The approximation hierarchy will have 196 * at most `levels` and will stop when the desired approximation 197 * precision `tol` will be reached. 198 */ 199 template <class CooIter, class ValIter> mba(const std::vector<backend::command_queue> & queue,const point & cmin,const point & cmax,CooIter coo_begin,CooIter coo_end,ValIter val_begin,std::array<size_t,NDIM> grid,size_t levels=8,real tol=1e-8)200 mba( 201 const std::vector<backend::command_queue> &queue, 202 const point &cmin, const point &cmax, 203 CooIter coo_begin, CooIter coo_end, ValIter val_begin, 204 std::array<size_t, NDIM> grid, size_t levels = 8, real tol = 1e-8 205 ) : queue(queue) 206 { 207 init(cmin, cmax, coo_begin, coo_end, val_begin, grid, levels, tol); 208 } 209 210 #if !defined(BOOST_NO_VARIADIC_TEMPLATES) && ((!defined(__GNUC__) || (__GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ > 6)) || defined(__clang__)) 211 /// Provide interpolated values at given coordinates. 212 template <class... Expr> operator ()(const Expr &...expr) const213 auto operator()(const Expr&... expr) const -> 214 mba_interp< mba, boost::tuple<const Expr&...> > 215 { 216 static_assert(sizeof...(Expr) == NDIM, "Wrong number of parameters"); 217 return mba_interp< mba, boost::tuple<const Expr&...> >(*this, boost::tie(expr...)); 218 } 219 #else 220 221 #define VEXCL_FUNCALL_OPERATOR(z, n, data) \ 222 template <BOOST_PP_ENUM_PARAMS(n, class Expr)> \ 223 mba_interp<mba, boost::tuple<BOOST_PP_ENUM_BINARY_PARAMS( \ 224 n, const Expr, &BOOST_PP_INTERCEPT)> > \ 225 operator()(BOOST_PP_ENUM_BINARY_PARAMS(n, const Expr, &expr)) { \ 226 return mba_interp<mba, boost::tuple<BOOST_PP_ENUM_BINARY_PARAMS( \ 227 n, const Expr, &BOOST_PP_INTERCEPT)> >( \ 228 *this, boost::tie(BOOST_PP_ENUM_PARAMS(n, expr))); \ 229 } 230 231 BOOST_PP_REPEAT_FROM_TO(1, 10, VEXCL_FUNCALL_OPERATOR, ~) 232 233 #undef VEXCL_FUNCALL_OPERATOR 234 #endif 235 private: 236 template <class CooIter, class ValIter> init(const point & cmin,const point & cmax,CooIter coo_begin,CooIter coo_end,ValIter val_begin,std::array<size_t,NDIM> grid,size_t levels=8,real tol=1e-8)237 void init( 238 const point &cmin, const point &cmax, 239 CooIter coo_begin, CooIter coo_end, ValIter val_begin, 240 std::array<size_t, NDIM> grid, size_t levels = 8, real tol = 1e-8 241 ) 242 { 243 for(size_t k = 0; k < NDIM; ++k) 244 assert(grid[k] > 1); 245 246 double res0 = std::accumulate( 247 val_begin, val_begin + (coo_end - coo_begin), 248 static_cast<real>(0), 249 [](real sum, real v) { return sum + v * v; } 250 ); 251 252 std::unique_ptr<lattice> psi( 253 new lattice(cmin, cmax, grid, coo_begin, coo_end, val_begin) 254 ); 255 double res = psi->update_data(coo_begin, coo_end, val_begin); 256 #ifdef VEXCL_MBA_VERBOSE 257 std::cout << "level 0: res = " << std::scientific << res << std::endl; 258 #endif 259 260 for (size_t k = 1; (res > res0 * tol) && (k < levels); ++k) { 261 for(size_t d = 0; d < NDIM; ++d) grid[d] = 2 * grid[d] - 1; 262 263 std::unique_ptr<lattice> f( 264 new lattice(cmin, cmax, grid, coo_begin, coo_end, val_begin) 265 ); 266 res = f->update_data(coo_begin, coo_end, val_begin); 267 #ifdef VEXCL_MBA_VERBOSE 268 std::cout << "level " << k << std::scientific << ": res = " << res << std::endl; 269 #endif 270 271 f->append_refined(*psi); 272 psi = std::move(f); 273 } 274 275 xmin = psi->xmin; 276 hinv = psi->hinv; 277 n = psi->n; 278 stride = psi->stride; 279 280 phi.reserve(queue.size()); 281 282 for(auto q = queue.begin(); q != queue.end(); ++q) 283 phi.push_back( backend::device_vector<real>( 284 *q, psi->phi.size(), psi->phi.data(), backend::MEM_READ_ONLY 285 ) ); 286 } 287 288 // Control lattice. 289 struct lattice { 290 point xmin, hinv; 291 index n, stride; 292 std::vector<real> phi; 293 294 template <class CooIter, class ValIter> latticevex::mba::lattice295 lattice( 296 const point &cmin, const point &cmax, std::array<size_t, NDIM> grid, 297 CooIter coo_begin, CooIter coo_end, ValIter val_begin 298 ) : xmin(cmin), n(grid) 299 { 300 for(size_t d = 0; d < NDIM; ++d) { 301 hinv[d] = (grid[d] - 1) / (cmax[d] - cmin[d]); 302 xmin[d] -= 1 / hinv[d]; 303 n[d] += 2; 304 } 305 306 stride[NDIM - 1] = 1; 307 for(size_t d = NDIM - 1; d--; ) 308 stride[d] = stride[d + 1] * n[d + 1]; 309 310 std::vector<real> delta(n[0] * stride[0], 0.0); 311 std::vector<real> omega(n[0] * stride[0], 0.0); 312 313 auto p = coo_begin; 314 auto v = val_begin; 315 for(; p != coo_end; ++p, ++v) { 316 if (!contained(cmin, cmax, *p)) continue; 317 318 index i; 319 point s; 320 321 for(size_t d = 0; d < NDIM; ++d) { 322 real u = ((*p)[d] - xmin[d]) * hinv[d]; 323 i[d] = static_cast<size_t>(std::floor(u) - 1); 324 s[d] = u - std::floor(u); 325 } 326 327 std::array<real, detail::power<4, NDIM>::value> w; 328 real sw2 = 0; 329 330 for(detail::scounter<4, NDIM> d; d.valid(); ++d) { 331 real buf = 1; 332 for(size_t k = 0; k < NDIM; ++k) 333 buf *= B(d[k], s[k]); 334 335 w[d] = buf; 336 sw2 += buf * buf; 337 } 338 339 for(detail::scounter<4, NDIM> d; d.valid(); ++d) { 340 real phi = (*v) * w[d] / sw2; 341 342 size_t idx = 0; 343 for(size_t k = 0; k < NDIM; ++k) { 344 assert(i[k] + d[k] < n[k]); 345 346 idx += (i[k] + d[k]) * stride[k]; 347 } 348 349 real w2 = w[d] * w[d]; 350 351 assert(idx < delta.size()); 352 353 delta[idx] += w2 * phi; 354 omega[idx] += w2; 355 } 356 } 357 358 phi.resize(omega.size()); 359 360 for(auto w = omega.begin(), d = delta.begin(), f = phi.begin(); 361 w != omega.end(); 362 ++w, ++d, ++f 363 ) 364 { 365 if (std::fabs(*w) < 1e-32) 366 *f = 0; 367 else 368 *f = (*d) / (*w); 369 } 370 } 371 372 // Get interpolated value at given position. operator ()vex::mba::lattice373 real operator()(const point &p) const { 374 index i; 375 point s; 376 377 for(size_t d = 0; d < NDIM; ++d) { 378 real u = (p[d] - xmin[d]) * hinv[d]; 379 i[d] = static_cast<size_t>(std::floor(u) - 1); 380 s[d] = u - std::floor(u); 381 } 382 383 real f = 0; 384 385 for(detail::scounter<4, NDIM> d; d.valid(); ++d) { 386 real w = 1; 387 for(size_t k = 0; k < NDIM; ++k) 388 w *= B(d[k], s[k]); 389 390 f += w * get(i, d); 391 } 392 393 return f; 394 } 395 396 // Subtract interpolated values from data points. 397 template <class CooIter, class ValIter> update_datavex::mba::lattice398 real update_data( 399 CooIter coo_begin, CooIter coo_end, ValIter val_begin 400 ) const 401 { 402 auto c = coo_begin; 403 auto v = val_begin; 404 405 real res = 0; 406 407 for(; c != coo_end; ++c, ++v) { 408 *v -= (*this)(*c); 409 410 res += (*v) * (*v); 411 } 412 413 return res; 414 } 415 416 // Refine r and append it to the current control lattice. append_refinedvex::mba::lattice417 void append_refined(const lattice &r) { 418 static const std::array<real, 5> s = {{ 419 0.125, 0.500, 0.750, 0.500, 0.125 420 }}; 421 422 for(detail::dcounter<NDIM> i(r.n); i.valid(); ++i) { 423 real f = r.phi[i]; 424 for(detail::scounter<5, NDIM> d; d.valid(); ++d) { 425 index j; 426 bool skip = false; 427 size_t idx = 0; 428 for(size_t k = 0; k < NDIM; ++k) { 429 j[k] = 2 * i[k] + d[k] - 3; 430 if (j[k] >= n[k]) { skip = true; break; } 431 432 idx += j[k] * stride[k]; 433 } 434 435 if (skip) continue; 436 437 real c = 1; 438 for(size_t k = 0; k < NDIM; ++k) c *= s[d[k]]; 439 440 phi[idx] += f * c; 441 } 442 } 443 } 444 445 private: 446 // Value of k-th B-Spline at t. Bvex::mba::lattice447 static inline real B(size_t k, real t) { 448 assert(0 <= t && t < 1); 449 assert(k < 4); 450 451 switch (k) { 452 case 0: 453 return (t * (t * (-t + 3) - 3) + 1) / 6; 454 case 1: 455 return (t * t * (3 * t - 6) + 4) / 6; 456 case 2: 457 return (t * (t * (-3 * t + 3) + 3) + 1) / 6; 458 case 3: 459 default: 460 return t * t * t / 6; 461 } 462 } 463 464 // x is within [xmin, xmax]. containedvex::mba::lattice465 static bool contained( 466 const point &xmin, const point &xmax, const point &x) 467 { 468 for(size_t d = 0; d < NDIM; ++d) { 469 static const real eps = 1e-12; 470 471 if (x[d] - eps < xmin[d]) return false; 472 if (x[d] + eps >= xmax[d]) return false; 473 } 474 475 return true; 476 } 477 478 // Get value of phi at index (i + d). 479 template <class Shift> getvex::mba::lattice480 inline real get(const index &i, const Shift &d) const { 481 size_t idx = 0; 482 483 for(size_t k = 0; k < NDIM; ++k) { 484 size_t j = i[k] + d[k]; 485 486 if (j >= n[k]) return 0; 487 idx += j * stride[k]; 488 } 489 490 return phi[idx]; 491 } 492 }; 493 }; 494 495 namespace traits { 496 497 template <> 498 struct is_vector_expr_terminal< mba_terminal > : std::true_type {}; 499 500 template <> 501 struct proto_terminal_is_value< mba_terminal > : std::true_type {}; 502 503 template <class MBA, class ExprTuple> 504 struct terminal_preamble< mba_interp<MBA, ExprTuple> > { getvex::traits::terminal_preamble505 static void get(backend::source_generator &src, 506 const mba_interp<MBA, ExprTuple>&, 507 const backend::command_queue&, const std::string &prm_name, 508 detail::kernel_generator_state_ptr) 509 { 510 typedef typename MBA::value_type real; 511 512 std::string B = prm_name + "_B"; 513 514 src.begin_function<real>(B + "0"); 515 src.begin_function_parameters(); 516 src.template parameter<real>("t"); 517 src.end_function_parameters(); 518 src.new_line() << "return (t * (t * (-t + 3) - 3) + 1) / 6;"; 519 src.end_function(); 520 521 src.begin_function<real>(B + "1"); 522 src.begin_function_parameters(); 523 src.template parameter<real>("t"); 524 src.end_function_parameters(); 525 src.new_line() << "return (t * t * (3 * t - 6) + 4) / 6;"; 526 src.end_function(); 527 528 src.begin_function<real>(B + "2"); 529 src.begin_function_parameters(); 530 src.template parameter<real>("t"); 531 src.end_function_parameters(); 532 src.new_line() << "return (t * (t * (-3 * t + 3) + 3) + 1) / 6;"; 533 src.end_function(); 534 535 src.begin_function<real>(B + "3"); 536 src.begin_function_parameters(); 537 src.template parameter<real>("t"); 538 src.end_function_parameters(); 539 src.new_line() << "return t * t * t / 6;"; 540 src.end_function(); 541 542 src.begin_function<real>(prm_name + "_mba"); 543 src.begin_function_parameters(); 544 545 for(size_t k = 0; k < MBA::ndim; ++k) 546 src.template parameter<real>("x") << k; 547 548 for(size_t k = 0; k < MBA::ndim; ++k) { 549 src.template parameter<real>("c" + std::to_string(k)); 550 src.template parameter<real>("h" + std::to_string(k)); 551 src.parameter<size_t>("n" + std::to_string(k)); 552 src.parameter<size_t>("m" + std::to_string(k)); 553 } 554 555 src.template parameter< global_ptr<const real> >("phi"); 556 src.end_function_parameters(); 557 src.new_line() << type_name<real>() << " u;"; 558 for(size_t k = 0; k < MBA::ndim; ++k) { 559 src.new_line() << "u = (x" << k << " - c" << k << ") * h" << k << ";"; 560 src.new_line() << type_name<size_t>() << " i" << k << " = floor(u) - 1;"; 561 src.new_line() << type_name<real>() << " s" << k << " = u - floor(u);"; 562 } 563 src.new_line() << type_name<real>() << " f = 0;"; 564 src.new_line() << type_name<size_t>() << " j, idx;"; 565 566 for(detail::scounter<4,MBA::ndim> d; d.valid(); ++d) { 567 src.new_line() << "idx = 0;"; 568 for(size_t k = 0; k < MBA::ndim; ++k) { 569 src.new_line() << "j = i" << k << " + " << d[k] << ";"; 570 src.new_line() << "if (j < n" << k << ")"; 571 src.open("{").new_line() << "idx += j * m" << k << ";"; 572 } 573 574 src.new_line() << "f += "; 575 for(size_t k = 0; k < MBA::ndim; ++k) { 576 if (k) src << " * "; 577 src << B << d[k] << "(s" << k << ")"; 578 } 579 580 src << " * phi[idx];"; 581 582 for(size_t k = 0; k < MBA::ndim; ++k) 583 src.close("}"); 584 } 585 src.new_line() << "return f;"; 586 src.end_function(); 587 } 588 }; 589 590 template <class MBA, class ExprTuple> 591 struct kernel_param_declaration< mba_interp<MBA, ExprTuple> > { getvex::traits::kernel_param_declaration592 static void get(backend::source_generator &src, 593 const mba_interp<MBA, ExprTuple> &term, 594 const backend::command_queue &queue, const std::string &prm_name, 595 detail::kernel_generator_state_ptr state) 596 { 597 typedef typename MBA::value_type real; 598 boost::fusion::for_each(term.coord, prmdecl(src, queue, prm_name, state)); 599 600 for(size_t k = 0; k < MBA::ndim; ++k) { 601 src.template parameter<real>(prm_name + "_c" + std::to_string(k)); 602 src.template parameter<real>(prm_name + "_h" + std::to_string(k)); 603 src.parameter<size_t>(prm_name + "_n" + std::to_string(k)); 604 src.parameter<size_t>(prm_name + "_m" + std::to_string(k)); 605 } 606 607 src.parameter< global_ptr<const real> >(prm_name + "_phi"); 608 } 609 610 struct prmdecl { 611 backend::source_generator &s; 612 const backend::command_queue &queue; 613 const std::string &prm_name; 614 detail::kernel_generator_state_ptr state; 615 mutable int pos; 616 prmdeclvex::traits::kernel_param_declaration::prmdecl617 prmdecl(backend::source_generator &s, 618 const backend::command_queue &queue, const std::string &prm_name, 619 detail::kernel_generator_state_ptr state 620 ) : s(s), queue(queue), prm_name(prm_name), state(state), pos(0) 621 {} 622 623 template <class Expr> operator ()vex::traits::kernel_param_declaration::prmdecl624 void operator()(const Expr &expr) const { 625 std::ostringstream prefix; 626 prefix << prm_name << "_x" << pos; 627 detail::declare_expression_parameter ctx(s, queue, prefix.str(), state); 628 detail::extract_terminals()(boost::proto::as_child(expr), ctx); 629 630 pos++; 631 } 632 }; 633 }; 634 635 template <class MBA, class ExprTuple> 636 struct local_terminal_init< mba_interp<MBA, ExprTuple> > { getvex::traits::local_terminal_init637 static void get(backend::source_generator &src, 638 const mba_interp<MBA, ExprTuple> &term, 639 const backend::command_queue &queue, const std::string &prm_name, 640 detail::kernel_generator_state_ptr state) 641 { 642 boost::fusion::for_each(term.coord, local_init(src, queue, prm_name, state)); 643 } 644 645 struct local_init { 646 backend::source_generator &s; 647 const backend::command_queue &queue; 648 const std::string &prm_name; 649 detail::kernel_generator_state_ptr state; 650 mutable int pos; 651 local_initvex::traits::local_terminal_init::local_init652 local_init(backend::source_generator &s, 653 const backend::command_queue &queue, const std::string &prm_name, 654 detail::kernel_generator_state_ptr state 655 ) : s(s), queue(queue), prm_name(prm_name), state(state), pos(0) 656 {} 657 658 template <class Expr> operator ()vex::traits::local_terminal_init::local_init659 void operator()(const Expr &expr) const { 660 std::ostringstream prefix; 661 prefix << prm_name << "_x" << pos; 662 663 detail::output_local_preamble init_ctx(s, queue, prefix.str(), state); 664 boost::proto::eval(boost::proto::as_child(expr), init_ctx); 665 666 pos++; 667 } 668 }; 669 }; 670 671 template <class MBA, class ExprTuple> 672 struct partial_vector_expr< mba_interp<MBA, ExprTuple> > { getvex::traits::partial_vector_expr673 static void get(backend::source_generator &src, 674 const mba_interp<MBA, ExprTuple> &term, 675 const backend::command_queue &queue, const std::string &prm_name, 676 detail::kernel_generator_state_ptr state) 677 { 678 src << prm_name << "_mba("; 679 680 boost::fusion::for_each(term.coord, buildexpr(src, queue, prm_name, state)); 681 682 for(size_t k = 0; k < MBA::ndim; ++k) { 683 src << ", " << prm_name << "_c" << k 684 << ", " << prm_name << "_h" << k 685 << ", " << prm_name << "_n" << k 686 << ", " << prm_name << "_m" << k; 687 } 688 689 src << ", " << prm_name << "_phi)"; 690 } 691 692 struct buildexpr { 693 backend::source_generator &s; 694 const backend::command_queue &queue; 695 const std::string &prm_name; 696 detail::kernel_generator_state_ptr state; 697 mutable int pos; 698 buildexprvex::traits::partial_vector_expr::buildexpr699 buildexpr(backend::source_generator &s, 700 const backend::command_queue &queue, const std::string &prm_name, 701 detail::kernel_generator_state_ptr state 702 ) : s(s), queue(queue), prm_name(prm_name), state(state), pos(0) 703 {} 704 705 template <class Expr> operator ()vex::traits::partial_vector_expr::buildexpr706 void operator()(const Expr &expr) const { 707 if(pos) s << ", "; 708 709 std::ostringstream prefix; 710 prefix << prm_name << "_x" << pos; 711 712 detail::vector_expr_context ctx(s, queue, prefix.str(), state); 713 boost::proto::eval(boost::proto::as_child(expr), ctx); 714 715 pos++; 716 } 717 }; 718 }; 719 720 template <class MBA, class ExprTuple> 721 struct kernel_arg_setter< mba_interp<MBA, ExprTuple> > { setvex::traits::kernel_arg_setter722 static void set(const mba_interp<MBA, ExprTuple> &term, 723 backend::kernel &kernel, unsigned part, size_t index_offset, 724 detail::kernel_generator_state_ptr state) 725 { 726 727 boost::fusion::for_each(term.coord, 728 setargs(kernel, part, index_offset, state)); 729 730 for(size_t k = 0; k < MBA::ndim; ++k) { 731 kernel.push_arg(term.cloud.xmin[k]); 732 kernel.push_arg(term.cloud.hinv[k]); 733 kernel.push_arg(term.cloud.n[k]); 734 kernel.push_arg(term.cloud.stride[k]); 735 } 736 kernel.push_arg(term.cloud.phi[part]); 737 } 738 739 struct setargs { 740 backend::kernel &kernel; 741 unsigned part; 742 size_t index_offset; 743 detail::kernel_generator_state_ptr state; 744 setargsvex::traits::kernel_arg_setter::setargs745 setargs( 746 backend::kernel &kernel, unsigned part, size_t index_offset, 747 detail::kernel_generator_state_ptr state 748 ) 749 : kernel(kernel), part(part), index_offset(index_offset), state(state) 750 {} 751 752 template <class Expr> operator ()vex::traits::kernel_arg_setter::setargs753 void operator()(const Expr &expr) const { 754 detail::set_expression_argument ctx(kernel, part, index_offset, state); 755 detail::extract_terminals()( boost::proto::as_child(expr), ctx); 756 } 757 }; 758 }; 759 760 template <class MBA, class ExprTuple> 761 struct expression_properties< mba_interp<MBA, ExprTuple> > { getvex::traits::expression_properties762 static void get(const mba_interp<MBA, ExprTuple> &term, 763 std::vector<backend::command_queue> &queue_list, 764 std::vector<size_t> &partition, 765 size_t &size 766 ) 767 { 768 boost::fusion::for_each(term.coord, extrprop(queue_list, partition, size)); 769 } 770 771 struct extrprop { 772 std::vector<backend::command_queue> &queue_list; 773 std::vector<size_t> &partition; 774 size_t &size; 775 extrpropvex::traits::expression_properties::extrprop776 extrprop(std::vector<backend::command_queue> &queue_list, 777 std::vector<size_t> &partition, size_t &size 778 ) : queue_list(queue_list), partition(partition), size(size) 779 {} 780 781 template <class Expr> operator ()vex::traits::expression_properties::extrprop782 void operator()(const Expr &expr) const { 783 if (queue_list.empty()) { 784 detail::get_expression_properties prop; 785 detail::extract_terminals()(boost::proto::as_child(expr), prop); 786 787 queue_list = prop.queue; 788 partition = prop.part; 789 size = prop.size; 790 } 791 } 792 }; 793 }; 794 795 } //namespace traits 796 797 } // namespace vex 798 799 800 #endif 801