1 #ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
2 #define STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/prim/meta.hpp>
6 #include <stan/math/opencl/matrix_cl_view.hpp>
7 #include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
8 #include <stan/math/opencl/kernel_generator/broadcast.hpp>
9 #include <stan/math/opencl/kernel_generator/binary_operation.hpp>
10 #include <stan/math/opencl/kernel_generator/name_generator.hpp>
11 #include <stan/math/opencl/kernel_generator/operation_cl.hpp>
12 #include <stan/math/opencl/kernel_generator/type_str.hpp>
13 #include <map>
14 #include <string>
15 #include <type_traits>
16 #include <utility>
17 
18 namespace stan {
19 namespace math {
20 namespace internal {
21 
22 /**
23  * Implementation of an optimization for usage of rowwise reduction in
24  * matrix-vector multiplication.
25  */
26 template <typename Arg>
27 struct matvec_mul_opt {
28   // in general the optimization is not possible
29   enum { is_possible = 0 };
30 
viewstan::math::internal::matvec_mul_opt31   static matrix_cl_view view(const Arg&) { return matrix_cl_view::Entire; }
32 
get_kernel_partsstan::math::internal::matvec_mul_opt33   static kernel_parts get_kernel_parts(
34       const Arg& a, std::map<const void*, const char*>& generated,
35       std::map<const void*, const char*>& generated_all,
36       name_generator& name_gen, const std::string& row_index_name,
37       const std::string& col_index_name) {
38     return {};
39   }
40 };
41 
42 template <typename Mat, typename VecT>
43 struct matvec_mul_opt<elt_multiply_<Mat, broadcast_<VecT, true, false>>> {
44   // if the argument of rowwise reduction is multiplication with a broadcast
45   // vector we can do the optimization
46   enum { is_possible = 1 };
47   using Arg = elt_multiply_<Mat, broadcast_<VecT, true, false>>;
48 
49   /**
50    * Return view of the vector.
51    * @param a argument to rowwise reduction (multiplication with second factor
52    * being broadcast vector)
53    * @return view
54    */
viewstan::math::internal::matvec_mul_opt55   static matrix_cl_view view(const Arg& a) {
56     return a.template get_arg<1>().template get_arg<0>().view();
57   }
58 
59   /**
60    * Generates kernel code for the argument of rowwise reduction, applying the
61    * optimization - ignoring the triangular view of the vector, as it is already
62    * handeled by rowwise reduction.
63    * @param mul argument of the rowwise reduction
64    * @param[in,out] generated map from (pointer to) already generated local
65    * operations to variable names
66    * @param[in,out] generated_all map from (pointer to) already generated all
67    * operations to variable names
68    * @param name_gen name generator for this kernel
69    * @param row_index_name row index variable name
70    * @param col_index_name column index variable name
71    * @return part of kernel with code for this and nested expressions
72    */
get_kernel_partsstan::math::internal::matvec_mul_opt73   static kernel_parts get_kernel_parts(
74       const Arg& mul, std::map<const void*, const char*>& generated,
75       std::map<const void*, const char*>& generated_all,
76       name_generator& name_gen, const std::string& row_index_name,
77       const std::string& col_index_name) {
78     kernel_parts res{};
79     if (generated.count(&mul) == 0) {
80       mul.var_name_ = name_gen.generate();
81       generated[&mul] = "";
82 
83       const auto& matrix = mul.template get_arg<0>();
84       const auto& broadcast = mul.template get_arg<1>();
85       res = matrix.get_kernel_parts(generated, generated_all, name_gen,
86                                     row_index_name, col_index_name, true);
87       if (generated.count(&broadcast) == 0) {
88         broadcast.var_name_ = name_gen.generate();
89         generated[&broadcast] = "";
90 
91         const auto& vec = broadcast.template get_arg<0>();
92         std::string row_index_name_bc = row_index_name;
93         std::string col_index_name_bc = col_index_name;
94         broadcast.modify_argument_indices(row_index_name_bc, col_index_name_bc);
95         res += vec.get_kernel_parts(generated, generated_all, name_gen,
96                                     row_index_name_bc, col_index_name_bc, true);
97         res += broadcast.generate(row_index_name, col_index_name, true,
98                                   vec.var_name_);
99       }
100       res += mul.generate(row_index_name, col_index_name, true,
101                           matrix.var_name_, broadcast.var_name_);
102     }
103     return res;
104   }
105 };
106 
107 }  // namespace internal
108 
109 /** \addtogroup opencl_kernel_generator
110  *  @{
111  */
112 
113 /**
114  * Represents a rowwise reduction in kernel generator expressions.
115  * @tparam Derived derived type
116  * @tparam T type of first argument
117  * @tparam operation type with member function generate that accepts two
118  * variable names and returns OpenCL source code for reduction operation_cl
119  * @tparam PassZero whether \c operation passes trough zeros
120  */
121 template <typename Derived, typename T, typename operation, bool PassZero>
122 class rowwise_reduction
123     : public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
124                           T> {
125  public:
126   using T_no_ref = std::remove_reference_t<T>;
127   using Scalar = typename T_no_ref::Scalar;
128   using base = operation_cl<Derived, Scalar, T>;
129   using base::var_name_;
130 
131  protected:
132   std::string init_;
133 
134  public:
135   using base::rows;
136   /**
137    * Constructor
138    * @param a the expression to reduce
139    * @param init OpenCL source code of initialization value for reduction
140    */
rowwise_reduction(T && a,const std::string & init)141   explicit rowwise_reduction(T&& a, const std::string& init)
142       : base(std::forward<T>(a)), init_(init) {}
143 
144   /**
145    * Generates kernel code for this and nested expressions.
146    * @param[in,out] generated map from (pointer to) already generated local
147    * operations to variable names
148    * @param[in,out] generated_all map from (pointer to) already generated all
149    * operations to variable names
150    * @param name_gen name generator for this kernel
151    * @param row_index_name row index variable name
152    * @param col_index_name column index variable name
153    * @param view_handled whether caller already handled matrix view
154    * @return part of kernel with code for this and nested expressions
155    */
get_kernel_parts(std::map<const void *,const char * > & generated,std::map<const void *,const char * > & generated_all,name_generator & name_gen,const std::string & row_index_name,const std::string & col_index_name,bool view_handled) const156   inline kernel_parts get_kernel_parts(
157       std::map<const void*, const char*>& generated,
158       std::map<const void*, const char*>& generated_all,
159       name_generator& name_gen, const std::string& row_index_name,
160       const std::string& col_index_name, bool view_handled) const {
161     kernel_parts res{};
162     if (generated.count(this) == 0) {
163       this->var_name_ = name_gen.generate();
164       generated[this] = "";
165 
166       std::map<const void*, const char*> generated2;
167       if (PassZero && internal::matvec_mul_opt<T_no_ref>::is_possible) {
168         res = internal::matvec_mul_opt<T_no_ref>::get_kernel_parts(
169             this->template get_arg<0>(), generated2, generated_all, name_gen,
170             row_index_name, var_name_ + "_j");
171       } else {
172         res = this->template get_arg<0>().get_kernel_parts(
173             generated2, generated_all, name_gen, row_index_name,
174             var_name_ + "_j", view_handled || PassZero);
175       }
176       kernel_parts my_part
177           = generate(row_index_name, col_index_name, view_handled,
178                      this->template get_arg<0>().var_name_);
179       res += my_part;
180       res.body = res.body_prefix + res.body;
181       res.body_prefix = "";
182     }
183     return res;
184   }
185 
186   /**
187    * Generates kernel code for this expression.
188    * @param row_index_name row index variable name
189    * @param col_index_name column index variable name
190    * @param view_handled whether whether caller already handled matrix view
191    * @param var_name_arg name of the variable in kernel that holds argument to
192    * this expression
193    * @return part of kernel with code for this expression
194    */
generate(const std::string & row_index_name,const std::string & col_index_name,const bool view_handled,const std::string & var_name_arg) const195   inline kernel_parts generate(const std::string& row_index_name,
196                                const std::string& col_index_name,
197                                const bool view_handled,
198                                const std::string& var_name_arg) const {
199     kernel_parts res;
200     res.body_prefix
201         = type_str<Scalar>() + " " + var_name_ + " = " + init_ + ";\n";
202     if (PassZero) {
203       res.body_prefix += "int " + var_name_ + "_start = contains_nonzero("
204                          + var_name_ + "_view, LOWER) ? 0 : " + row_index_name
205                          + ";\n";
206       if (internal::matvec_mul_opt<T_no_ref>::is_possible) {
207         res.body_prefix += "int " + var_name_ + "_end_temp = contains_nonzero("
208                            + var_name_ + "_view, UPPER) ? " + var_name_
209                            + "_cols : min(" + var_name_ + "_cols, "
210                            + row_index_name + " + 1);\n";
211         res.body_prefix += "int " + var_name_ + "_end = contains_nonzero("
212                            + var_name_ + "_vec_view, UPPER) ? " + var_name_
213                            + "_end_temp : min(1, " + var_name_
214                            + "_end_temp);\n";
215       } else {
216         res.body_prefix += "int " + var_name_ + "_end = contains_nonzero("
217                            + var_name_ + "_view, UPPER) ? " + var_name_
218                            + "_cols : min(" + var_name_ + "_cols, "
219                            + row_index_name + " + 1);\n";
220       }
221       res.body_prefix += "for(int " + var_name_ + "_j = " + var_name_
222                          + "_start; " + var_name_ + "_j < " + var_name_
223                          + "_end; " + var_name_ + "_j++){\n";
224     } else {
225       res.body_prefix += "for(int " + var_name_ + "_j = 0; " + var_name_
226                          + "_j < " + var_name_ + "_cols; " + var_name_
227                          + "_j++){\n";
228     }
229     res.body += var_name_ + " = " + operation::generate(var_name_, var_name_arg)
230                 + ";\n}\n";
231     res.args = "int " + var_name_ + "_view, int " + var_name_ + "_cols, ";
232     if (PassZero && internal::matvec_mul_opt<T_no_ref>::is_possible) {
233       res.args += "int " + var_name_ + "_vec_view, ";
234     }
235     return res;
236   }
237 
238   /**
239    * Sets kernel arguments for this and nested expressions.
240    * @param[in,out] generated map from (pointer to) already generated local
241    * operations to variable names
242    * @param[in,out] generated_all map from (pointer to) already generated all
243    * operations to variable names
244    * @param kernel kernel to set arguments on
245    * @param[in,out] arg_num consecutive number of the first argument to set.
246    * This is incremented for each argument set by this function.
247    */
set_args(std::map<const void *,const char * > & generated,std::map<const void *,const char * > & generated_all,cl::Kernel & kernel,int & arg_num) const248   inline void set_args(std::map<const void*, const char*>& generated,
249                        std::map<const void*, const char*>& generated_all,
250                        cl::Kernel& kernel, int& arg_num) const {
251     if (generated.count(this) == 0) {
252       generated[this] = "";
253       std::map<const void*, const char*> generated2;
254       this->template get_arg<0>().set_args(generated2, generated_all, kernel,
255                                            arg_num);
256       kernel.setArg(arg_num++, this->template get_arg<0>().view());
257       kernel.setArg(arg_num++, this->template get_arg<0>().cols());
258       if (PassZero && internal::matvec_mul_opt<T>::is_possible) {
259         kernel.setArg(arg_num++, internal::matvec_mul_opt<T_no_ref>::view(
260                                      this->template get_arg<0>()));
261       }
262     }
263   }
264 
265   /**
266    * Number of columns of a matrix that would be the result of evaluating this
267    * expression.
268    * @return number of columns
269    */
cols() const270   inline int cols() const { return 1; }
271 
272   /**
273    * Determine indices of extreme sub- and superdiagonals written.
274    * @return pair of indices - bottom and top diagonal
275    */
extreme_diagonals() const276   inline std::pair<int, int> extreme_diagonals() const {
277     return {-rows() + 1, cols() - 1};
278   }
279 };
280 
281 /**
282  * Operation for sum reduction.
283  */
284 struct sum_op {
285   /**
286    * Generates sum reduction kernel code.
287    * @param a first variable
288    * @param b second variable
289    * @return reduction code
290    */
generatestan::math::sum_op291   inline static std::string generate(const std::string& a,
292                                      const std::string& b) {
293     return a + " + " + b;
294   }
295 };
296 
297 /**
298  * Represents rowwise sum reduction in kernel generator expressions.
299  * @tparam T type of expression
300  */
301 template <typename T>
302 class rowwise_sum_
303     : public rowwise_reduction<rowwise_sum_<T>, T, sum_op, true> {
304   using base = rowwise_reduction<rowwise_sum_<T>, T, sum_op, true>;
305   using base::arguments_;
306 
307  public:
rowwise_sum_(T && a)308   explicit rowwise_sum_(T&& a) : base(std::forward<T>(a), "0") {}
309 
310   /**
311    * Creates a deep copy of this expression.
312    * @return copy of \c *this
313    */
deep_copy() const314   inline auto deep_copy() const {
315     auto&& arg_copy = this->template get_arg<0>().deep_copy();
316     return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
317         std::move(arg_copy));
318   }
319 };
320 
321 /**
322  * Rowwise sum reduction of a kernel generator expression.
323  * @tparam T type of input expression
324  * @param a the expression to reduce
325  * @return sum
326  */
327 template <typename T,
328           typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_sum(T && a)329 inline auto rowwise_sum(T&& a) {
330   auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
331   return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
332       std::move(arg_copy));
333 }
334 
335 /**
336  * Operation for product reduction.
337  */
338 struct prod_op {
339   /**
340    * Generates prod reduction kernel code.
341    * @param a first variable
342    * @param b second variable
343    * @return reduction code
344    */
generatestan::math::prod_op345   inline static std::string generate(const std::string& a,
346                                      const std::string& b) {
347     return a + " * " + b;
348   }
349 };
350 
351 /**
352  * Represents rowwise product reduction in kernel generator expressions.
353  * @tparam T type of expression
354  */
355 template <typename T>
356 class rowwise_prod_
357     : public rowwise_reduction<rowwise_prod_<T>, T, prod_op, false> {
358   using base = rowwise_reduction<rowwise_prod_<T>, T, prod_op, false>;
359   using base::arguments_;
360 
361  public:
rowwise_prod_(T && a)362   explicit rowwise_prod_(T&& a) : base(std::forward<T>(a), "1") {}
363 
364   /**
365    * Creates a deep copy of this expression.
366    * @return copy of \c *this
367    */
deep_copy() const368   inline auto deep_copy() const {
369     auto&& arg_copy = this->template get_arg<0>().deep_copy();
370     return rowwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
371         std::move(arg_copy));
372   }
373 };
374 
375 /**
376  * Rowwise product reduction of a kernel generator expression.
377  * @tparam T type of input expression
378  * @param a the expression to reduce
379  * @return prod
380  */
381 template <typename T,
382           typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_prod(T && a)383 inline auto rowwise_prod(T&& a) {
384   auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
385   return rowwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
386       std::move(arg_copy));
387 }
388 
389 /**
390  * Operation for max reduction.
391  * @tparam T type to reduce
392  */
393 template <typename T>
394 struct max_op {
395   /**
396    * Generates max reduction kernel code.
397    * @param a first variable
398    * @param b second variable
399    * @return reduction code
400    */
generatestan::math::max_op401   inline static std::string generate(const std::string& a,
402                                      const std::string& b) {
403     if (std::is_floating_point<T>()) {
404       return "fmax(" + a + ", " + b + ")";
405     }
406     return "max(" + a + ", " + b + ")";
407   }
408 
initstan::math::max_op409   inline static std::string init() {
410     if (std::is_floating_point<T>()) {
411       return "-INFINITY";
412     }
413     return "INT_MIN";
414   }
415 };
416 
417 /**
418  * Represents rowwise max reduction in kernel generator expressions.
419  * @tparam T type of expression
420  */
421 template <typename T>
422 class rowwise_max_
423     : public rowwise_reduction<
424           rowwise_max_<T>, T,
425           max_op<typename std::remove_reference_t<T>::Scalar>, false> {
426   using op = max_op<typename std::remove_reference_t<T>::Scalar>;
427   using base = rowwise_reduction<rowwise_max_<T>, T, op, false>;
428   using base::arguments_;
429 
430  public:
rowwise_max_(T && a)431   explicit rowwise_max_(T&& a) : base(std::forward<T>(a), op::init()) {}
432   /**
433    * Creates a deep copy of this expression.
434    * @return copy of \c *this
435    */
deep_copy() const436   inline auto deep_copy() const {
437     auto&& arg_copy = this->template get_arg<0>().deep_copy();
438     return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
439         std::move(arg_copy));
440   }
441 };
442 
443 /**
444  * Rowwise max reduction of a kernel generator expression.
445  * @tparam T type of input expression
446  * @param a the expression to reduce
447  * @return max
448  */
449 template <typename T,
450           typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_max(T && a)451 inline auto rowwise_max(T&& a) {
452   auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
453   return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
454       std::move(arg_copy));
455 }
456 /**
457  * Operation for min reduction.
458  * @tparam T type to reduce
459  */
460 template <typename T>
461 struct min_op {
462   /**
463    * Generates min reduction kernel code.
464    * @param a first variable
465    * @param b second variable
466    * @return reduction code
467    */
generatestan::math::min_op468   inline static std::string generate(const std::string& a,
469                                      const std::string& b) {
470     if (std::is_floating_point<T>()) {
471       return "fmin(" + a + ", " + b + ")";
472     }
473     return "min(" + a + ", " + b + ")";
474   }
475 
initstan::math::min_op476   inline static std::string init() {
477     if (std::is_floating_point<T>()) {
478       return "INFINITY";
479     }
480     return "INT_MAX";
481   }
482 };
483 
484 /**
485  * Represents rowwise min reduction in kernel generator expressions.
486  * @tparam T type of expression
487  */
488 template <typename T>
489 class rowwise_min_
490     : public rowwise_reduction<
491           rowwise_min_<T>, T,
492           min_op<typename std::remove_reference_t<T>::Scalar>, false> {
493   using op = min_op<typename std::remove_reference_t<T>::Scalar>;
494   using base = rowwise_reduction<rowwise_min_<T>, T, op, false>;
495   using base::arguments_;
496 
497  public:
rowwise_min_(T && a)498   explicit rowwise_min_(T&& a) : base(std::forward<T>(a), op::init()) {}
499   /**
500    * Creates a deep copy of this expression.
501    * @return copy of \c *this
502    */
deep_copy() const503   inline auto deep_copy() const {
504     auto&& arg_copy = this->template get_arg<0>().deep_copy();
505     return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
506         std::move(arg_copy));
507   }
508 };
509 
510 /**
511  * Min reduction of a kernel generator expression.
512  * @tparam T type of input expression
513  * @param a the expression to reduce
514  * @return min
515  */
516 template <typename T,
517           typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_min(T && a)518 inline auto rowwise_min(T&& a) {
519   auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
520   return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
521       std::move(arg_copy));
522 }
523 /** @}*/
524 }  // namespace math
525 }  // namespace stan
526 
527 #endif
528 #endif
529