1 #ifndef STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
2 #define STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/adjoint_of.hpp>
7 #include <stan/math/rev/fun/value_of.hpp>
8 #include <stan/math/prim/err.hpp>
9 #include <stan/math/prim/fun/Eigen.hpp>
10 #include <stan/math/prim/fun/exp.hpp>
11 #include <stan/math/prim/fun/squared_distance.hpp>
12 #include <cmath>
13 #include <type_traits>
14 #include <vector>
15 
16 namespace stan {
17 namespace math {
18 
19 /**
20  * Returns a squared exponential kernel.
21  *
22  * @tparam T_x type of elements in the vector
23  * @param x std::vector input that can be used in square distance
24  *    Assumes each element of x is the same size
25  * @param sigma standard deviation
26  * @param length_scale length scale
27  * @return squared distance
28  * @throw std::domain_error if sigma <= 0, l <= 0, or
29  *   x is nan or infinite
30  */
31 template <typename T_x, typename T_sigma, require_st_arithmetic<T_x>* = nullptr,
32           require_stan_scalar_t<T_sigma>* = nullptr>
gp_exp_quad_cov(const std::vector<T_x> & x,const T_sigma sigma,const var length_scale)33 inline Eigen::Matrix<var, -1, -1> gp_exp_quad_cov(const std::vector<T_x>& x,
34                                                   const T_sigma sigma,
35                                                   const var length_scale) {
36   check_positive("gp_exp_quad_cov", "sigma", sigma);
37   check_positive("gp_exp_quad_cov", "length_scale", length_scale);
38   size_t x_size = x.size();
39   for (size_t i = 0; i < x_size; ++i) {
40     check_not_nan("gp_exp_quad_cov", "x", x[i]);
41   }
42 
43   Eigen::Matrix<var, -1, -1> cov(x_size, x_size);
44   if (x_size == 0) {
45     return cov;
46   }
47   size_t l_tri_size = x_size * (x_size - 1) / 2;
48   arena_matrix<Eigen::VectorXd> sq_dists_lin(l_tri_size);
49   arena_matrix<Eigen::Matrix<var, -1, 1>> cov_l_tri_lin(l_tri_size);
50   arena_matrix<Eigen::Matrix<var, -1, 1>> cov_diag(
51       is_constant<T_sigma>::value ? 0 : x_size);
52 
53   double l_val = value_of(length_scale);
54   double sigma_sq = square(value_of(sigma));
55   double neg_half_inv_l_sq = -0.5 / square(l_val);
56 
57   size_t block_size = 10;
58   size_t pos = 0;
59   for (size_t jb = 0; jb < x_size; jb += block_size) {
60     size_t j_end = std::min(x_size, jb + block_size);
61     size_t j_size = j_end - jb;
62     cov.diagonal().segment(jb, j_size)
63         = Eigen::VectorXd::Constant(j_size, sigma_sq);
64     if (!is_constant<T_sigma>::value) {
65       cov_diag.segment(jb, j_size) = cov.diagonal().segment(jb, j_size);
66     }
67     for (size_t ib = jb; ib < x_size; ib += block_size) {
68       size_t i_end = std::min(x_size, ib + block_size);
69       for (size_t j = jb; j < j_end; ++j) {
70         for (size_t i = std::max(ib, j + 1); i < i_end; ++i) {
71           sq_dists_lin.coeffRef(pos) = squared_distance(x[i], x[j]);
72           cov_l_tri_lin.coeffRef(pos) = cov.coeffRef(j, i) = cov.coeffRef(i, j)
73               = sigma_sq * exp(sq_dists_lin.coeff(pos) * neg_half_inv_l_sq);
74           pos++;
75         }
76       }
77     }
78   }
79 
80   reverse_pass_callback(
81       [cov_l_tri_lin, cov_diag, sq_dists_lin, sigma, length_scale, x_size]() {
82         size_t l_tri_size = x_size * (x_size - 1) / 2;
83         double adjl = 0;
84         double adjsigma = 0;
85         for (Eigen::Index pos = 0; pos < l_tri_size; pos++) {
86           double prod_add
87               = cov_l_tri_lin.coeff(pos).val() * cov_l_tri_lin.coeff(pos).adj();
88           adjl += prod_add * sq_dists_lin.coeff(pos);
89           if (!is_constant<T_sigma>::value) {
90             adjsigma += prod_add;
91           }
92         }
93         if (!is_constant<T_sigma>::value) {
94           adjsigma += (cov_diag.val().array() * cov_diag.adj().array()).sum();
95           adjoint_of(sigma) += adjsigma * 2 / value_of(sigma);
96         }
97         double l_val = value_of(length_scale);
98         length_scale.adj() += adjl / (l_val * l_val * l_val);
99       });
100 
101   return cov;
102 }
103 
104 }  // namespace math
105 }  // namespace stan
106 #endif
107