1 #ifndef STAN_MATH_OPENCL_PRIM_UB_CONSTRAIN_HPP
2 #define STAN_MATH_OPENCL_PRIM_UB_CONSTRAIN_HPP
3 #ifdef STAN_OPENCL
4
5 #include <stan/math/opencl/prim/sum.hpp>
6 #include <stan/math/opencl/matrix_cl.hpp>
7 #include <stan/math/opencl/kernel_generator.hpp>
8 #include <stan/math/prim/fun/constants.hpp>
9
10 namespace stan {
11 namespace math {
12
13 /**
14 * Return the upper-bounded value for the specified unconstrained
15 * matrix and upper bound.
16 *
17 * <p>The transform is
18 *
19 * <p>\f$f(x) = U - \exp(x)\f$
20 *
21 * <p>where \f$U\f$ is the upper bound.
22 *
23 * @tparam T type of Matrix
24 * @tparam U type of upper bound
25 * @param[in] x free Matrix.
26 * @param[in] ub upper bound
27 * @return matrix constrained to have upper bound
28 */
29 template <typename T, typename U,
30 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr,
31 require_all_kernel_expressions_t<U>* = nullptr>
ub_constrain(T && x,U && ub)32 inline auto ub_constrain(T&& x, U&& ub) {
33 return make_holder_cl(
34 [](auto& x_, auto& ub_) {
35 return select(ub_ == INFTY, x_, ub_ - exp(x_));
36 },
37 std::forward<T>(x), std::forward<U>(ub));
38 }
39
40 /**
41 * Return the upper-bounded value for the specified unconstrained
42 * matrix and upper bound.
43 *
44 * <p>The transform is
45 *
46 * <p>\f$f(x) = U - \exp(x)\f$
47 *
48 * <p>where \f$U\f$ is the upper bound.
49 *
50 * @tparam T type of Matrix
51 * @tparam U type of upper bound
52 * @param[in] x free Matrix.
53 * @param[in] ub upper bound
54 * @param[in,out] lp reference to log probability to increment
55 * @return matrix constrained to have upper bound
56 */
57 template <typename T, typename U,
58 require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr,
59 require_all_kernel_expressions_t<U>* = nullptr>
ub_constrain(const T & x,const U & ub,return_type_t<T,U> & lp)60 inline auto ub_constrain(const T& x, const U& ub, return_type_t<T, U>& lp) {
61 matrix_cl<double> lp_inc;
62 matrix_cl<double> res;
63 auto ub_inf = ub == INFTY;
64 auto lp_inc_expr = sum_2d(select(ub_inf, 0.0, x));
65 auto res_expr = select(ub_inf, x, ub - exp(x));
66 results(lp_inc, res) = expressions(lp_inc_expr, res_expr);
67 lp += sum(from_matrix_cl(lp_inc));
68 return res;
69 }
70
71 } // namespace math
72 } // namespace stan
73 #endif
74 #endif
75