1 #ifndef VEXCL_MBA_HPP
2 #define VEXCL_MBA_HPP
3 
4 /*
5 The MIT License
6 
7 Copyright (c) 2012-2018 Denis Demidov <dennis.demidov@gmail.com>
8 
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15 
16 The above copyright notice and this permission notice shall be included in
17 all copies or substantial portions of the Software.
18 
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25 THE SOFTWARE.
26 */
27 
28 /**
29  * \file   vexcl/mba.hpp
30  * \author Denis Demidov <dennis.demidov@gmail.com>
31  * \brief  Scattered data interpolation with multilevel B-Splines.
32  */
33 
34 #include <vector>
35 #include <array>
36 #include <sstream>
37 #include <memory>
38 #include <algorithm>
39 #include <numeric>
40 #include <type_traits>
41 #include <cassert>
42 
43 #include <boost/tuple/tuple.hpp>
44 #include <boost/fusion/adapted/boost_tuple.hpp>
45 
46 #include <vexcl/operations.hpp>
47 
48 // Include boost.preprocessor header if variadic templates are not available.
49 // Also include it if we use gcc v4.6.
50 // This is required due to bug http://gcc.gnu.org/bugzilla/show_bug.cgi?id=35722
51 #if defined(BOOST_NO_VARIADIC_TEMPLATES) || (defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ == 6)
52 #  include <boost/preprocessor/repetition.hpp>
53 #  ifndef VEXCL_MAX_ARITY
54 #    define VEXCL_MAX_ARITY BOOST_PROTO_MAX_ARITY
55 #  endif
56 #endif
57 namespace vex {
58 
59 struct mba_terminal {};
60 
61 typedef vector_expression<
62     typename boost::proto::terminal< mba_terminal >::type
63     > mba_terminal_expression;
64 
65 template <class MBA, class ExprTuple>
66 struct mba_interp : public mba_terminal_expression {
67     typedef typename MBA::value_type value_type;
68 
69     const MBA      &cloud;
70     const ExprTuple coord;
71 
mba_interpvex::mba_interp72     mba_interp(const MBA &cloud, const ExprTuple coord)
73         : cloud(cloud), coord(coord) {}
74 };
75 
76 namespace detail {
77     // Compile time value of N^M.
78     template <size_t N, size_t M>
79     struct power : std::integral_constant<size_t, N * power<N, M-1>::value> {};
80 
81     template <size_t N>
82     struct power<N, 0> : std::integral_constant<size_t, 1> {};
83 
84     // Nested loop counter of compile-time size (M loops of size N).
85     template <size_t N, size_t M>
86     class scounter {
87         public:
scounter()88             scounter() : idx(0) {
89                 std::fill(i.begin(), i.end(), static_cast<size_t>(0));
90             }
91 
operator [](size_t d) const92             size_t operator[](size_t d) const {
93                 return i[d];
94             }
95 
operator ++()96             scounter& operator++() {
97                 for(size_t d = M; d--; ) {
98                     if (++i[d] < N) break;
99                     i[d] = 0;
100                 }
101 
102                 ++idx;
103 
104                 return *this;
105             }
106 
operator size_t() const107             operator size_t() const {
108                 return idx;
109             }
110 
valid() const111             bool valid() const {
112                 return idx < power<N, M>::value;
113             }
114         private:
115             size_t idx;
116             std::array<size_t, M> i;
117     };
118 
119     // Nested loop counter of run-time size (M loops of given sizes).
120     template <size_t M>
121     class dcounter {
122         public:
dcounter(const std::array<size_t,M> & N)123             dcounter(const std::array<size_t, M> &N)
124                 : idx(0),
125                   size(std::accumulate(N.begin(), N.end(),
126                             static_cast<size_t>(1), std::multiplies<size_t>())),
127                   N(N)
128             {
129                 std::fill(i.begin(), i.end(), static_cast<size_t>(0));
130             }
131 
operator [](size_t d) const132             size_t operator[](size_t d) const {
133                 return i[d];
134             }
135 
operator ++()136             dcounter& operator++() {
137                 for(size_t d = M; d--; ) {
138                     if (++i[d] < N[d]) break;
139                     i[d] = 0;
140                 }
141 
142                 ++idx;
143 
144                 return *this;
145             }
146 
operator size_t() const147             operator size_t() const {
148                 return idx;
149             }
150 
valid() const151             bool valid() const {
152                 return idx < size;
153             }
154         private:
155             size_t idx, size;
156             std::array<size_t, M> N, i;
157     };
158 } // namespace detail
159 
160 /// Scattered data interpolation with multilevel B-Splines.
161 template <size_t NDIM, typename real = double>
162 class mba {
163     public:
164         typedef real value_type;
165         typedef std::array<real,   NDIM> point;
166         typedef std::array<size_t, NDIM> index;
167 
168         static const size_t ndim = NDIM;
169 
170         std::vector< backend::command_queue >       queue;
171         std::vector< backend::device_vector<real> > phi;
172         point xmin, hinv;
173         index n, stride;
174 
175         /** Creates the approximation functor.
176          * `cmin` and `cmax` specify the domain boundaries, `coo` and `val`
177          * contain coordinates and values of the data points. `grid` is the
178          * initial control grid size. The approximation hierarchy will have at
179          * most `levels` and will stop when the desired approximation precision
180          * `tol` will be reached.
181          */
mba(const std::vector<backend::command_queue> & queue,const point & cmin,const point & cmax,const std::vector<point> & coo,std::vector<real> val,std::array<size_t,NDIM> grid,size_t levels=8,real tol=1e-8)182         mba(
183                 const std::vector<backend::command_queue> &queue,
184                 const point &cmin, const point &cmax,
185                 const std::vector<point> &coo, std::vector<real> val,
186                 std::array<size_t, NDIM> grid, size_t levels = 8, real tol = 1e-8
187            ) : queue(queue)
188         {
189             init(cmin, cmax, coo.begin(), coo.end(), val.begin(), grid, levels, tol);
190         }
191 
192         /** Creates the approximation functor.
193          * `cmin` and `cmax` specify the domain boundaries. Coordinates and
194          * values of the data points are passed as iterator ranges. `grid` is
195          * the initial control grid size. The approximation hierarchy will have
196          * at most `levels` and will stop when the desired approximation
197          * precision `tol` will be reached.
198          */
199         template <class CooIter, class ValIter>
mba(const std::vector<backend::command_queue> & queue,const point & cmin,const point & cmax,CooIter coo_begin,CooIter coo_end,ValIter val_begin,std::array<size_t,NDIM> grid,size_t levels=8,real tol=1e-8)200         mba(
201                 const std::vector<backend::command_queue> &queue,
202                 const point &cmin, const point &cmax,
203                 CooIter coo_begin, CooIter coo_end, ValIter val_begin,
204                 std::array<size_t, NDIM> grid, size_t levels = 8, real tol = 1e-8
205            ) : queue(queue)
206         {
207             init(cmin, cmax, coo_begin, coo_end, val_begin, grid, levels, tol);
208         }
209 
210 #if !defined(BOOST_NO_VARIADIC_TEMPLATES) && ((!defined(__GNUC__) || (__GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ > 6)) || defined(__clang__))
211         /// Provide interpolated values at given coordinates.
212         template <class... Expr>
operator ()(const Expr &...expr) const213         auto operator()(const Expr&... expr) const ->
214             mba_interp< mba, boost::tuple<const Expr&...> >
215         {
216             static_assert(sizeof...(Expr) == NDIM, "Wrong number of parameters");
217             return mba_interp< mba, boost::tuple<const Expr&...> >(*this, boost::tie(expr...));
218         }
219 #else
220 
221 #define VEXCL_FUNCALL_OPERATOR(z, n, data)                                     \
222   template <BOOST_PP_ENUM_PARAMS(n, class Expr)>                               \
223   mba_interp<mba, boost::tuple<BOOST_PP_ENUM_BINARY_PARAMS(                    \
224                       n, const Expr, &BOOST_PP_INTERCEPT)> >                   \
225   operator()(BOOST_PP_ENUM_BINARY_PARAMS(n, const Expr, &expr)) {              \
226     return mba_interp<mba, boost::tuple<BOOST_PP_ENUM_BINARY_PARAMS(           \
227                                n, const Expr, &BOOST_PP_INTERCEPT)> >(         \
228         *this, boost::tie(BOOST_PP_ENUM_PARAMS(n, expr)));                     \
229   }
230 
231 BOOST_PP_REPEAT_FROM_TO(1, 10, VEXCL_FUNCALL_OPERATOR, ~)
232 
233 #undef VEXCL_FUNCALL_OPERATOR
234 #endif
235     private:
236         template <class CooIter, class ValIter>
init(const point & cmin,const point & cmax,CooIter coo_begin,CooIter coo_end,ValIter val_begin,std::array<size_t,NDIM> grid,size_t levels=8,real tol=1e-8)237         void init(
238                 const point &cmin, const point &cmax,
239                 CooIter coo_begin, CooIter coo_end, ValIter val_begin,
240                 std::array<size_t, NDIM> grid, size_t levels = 8, real tol = 1e-8
241                 )
242         {
243             for(size_t k = 0; k < NDIM; ++k)
244                 assert(grid[k] > 1);
245 
246             double res0 = std::accumulate(
247                     val_begin, val_begin + (coo_end - coo_begin),
248                     static_cast<real>(0),
249                     [](real sum, real v) { return sum + v * v; }
250                     );
251 
252             std::unique_ptr<lattice> psi(
253                     new lattice(cmin, cmax, grid, coo_begin, coo_end, val_begin)
254                     );
255             double res = psi->update_data(coo_begin, coo_end, val_begin);
256 #ifdef VEXCL_MBA_VERBOSE
257             std::cout << "level  0: res = " << std::scientific << res << std::endl;
258 #endif
259 
260             for (size_t k = 1; (res > res0 * tol) && (k < levels); ++k) {
261                 for(size_t d = 0; d < NDIM; ++d) grid[d] = 2 * grid[d] - 1;
262 
263                 std::unique_ptr<lattice> f(
264                         new lattice(cmin, cmax, grid, coo_begin, coo_end, val_begin)
265                         );
266                 res = f->update_data(coo_begin, coo_end, val_begin);
267 #ifdef VEXCL_MBA_VERBOSE
268                 std::cout << "level " << k << std::scientific << ": res = " << res << std::endl;
269 #endif
270 
271                 f->append_refined(*psi);
272                 psi = std::move(f);
273             }
274 
275             xmin   = psi->xmin;
276             hinv   = psi->hinv;
277             n      = psi->n;
278             stride = psi->stride;
279 
280             phi.reserve(queue.size());
281 
282             for(auto q = queue.begin(); q != queue.end(); ++q)
283                 phi.push_back( backend::device_vector<real>(
284                             *q, psi->phi.size(), psi->phi.data(), backend::MEM_READ_ONLY
285                             ) );
286         }
287 
288         // Control lattice.
289         struct lattice {
290             point xmin, hinv;
291             index n, stride;
292             std::vector<real> phi;
293 
294             template <class CooIter, class ValIter>
latticevex::mba::lattice295             lattice(
296                     const point &cmin, const point &cmax, std::array<size_t, NDIM> grid,
297                     CooIter coo_begin, CooIter coo_end, ValIter val_begin
298                    ) : xmin(cmin), n(grid)
299             {
300                 for(size_t d = 0; d < NDIM; ++d) {
301                     hinv[d] = (grid[d] - 1) / (cmax[d] - cmin[d]);
302                     xmin[d] -= 1 / hinv[d];
303                     n[d]    += 2;
304                 }
305 
306                 stride[NDIM - 1] = 1;
307                 for(size_t d = NDIM - 1; d--; )
308                     stride[d] = stride[d + 1] * n[d + 1];
309 
310                 std::vector<real> delta(n[0] * stride[0], 0.0);
311                 std::vector<real> omega(n[0] * stride[0], 0.0);
312 
313                 auto p = coo_begin;
314                 auto v = val_begin;
315                 for(; p != coo_end; ++p, ++v) {
316                     if (!contained(cmin, cmax, *p)) continue;
317 
318                     index i;
319                     point s;
320 
321                     for(size_t d = 0; d < NDIM; ++d) {
322                         real u = ((*p)[d] - xmin[d]) * hinv[d];
323                         i[d] = static_cast<size_t>(std::floor(u) - 1);
324                         s[d] = u - std::floor(u);
325                     }
326 
327                     std::array<real, detail::power<4, NDIM>::value> w;
328                     real sw2 = 0;
329 
330                     for(detail::scounter<4, NDIM> d; d.valid(); ++d) {
331                         real buf = 1;
332                         for(size_t k = 0; k < NDIM; ++k)
333                             buf *= B(d[k], s[k]);
334 
335                         w[d] = buf;
336                         sw2 += buf * buf;
337                     }
338 
339                     for(detail::scounter<4, NDIM> d; d.valid(); ++d) {
340                         real phi = (*v) * w[d] / sw2;
341 
342                         size_t idx = 0;
343                         for(size_t k = 0; k < NDIM; ++k) {
344                             assert(i[k] + d[k] < n[k]);
345 
346                             idx += (i[k] + d[k]) * stride[k];
347                         }
348 
349                         real w2 = w[d] * w[d];
350 
351                         assert(idx < delta.size());
352 
353                         delta[idx] += w2 * phi;
354                         omega[idx] += w2;
355                     }
356                 }
357 
358                 phi.resize(omega.size());
359 
360                 for(auto w = omega.begin(), d = delta.begin(), f = phi.begin();
361                         w != omega.end();
362                         ++w, ++d, ++f
363                    )
364                 {
365                     if (std::fabs(*w) < 1e-32)
366                         *f = 0;
367                     else
368                         *f = (*d) / (*w);
369                 }
370             }
371 
372             // Get interpolated value at given position.
operator ()vex::mba::lattice373             real operator()(const point &p) const {
374                 index i;
375                 point s;
376 
377                 for(size_t d = 0; d < NDIM; ++d) {
378                     real u = (p[d] - xmin[d]) * hinv[d];
379                     i[d] = static_cast<size_t>(std::floor(u) - 1);
380                     s[d] = u - std::floor(u);
381                 }
382 
383                 real f = 0;
384 
385                 for(detail::scounter<4, NDIM> d; d.valid(); ++d) {
386                     real w = 1;
387                     for(size_t k = 0; k < NDIM; ++k)
388                         w *= B(d[k], s[k]);
389 
390                     f += w * get(i, d);
391                 }
392 
393                 return f;
394             }
395 
396             // Subtract interpolated values from data points.
397             template <class CooIter, class ValIter>
update_datavex::mba::lattice398             real update_data(
399                     CooIter coo_begin, CooIter coo_end, ValIter val_begin
400                     ) const
401             {
402                 auto c = coo_begin;
403                 auto v = val_begin;
404 
405                 real res = 0;
406 
407                 for(; c != coo_end; ++c, ++v) {
408                     *v -= (*this)(*c);
409 
410                     res += (*v) * (*v);
411                 }
412 
413                 return res;
414             }
415 
416             // Refine r and append it to the current control lattice.
append_refinedvex::mba::lattice417             void append_refined(const lattice &r) {
418                 static const std::array<real, 5> s = {{
419                     0.125, 0.500, 0.750, 0.500, 0.125
420                 }};
421 
422                 for(detail::dcounter<NDIM> i(r.n); i.valid(); ++i) {
423                     real f = r.phi[i];
424                     for(detail::scounter<5, NDIM> d; d.valid(); ++d) {
425                         index j;
426                         bool skip = false;
427                         size_t idx = 0;
428                         for(size_t k = 0; k < NDIM; ++k) {
429                             j[k] = 2 * i[k] + d[k] - 3;
430                             if (j[k] >= n[k]) { skip = true; break; }
431 
432                             idx += j[k] * stride[k];
433                         }
434 
435                         if (skip) continue;
436 
437                         real c = 1;
438                         for(size_t k = 0; k < NDIM; ++k) c *= s[d[k]];
439 
440                         phi[idx] += f * c;
441                     }
442                 }
443             }
444 
445             private:
446                 // Value of k-th B-Spline at t.
Bvex::mba::lattice447                 static inline real B(size_t k, real t) {
448                     assert(0 <= t && t < 1);
449                     assert(k < 4);
450 
451                     switch (k) {
452                         case 0:
453                             return (t * (t * (-t + 3) - 3) + 1) / 6;
454                         case 1:
455                             return (t * t * (3 * t - 6) + 4) / 6;
456                         case 2:
457                             return (t * (t * (-3 * t + 3) + 3) + 1) / 6;
458                         case 3:
459                         default:
460                             return t * t * t / 6;
461                     }
462                 }
463 
464                 // x is within [xmin, xmax].
containedvex::mba::lattice465                 static bool contained(
466                         const point &xmin, const point &xmax, const point &x)
467                 {
468                     for(size_t d = 0; d < NDIM; ++d) {
469                         static const real eps = 1e-12;
470 
471                         if (x[d] - eps <  xmin[d]) return false;
472                         if (x[d] + eps >= xmax[d]) return false;
473                     }
474 
475                     return true;
476                 }
477 
478                 // Get value of phi at index (i + d).
479                 template <class Shift>
getvex::mba::lattice480                 inline real get(const index &i, const Shift &d) const {
481                     size_t idx = 0;
482 
483                     for(size_t k = 0; k < NDIM; ++k) {
484                         size_t j = i[k] + d[k];
485 
486                         if (j >= n[k]) return 0;
487                         idx += j * stride[k];
488                     }
489 
490                     return phi[idx];
491                 }
492         };
493 };
494 
495 namespace traits {
496 
497 template <>
498 struct is_vector_expr_terminal< mba_terminal > : std::true_type {};
499 
500 template <>
501 struct proto_terminal_is_value< mba_terminal > : std::true_type {};
502 
503 template <class MBA, class ExprTuple>
504 struct terminal_preamble< mba_interp<MBA, ExprTuple> > {
getvex::traits::terminal_preamble505     static void get(backend::source_generator &src,
506             const mba_interp<MBA, ExprTuple>&,
507             const backend::command_queue&, const std::string &prm_name,
508             detail::kernel_generator_state_ptr)
509     {
510         typedef typename MBA::value_type real;
511 
512         std::string B = prm_name + "_B";
513 
514         src.begin_function<real>(B + "0");
515         src.begin_function_parameters();
516         src.template parameter<real>("t");
517         src.end_function_parameters();
518         src.new_line() << "return (t * (t * (-t + 3) - 3) + 1) / 6;";
519         src.end_function();
520 
521         src.begin_function<real>(B + "1");
522         src.begin_function_parameters();
523         src.template parameter<real>("t");
524         src.end_function_parameters();
525         src.new_line() << "return (t * t * (3 * t - 6) + 4) / 6;";
526         src.end_function();
527 
528         src.begin_function<real>(B + "2");
529         src.begin_function_parameters();
530         src.template parameter<real>("t");
531         src.end_function_parameters();
532         src.new_line() << "return (t * (t * (-3 * t + 3) + 3) + 1) / 6;";
533         src.end_function();
534 
535         src.begin_function<real>(B + "3");
536         src.begin_function_parameters();
537         src.template parameter<real>("t");
538         src.end_function_parameters();
539         src.new_line() << "return t * t * t / 6;";
540         src.end_function();
541 
542         src.begin_function<real>(prm_name + "_mba");
543         src.begin_function_parameters();
544 
545         for(size_t k = 0; k < MBA::ndim; ++k)
546             src.template parameter<real>("x") << k;
547 
548         for(size_t k = 0; k < MBA::ndim; ++k) {
549             src.template parameter<real>("c" + std::to_string(k));
550             src.template parameter<real>("h" + std::to_string(k));
551             src.parameter<size_t>("n" + std::to_string(k));
552             src.parameter<size_t>("m" + std::to_string(k));
553         }
554 
555         src.template parameter< global_ptr<const real> >("phi");
556         src.end_function_parameters();
557         src.new_line() << type_name<real>() << " u;";
558         for(size_t k = 0; k < MBA::ndim; ++k) {
559             src.new_line() << "u = (x" << k << " - c" << k << ") * h" << k << ";";
560             src.new_line() << type_name<size_t>() << " i" << k << " = floor(u) - 1;";
561             src.new_line() << type_name<real>() << " s" << k << " = u - floor(u);";
562         }
563         src.new_line() << type_name<real>() << " f = 0;";
564         src.new_line() << type_name<size_t>() << " j, idx;";
565 
566         for(detail::scounter<4,MBA::ndim> d; d.valid(); ++d) {
567             src.new_line() << "idx = 0;";
568             for(size_t k = 0; k < MBA::ndim; ++k) {
569                 src.new_line() << "j = i" << k << " + " << d[k] << ";";
570                 src.new_line() << "if (j < n" << k << ")";
571                 src.open("{").new_line() << "idx += j * m" << k << ";";
572             }
573 
574             src.new_line() << "f += ";
575             for(size_t k = 0; k < MBA::ndim; ++k) {
576                 if (k) src << " * ";
577                 src << B << d[k] << "(s" << k << ")";
578             }
579 
580             src << " * phi[idx];";
581 
582             for(size_t k = 0; k < MBA::ndim; ++k)
583                 src.close("}");
584         }
585         src.new_line() << "return f;";
586         src.end_function();
587     }
588 };
589 
590 template <class MBA, class ExprTuple>
591 struct kernel_param_declaration< mba_interp<MBA, ExprTuple> > {
getvex::traits::kernel_param_declaration592     static void get(backend::source_generator &src,
593             const mba_interp<MBA, ExprTuple> &term,
594             const backend::command_queue &queue, const std::string &prm_name,
595             detail::kernel_generator_state_ptr state)
596     {
597         typedef typename MBA::value_type real;
598         boost::fusion::for_each(term.coord, prmdecl(src, queue, prm_name, state));
599 
600         for(size_t k = 0; k < MBA::ndim; ++k) {
601             src.template parameter<real>(prm_name + "_c" + std::to_string(k));
602             src.template parameter<real>(prm_name + "_h" + std::to_string(k));
603             src.parameter<size_t>(prm_name + "_n" + std::to_string(k));
604             src.parameter<size_t>(prm_name + "_m" + std::to_string(k));
605         }
606 
607         src.parameter< global_ptr<const real> >(prm_name + "_phi");
608     }
609 
610     struct prmdecl {
611         backend::source_generator &s;
612         const backend::command_queue &queue;
613         const std::string &prm_name;
614         detail::kernel_generator_state_ptr state;
615         mutable int pos;
616 
prmdeclvex::traits::kernel_param_declaration::prmdecl617         prmdecl(backend::source_generator &s,
618                 const backend::command_queue &queue, const std::string &prm_name,
619                 detail::kernel_generator_state_ptr state
620             ) : s(s), queue(queue), prm_name(prm_name), state(state), pos(0)
621         {}
622 
623         template <class Expr>
operator ()vex::traits::kernel_param_declaration::prmdecl624         void operator()(const Expr &expr) const {
625             std::ostringstream prefix;
626             prefix << prm_name << "_x" << pos;
627             detail::declare_expression_parameter ctx(s, queue, prefix.str(), state);
628             detail::extract_terminals()(boost::proto::as_child(expr), ctx);
629 
630             pos++;
631         }
632     };
633 };
634 
635 template <class MBA, class ExprTuple>
636 struct local_terminal_init< mba_interp<MBA, ExprTuple> > {
getvex::traits::local_terminal_init637     static void get(backend::source_generator &src,
638             const mba_interp<MBA, ExprTuple> &term,
639             const backend::command_queue &queue, const std::string &prm_name,
640             detail::kernel_generator_state_ptr state)
641     {
642         boost::fusion::for_each(term.coord, local_init(src, queue, prm_name, state));
643     }
644 
645     struct local_init {
646         backend::source_generator &s;
647         const backend::command_queue &queue;
648         const std::string &prm_name;
649         detail::kernel_generator_state_ptr state;
650         mutable int pos;
651 
local_initvex::traits::local_terminal_init::local_init652         local_init(backend::source_generator &s,
653                 const backend::command_queue &queue, const std::string &prm_name,
654                 detail::kernel_generator_state_ptr state
655             ) : s(s), queue(queue), prm_name(prm_name), state(state), pos(0)
656         {}
657 
658         template <class Expr>
operator ()vex::traits::local_terminal_init::local_init659         void operator()(const Expr &expr) const {
660             std::ostringstream prefix;
661             prefix << prm_name << "_x" << pos;
662 
663             detail::output_local_preamble init_ctx(s, queue, prefix.str(), state);
664             boost::proto::eval(boost::proto::as_child(expr), init_ctx);
665 
666             pos++;
667         }
668     };
669 };
670 
671 template <class MBA, class ExprTuple>
672 struct partial_vector_expr< mba_interp<MBA, ExprTuple> > {
getvex::traits::partial_vector_expr673     static void get(backend::source_generator &src,
674             const mba_interp<MBA, ExprTuple> &term,
675             const backend::command_queue &queue, const std::string &prm_name,
676             detail::kernel_generator_state_ptr state)
677     {
678         src << prm_name << "_mba(";
679 
680         boost::fusion::for_each(term.coord, buildexpr(src, queue, prm_name, state));
681 
682         for(size_t k = 0; k < MBA::ndim; ++k) {
683             src << ", " << prm_name << "_c" << k
684                 << ", " << prm_name << "_h" << k
685                 << ", " << prm_name << "_n" << k
686                 << ", " << prm_name << "_m" << k;
687         }
688 
689         src << ", " << prm_name << "_phi)";
690     }
691 
692     struct buildexpr {
693         backend::source_generator &s;
694         const backend::command_queue &queue;
695         const std::string &prm_name;
696         detail::kernel_generator_state_ptr state;
697         mutable int pos;
698 
buildexprvex::traits::partial_vector_expr::buildexpr699         buildexpr(backend::source_generator &s,
700                 const backend::command_queue &queue, const std::string &prm_name,
701                 detail::kernel_generator_state_ptr state
702             ) : s(s), queue(queue), prm_name(prm_name), state(state), pos(0)
703         {}
704 
705         template <class Expr>
operator ()vex::traits::partial_vector_expr::buildexpr706         void operator()(const Expr &expr) const {
707             if(pos) s << ", ";
708 
709             std::ostringstream prefix;
710             prefix << prm_name << "_x" << pos;
711 
712             detail::vector_expr_context ctx(s, queue, prefix.str(), state);
713             boost::proto::eval(boost::proto::as_child(expr), ctx);
714 
715             pos++;
716         }
717     };
718 };
719 
720 template <class MBA, class ExprTuple>
721 struct kernel_arg_setter< mba_interp<MBA, ExprTuple> > {
setvex::traits::kernel_arg_setter722     static void set(const mba_interp<MBA, ExprTuple> &term,
723             backend::kernel &kernel, unsigned part, size_t index_offset,
724             detail::kernel_generator_state_ptr state)
725     {
726 
727         boost::fusion::for_each(term.coord,
728                 setargs(kernel, part, index_offset, state));
729 
730         for(size_t k = 0; k < MBA::ndim; ++k) {
731             kernel.push_arg(term.cloud.xmin[k]);
732             kernel.push_arg(term.cloud.hinv[k]);
733             kernel.push_arg(term.cloud.n[k]);
734             kernel.push_arg(term.cloud.stride[k]);
735         }
736         kernel.push_arg(term.cloud.phi[part]);
737     }
738 
739     struct setargs {
740         backend::kernel &kernel;
741         unsigned part;
742         size_t index_offset;
743         detail::kernel_generator_state_ptr state;
744 
setargsvex::traits::kernel_arg_setter::setargs745         setargs(
746                 backend::kernel &kernel, unsigned part, size_t index_offset,
747                 detail::kernel_generator_state_ptr state
748                )
749             : kernel(kernel), part(part), index_offset(index_offset), state(state)
750         {}
751 
752         template <class Expr>
operator ()vex::traits::kernel_arg_setter::setargs753         void operator()(const Expr &expr) const {
754             detail::set_expression_argument ctx(kernel, part, index_offset, state);
755             detail::extract_terminals()( boost::proto::as_child(expr), ctx);
756         }
757     };
758 };
759 
760 template <class MBA, class ExprTuple>
761 struct expression_properties< mba_interp<MBA, ExprTuple> > {
getvex::traits::expression_properties762     static void get(const mba_interp<MBA, ExprTuple> &term,
763             std::vector<backend::command_queue> &queue_list,
764             std::vector<size_t> &partition,
765             size_t &size
766             )
767     {
768         boost::fusion::for_each(term.coord, extrprop(queue_list, partition, size));
769     }
770 
771     struct extrprop {
772         std::vector<backend::command_queue> &queue_list;
773         std::vector<size_t> &partition;
774         size_t &size;
775 
extrpropvex::traits::expression_properties::extrprop776         extrprop(std::vector<backend::command_queue> &queue_list,
777             std::vector<size_t> &partition, size_t &size
778             ) : queue_list(queue_list), partition(partition), size(size)
779         {}
780 
781         template <class Expr>
operator ()vex::traits::expression_properties::extrprop782         void operator()(const Expr &expr) const {
783             if (queue_list.empty()) {
784                 detail::get_expression_properties prop;
785                 detail::extract_terminals()(boost::proto::as_child(expr), prop);
786 
787                 queue_list = prop.queue;
788                 partition  = prop.part;
789                 size       = prop.size;
790             }
791         }
792     };
793 };
794 
795 } //namespace traits
796 
797 } // namespace vex
798 
799 
800 #endif
801