1 #ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
2 #define STAN_MATH_OPENCL_KERNEL_GENERATOR_ROWWISE_REDUCTION_HPP
3 #ifdef STAN_OPENCL
4
5 #include <stan/math/prim/meta.hpp>
6 #include <stan/math/opencl/matrix_cl_view.hpp>
7 #include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
8 #include <stan/math/opencl/kernel_generator/broadcast.hpp>
9 #include <stan/math/opencl/kernel_generator/binary_operation.hpp>
10 #include <stan/math/opencl/kernel_generator/name_generator.hpp>
11 #include <stan/math/opencl/kernel_generator/operation_cl.hpp>
12 #include <stan/math/opencl/kernel_generator/type_str.hpp>
13 #include <map>
14 #include <string>
15 #include <type_traits>
16 #include <utility>
17
18 namespace stan {
19 namespace math {
20 namespace internal {
21
22 /**
23 * Implementation of an optimization for usage of rowwise reduction in
24 * matrix-vector multiplication.
25 */
26 template <typename Arg>
27 struct matvec_mul_opt {
28 // in general the optimization is not possible
29 enum { is_possible = 0 };
30
viewstan::math::internal::matvec_mul_opt31 static matrix_cl_view view(const Arg&) { return matrix_cl_view::Entire; }
32
get_kernel_partsstan::math::internal::matvec_mul_opt33 static kernel_parts get_kernel_parts(
34 const Arg& a, std::map<const void*, const char*>& generated,
35 std::map<const void*, const char*>& generated_all,
36 name_generator& name_gen, const std::string& row_index_name,
37 const std::string& col_index_name) {
38 return {};
39 }
40 };
41
42 template <typename Mat, typename VecT>
43 struct matvec_mul_opt<elt_multiply_<Mat, broadcast_<VecT, true, false>>> {
44 // if the argument of rowwise reduction is multiplication with a broadcast
45 // vector we can do the optimization
46 enum { is_possible = 1 };
47 using Arg = elt_multiply_<Mat, broadcast_<VecT, true, false>>;
48
49 /**
50 * Return view of the vector.
51 * @param a argument to rowwise reduction (multiplication with second factor
52 * being broadcast vector)
53 * @return view
54 */
viewstan::math::internal::matvec_mul_opt55 static matrix_cl_view view(const Arg& a) {
56 return a.template get_arg<1>().template get_arg<0>().view();
57 }
58
59 /**
60 * Generates kernel code for the argument of rowwise reduction, applying the
61 * optimization - ignoring the triangular view of the vector, as it is already
62 * handeled by rowwise reduction.
63 * @param mul argument of the rowwise reduction
64 * @param[in,out] generated map from (pointer to) already generated local
65 * operations to variable names
66 * @param[in,out] generated_all map from (pointer to) already generated all
67 * operations to variable names
68 * @param name_gen name generator for this kernel
69 * @param row_index_name row index variable name
70 * @param col_index_name column index variable name
71 * @return part of kernel with code for this and nested expressions
72 */
get_kernel_partsstan::math::internal::matvec_mul_opt73 static kernel_parts get_kernel_parts(
74 const Arg& mul, std::map<const void*, const char*>& generated,
75 std::map<const void*, const char*>& generated_all,
76 name_generator& name_gen, const std::string& row_index_name,
77 const std::string& col_index_name) {
78 kernel_parts res{};
79 if (generated.count(&mul) == 0) {
80 mul.var_name_ = name_gen.generate();
81 generated[&mul] = "";
82
83 const auto& matrix = mul.template get_arg<0>();
84 const auto& broadcast = mul.template get_arg<1>();
85 res = matrix.get_kernel_parts(generated, generated_all, name_gen,
86 row_index_name, col_index_name, true);
87 if (generated.count(&broadcast) == 0) {
88 broadcast.var_name_ = name_gen.generate();
89 generated[&broadcast] = "";
90
91 const auto& vec = broadcast.template get_arg<0>();
92 std::string row_index_name_bc = row_index_name;
93 std::string col_index_name_bc = col_index_name;
94 broadcast.modify_argument_indices(row_index_name_bc, col_index_name_bc);
95 res += vec.get_kernel_parts(generated, generated_all, name_gen,
96 row_index_name_bc, col_index_name_bc, true);
97 res += broadcast.generate(row_index_name, col_index_name, true,
98 vec.var_name_);
99 }
100 res += mul.generate(row_index_name, col_index_name, true,
101 matrix.var_name_, broadcast.var_name_);
102 }
103 return res;
104 }
105 };
106
107 } // namespace internal
108
109 /** \addtogroup opencl_kernel_generator
110 * @{
111 */
112
113 /**
114 * Represents a rowwise reduction in kernel generator expressions.
115 * @tparam Derived derived type
116 * @tparam T type of first argument
117 * @tparam operation type with member function generate that accepts two
118 * variable names and returns OpenCL source code for reduction operation_cl
119 * @tparam PassZero whether \c operation passes trough zeros
120 */
121 template <typename Derived, typename T, typename operation, bool PassZero>
122 class rowwise_reduction
123 : public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
124 T> {
125 public:
126 using T_no_ref = std::remove_reference_t<T>;
127 using Scalar = typename T_no_ref::Scalar;
128 using base = operation_cl<Derived, Scalar, T>;
129 using base::var_name_;
130
131 protected:
132 std::string init_;
133
134 public:
135 using base::rows;
136 /**
137 * Constructor
138 * @param a the expression to reduce
139 * @param init OpenCL source code of initialization value for reduction
140 */
rowwise_reduction(T && a,const std::string & init)141 explicit rowwise_reduction(T&& a, const std::string& init)
142 : base(std::forward<T>(a)), init_(init) {}
143
144 /**
145 * Generates kernel code for this and nested expressions.
146 * @param[in,out] generated map from (pointer to) already generated local
147 * operations to variable names
148 * @param[in,out] generated_all map from (pointer to) already generated all
149 * operations to variable names
150 * @param name_gen name generator for this kernel
151 * @param row_index_name row index variable name
152 * @param col_index_name column index variable name
153 * @param view_handled whether caller already handled matrix view
154 * @return part of kernel with code for this and nested expressions
155 */
get_kernel_parts(std::map<const void *,const char * > & generated,std::map<const void *,const char * > & generated_all,name_generator & name_gen,const std::string & row_index_name,const std::string & col_index_name,bool view_handled) const156 inline kernel_parts get_kernel_parts(
157 std::map<const void*, const char*>& generated,
158 std::map<const void*, const char*>& generated_all,
159 name_generator& name_gen, const std::string& row_index_name,
160 const std::string& col_index_name, bool view_handled) const {
161 kernel_parts res{};
162 if (generated.count(this) == 0) {
163 this->var_name_ = name_gen.generate();
164 generated[this] = "";
165
166 std::map<const void*, const char*> generated2;
167 if (PassZero && internal::matvec_mul_opt<T_no_ref>::is_possible) {
168 res = internal::matvec_mul_opt<T_no_ref>::get_kernel_parts(
169 this->template get_arg<0>(), generated2, generated_all, name_gen,
170 row_index_name, var_name_ + "_j");
171 } else {
172 res = this->template get_arg<0>().get_kernel_parts(
173 generated2, generated_all, name_gen, row_index_name,
174 var_name_ + "_j", view_handled || PassZero);
175 }
176 kernel_parts my_part
177 = generate(row_index_name, col_index_name, view_handled,
178 this->template get_arg<0>().var_name_);
179 res += my_part;
180 res.body = res.body_prefix + res.body;
181 res.body_prefix = "";
182 }
183 return res;
184 }
185
186 /**
187 * Generates kernel code for this expression.
188 * @param row_index_name row index variable name
189 * @param col_index_name column index variable name
190 * @param view_handled whether whether caller already handled matrix view
191 * @param var_name_arg name of the variable in kernel that holds argument to
192 * this expression
193 * @return part of kernel with code for this expression
194 */
generate(const std::string & row_index_name,const std::string & col_index_name,const bool view_handled,const std::string & var_name_arg) const195 inline kernel_parts generate(const std::string& row_index_name,
196 const std::string& col_index_name,
197 const bool view_handled,
198 const std::string& var_name_arg) const {
199 kernel_parts res;
200 res.body_prefix
201 = type_str<Scalar>() + " " + var_name_ + " = " + init_ + ";\n";
202 if (PassZero) {
203 res.body_prefix += "int " + var_name_ + "_start = contains_nonzero("
204 + var_name_ + "_view, LOWER) ? 0 : " + row_index_name
205 + ";\n";
206 if (internal::matvec_mul_opt<T_no_ref>::is_possible) {
207 res.body_prefix += "int " + var_name_ + "_end_temp = contains_nonzero("
208 + var_name_ + "_view, UPPER) ? " + var_name_
209 + "_cols : min(" + var_name_ + "_cols, "
210 + row_index_name + " + 1);\n";
211 res.body_prefix += "int " + var_name_ + "_end = contains_nonzero("
212 + var_name_ + "_vec_view, UPPER) ? " + var_name_
213 + "_end_temp : min(1, " + var_name_
214 + "_end_temp);\n";
215 } else {
216 res.body_prefix += "int " + var_name_ + "_end = contains_nonzero("
217 + var_name_ + "_view, UPPER) ? " + var_name_
218 + "_cols : min(" + var_name_ + "_cols, "
219 + row_index_name + " + 1);\n";
220 }
221 res.body_prefix += "for(int " + var_name_ + "_j = " + var_name_
222 + "_start; " + var_name_ + "_j < " + var_name_
223 + "_end; " + var_name_ + "_j++){\n";
224 } else {
225 res.body_prefix += "for(int " + var_name_ + "_j = 0; " + var_name_
226 + "_j < " + var_name_ + "_cols; " + var_name_
227 + "_j++){\n";
228 }
229 res.body += var_name_ + " = " + operation::generate(var_name_, var_name_arg)
230 + ";\n}\n";
231 res.args = "int " + var_name_ + "_view, int " + var_name_ + "_cols, ";
232 if (PassZero && internal::matvec_mul_opt<T_no_ref>::is_possible) {
233 res.args += "int " + var_name_ + "_vec_view, ";
234 }
235 return res;
236 }
237
238 /**
239 * Sets kernel arguments for this and nested expressions.
240 * @param[in,out] generated map from (pointer to) already generated local
241 * operations to variable names
242 * @param[in,out] generated_all map from (pointer to) already generated all
243 * operations to variable names
244 * @param kernel kernel to set arguments on
245 * @param[in,out] arg_num consecutive number of the first argument to set.
246 * This is incremented for each argument set by this function.
247 */
set_args(std::map<const void *,const char * > & generated,std::map<const void *,const char * > & generated_all,cl::Kernel & kernel,int & arg_num) const248 inline void set_args(std::map<const void*, const char*>& generated,
249 std::map<const void*, const char*>& generated_all,
250 cl::Kernel& kernel, int& arg_num) const {
251 if (generated.count(this) == 0) {
252 generated[this] = "";
253 std::map<const void*, const char*> generated2;
254 this->template get_arg<0>().set_args(generated2, generated_all, kernel,
255 arg_num);
256 kernel.setArg(arg_num++, this->template get_arg<0>().view());
257 kernel.setArg(arg_num++, this->template get_arg<0>().cols());
258 if (PassZero && internal::matvec_mul_opt<T>::is_possible) {
259 kernel.setArg(arg_num++, internal::matvec_mul_opt<T_no_ref>::view(
260 this->template get_arg<0>()));
261 }
262 }
263 }
264
265 /**
266 * Number of columns of a matrix that would be the result of evaluating this
267 * expression.
268 * @return number of columns
269 */
cols() const270 inline int cols() const { return 1; }
271
272 /**
273 * Determine indices of extreme sub- and superdiagonals written.
274 * @return pair of indices - bottom and top diagonal
275 */
extreme_diagonals() const276 inline std::pair<int, int> extreme_diagonals() const {
277 return {-rows() + 1, cols() - 1};
278 }
279 };
280
281 /**
282 * Operation for sum reduction.
283 */
284 struct sum_op {
285 /**
286 * Generates sum reduction kernel code.
287 * @param a first variable
288 * @param b second variable
289 * @return reduction code
290 */
generatestan::math::sum_op291 inline static std::string generate(const std::string& a,
292 const std::string& b) {
293 return a + " + " + b;
294 }
295 };
296
297 /**
298 * Represents rowwise sum reduction in kernel generator expressions.
299 * @tparam T type of expression
300 */
301 template <typename T>
302 class rowwise_sum_
303 : public rowwise_reduction<rowwise_sum_<T>, T, sum_op, true> {
304 using base = rowwise_reduction<rowwise_sum_<T>, T, sum_op, true>;
305 using base::arguments_;
306
307 public:
rowwise_sum_(T && a)308 explicit rowwise_sum_(T&& a) : base(std::forward<T>(a), "0") {}
309
310 /**
311 * Creates a deep copy of this expression.
312 * @return copy of \c *this
313 */
deep_copy() const314 inline auto deep_copy() const {
315 auto&& arg_copy = this->template get_arg<0>().deep_copy();
316 return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
317 std::move(arg_copy));
318 }
319 };
320
321 /**
322 * Rowwise sum reduction of a kernel generator expression.
323 * @tparam T type of input expression
324 * @param a the expression to reduce
325 * @return sum
326 */
327 template <typename T,
328 typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_sum(T && a)329 inline auto rowwise_sum(T&& a) {
330 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
331 return rowwise_sum_<std::remove_reference_t<decltype(arg_copy)>>(
332 std::move(arg_copy));
333 }
334
335 /**
336 * Operation for product reduction.
337 */
338 struct prod_op {
339 /**
340 * Generates prod reduction kernel code.
341 * @param a first variable
342 * @param b second variable
343 * @return reduction code
344 */
generatestan::math::prod_op345 inline static std::string generate(const std::string& a,
346 const std::string& b) {
347 return a + " * " + b;
348 }
349 };
350
351 /**
352 * Represents rowwise product reduction in kernel generator expressions.
353 * @tparam T type of expression
354 */
355 template <typename T>
356 class rowwise_prod_
357 : public rowwise_reduction<rowwise_prod_<T>, T, prod_op, false> {
358 using base = rowwise_reduction<rowwise_prod_<T>, T, prod_op, false>;
359 using base::arguments_;
360
361 public:
rowwise_prod_(T && a)362 explicit rowwise_prod_(T&& a) : base(std::forward<T>(a), "1") {}
363
364 /**
365 * Creates a deep copy of this expression.
366 * @return copy of \c *this
367 */
deep_copy() const368 inline auto deep_copy() const {
369 auto&& arg_copy = this->template get_arg<0>().deep_copy();
370 return rowwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
371 std::move(arg_copy));
372 }
373 };
374
375 /**
376 * Rowwise product reduction of a kernel generator expression.
377 * @tparam T type of input expression
378 * @param a the expression to reduce
379 * @return prod
380 */
381 template <typename T,
382 typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_prod(T && a)383 inline auto rowwise_prod(T&& a) {
384 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
385 return rowwise_prod_<std::remove_reference_t<decltype(arg_copy)>>(
386 std::move(arg_copy));
387 }
388
389 /**
390 * Operation for max reduction.
391 * @tparam T type to reduce
392 */
393 template <typename T>
394 struct max_op {
395 /**
396 * Generates max reduction kernel code.
397 * @param a first variable
398 * @param b second variable
399 * @return reduction code
400 */
generatestan::math::max_op401 inline static std::string generate(const std::string& a,
402 const std::string& b) {
403 if (std::is_floating_point<T>()) {
404 return "fmax(" + a + ", " + b + ")";
405 }
406 return "max(" + a + ", " + b + ")";
407 }
408
initstan::math::max_op409 inline static std::string init() {
410 if (std::is_floating_point<T>()) {
411 return "-INFINITY";
412 }
413 return "INT_MIN";
414 }
415 };
416
417 /**
418 * Represents rowwise max reduction in kernel generator expressions.
419 * @tparam T type of expression
420 */
421 template <typename T>
422 class rowwise_max_
423 : public rowwise_reduction<
424 rowwise_max_<T>, T,
425 max_op<typename std::remove_reference_t<T>::Scalar>, false> {
426 using op = max_op<typename std::remove_reference_t<T>::Scalar>;
427 using base = rowwise_reduction<rowwise_max_<T>, T, op, false>;
428 using base::arguments_;
429
430 public:
rowwise_max_(T && a)431 explicit rowwise_max_(T&& a) : base(std::forward<T>(a), op::init()) {}
432 /**
433 * Creates a deep copy of this expression.
434 * @return copy of \c *this
435 */
deep_copy() const436 inline auto deep_copy() const {
437 auto&& arg_copy = this->template get_arg<0>().deep_copy();
438 return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
439 std::move(arg_copy));
440 }
441 };
442
443 /**
444 * Rowwise max reduction of a kernel generator expression.
445 * @tparam T type of input expression
446 * @param a the expression to reduce
447 * @return max
448 */
449 template <typename T,
450 typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_max(T && a)451 inline auto rowwise_max(T&& a) {
452 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
453 return rowwise_max_<std::remove_reference_t<decltype(arg_copy)>>(
454 std::move(arg_copy));
455 }
456 /**
457 * Operation for min reduction.
458 * @tparam T type to reduce
459 */
460 template <typename T>
461 struct min_op {
462 /**
463 * Generates min reduction kernel code.
464 * @param a first variable
465 * @param b second variable
466 * @return reduction code
467 */
generatestan::math::min_op468 inline static std::string generate(const std::string& a,
469 const std::string& b) {
470 if (std::is_floating_point<T>()) {
471 return "fmin(" + a + ", " + b + ")";
472 }
473 return "min(" + a + ", " + b + ")";
474 }
475
initstan::math::min_op476 inline static std::string init() {
477 if (std::is_floating_point<T>()) {
478 return "INFINITY";
479 }
480 return "INT_MAX";
481 }
482 };
483
484 /**
485 * Represents rowwise min reduction in kernel generator expressions.
486 * @tparam T type of expression
487 */
488 template <typename T>
489 class rowwise_min_
490 : public rowwise_reduction<
491 rowwise_min_<T>, T,
492 min_op<typename std::remove_reference_t<T>::Scalar>, false> {
493 using op = min_op<typename std::remove_reference_t<T>::Scalar>;
494 using base = rowwise_reduction<rowwise_min_<T>, T, op, false>;
495 using base::arguments_;
496
497 public:
rowwise_min_(T && a)498 explicit rowwise_min_(T&& a) : base(std::forward<T>(a), op::init()) {}
499 /**
500 * Creates a deep copy of this expression.
501 * @return copy of \c *this
502 */
deep_copy() const503 inline auto deep_copy() const {
504 auto&& arg_copy = this->template get_arg<0>().deep_copy();
505 return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
506 std::move(arg_copy));
507 }
508 };
509
510 /**
511 * Min reduction of a kernel generator expression.
512 * @tparam T type of input expression
513 * @param a the expression to reduce
514 * @return min
515 */
516 template <typename T,
517 typename = require_all_kernel_expressions_and_none_scalar_t<T>>
rowwise_min(T && a)518 inline auto rowwise_min(T&& a) {
519 auto&& arg_copy = as_operation_cl(std::forward<T>(a)).deep_copy();
520 return rowwise_min_<std::remove_reference_t<decltype(arg_copy)>>(
521 std::move(arg_copy));
522 }
523 /** @}*/
524 } // namespace math
525 } // namespace stan
526
527 #endif
528 #endif
529