1 #ifndef STAN_MATH_OPENCL_PRIM_UB_CONSTRAIN_HPP
2 #define STAN_MATH_OPENCL_PRIM_UB_CONSTRAIN_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/opencl/prim/sum.hpp>
6 #include <stan/math/opencl/matrix_cl.hpp>
7 #include <stan/math/opencl/kernel_generator.hpp>
8 #include <stan/math/prim/fun/constants.hpp>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * Return the upper-bounded value for the specified unconstrained
15  * matrix and upper bound.
16  *
17  * <p>The transform is
18  *
19  * <p>\f$f(x) = U - \exp(x)\f$
20  *
21  * <p>where \f$U\f$ is the upper bound.
22  *
23  * @tparam T type of Matrix
24  * @tparam U type of upper bound
25  * @param[in] x free Matrix.
26  * @param[in] ub upper bound
27  * @return matrix constrained to have upper bound
28  */
29 template <typename T, typename U,
30           require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr,
31           require_all_kernel_expressions_t<U>* = nullptr>
ub_constrain(T && x,U && ub)32 inline auto ub_constrain(T&& x, U&& ub) {
33   return make_holder_cl(
34       [](auto& x_, auto& ub_) {
35         return select(ub_ == INFTY, x_, ub_ - exp(x_));
36       },
37       std::forward<T>(x), std::forward<U>(ub));
38 }
39 
40 /**
41  * Return the upper-bounded value for the specified unconstrained
42  * matrix and upper bound.
43  *
44  * <p>The transform is
45  *
46  * <p>\f$f(x) = U - \exp(x)\f$
47  *
48  * <p>where \f$U\f$ is the upper bound.
49  *
50  * @tparam T type of Matrix
51  * @tparam U type of upper bound
52  * @param[in] x free Matrix.
53  * @param[in] ub upper bound
54  * @param[in,out] lp reference to log probability to increment
55  * @return matrix constrained to have upper bound
56  */
57 template <typename T, typename U,
58           require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr,
59           require_all_kernel_expressions_t<U>* = nullptr>
ub_constrain(const T & x,const U & ub,return_type_t<T,U> & lp)60 inline auto ub_constrain(const T& x, const U& ub, return_type_t<T, U>& lp) {
61   matrix_cl<double> lp_inc;
62   matrix_cl<double> res;
63   auto ub_inf = ub == INFTY;
64   auto lp_inc_expr = sum_2d(select(ub_inf, 0.0, x));
65   auto res_expr = select(ub_inf, x, ub - exp(x));
66   results(lp_inc, res) = expressions(lp_inc_expr, res_expr);
67   lp += sum(from_matrix_cl(lp_inc));
68   return res;
69 }
70 
71 }  // namespace math
72 }  // namespace stan
73 #endif
74 #endif
75