1 #ifndef STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
2 #define STAN_MATH_REV_FUN_GP_EXP_QUAD_COV_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/adjoint_of.hpp>
7 #include <stan/math/rev/fun/value_of.hpp>
8 #include <stan/math/prim/err.hpp>
9 #include <stan/math/prim/fun/Eigen.hpp>
10 #include <stan/math/prim/fun/exp.hpp>
11 #include <stan/math/prim/fun/squared_distance.hpp>
12 #include <cmath>
13 #include <type_traits>
14 #include <vector>
15
16 namespace stan {
17 namespace math {
18
19 /**
20 * Returns a squared exponential kernel.
21 *
22 * @tparam T_x type of elements in the vector
23 * @param x std::vector input that can be used in square distance
24 * Assumes each element of x is the same size
25 * @param sigma standard deviation
26 * @param length_scale length scale
27 * @return squared distance
28 * @throw std::domain_error if sigma <= 0, l <= 0, or
29 * x is nan or infinite
30 */
31 template <typename T_x, typename T_sigma, require_st_arithmetic<T_x>* = nullptr,
32 require_stan_scalar_t<T_sigma>* = nullptr>
gp_exp_quad_cov(const std::vector<T_x> & x,const T_sigma sigma,const var length_scale)33 inline Eigen::Matrix<var, -1, -1> gp_exp_quad_cov(const std::vector<T_x>& x,
34 const T_sigma sigma,
35 const var length_scale) {
36 check_positive("gp_exp_quad_cov", "sigma", sigma);
37 check_positive("gp_exp_quad_cov", "length_scale", length_scale);
38 size_t x_size = x.size();
39 for (size_t i = 0; i < x_size; ++i) {
40 check_not_nan("gp_exp_quad_cov", "x", x[i]);
41 }
42
43 Eigen::Matrix<var, -1, -1> cov(x_size, x_size);
44 if (x_size == 0) {
45 return cov;
46 }
47 size_t l_tri_size = x_size * (x_size - 1) / 2;
48 arena_matrix<Eigen::VectorXd> sq_dists_lin(l_tri_size);
49 arena_matrix<Eigen::Matrix<var, -1, 1>> cov_l_tri_lin(l_tri_size);
50 arena_matrix<Eigen::Matrix<var, -1, 1>> cov_diag(
51 is_constant<T_sigma>::value ? 0 : x_size);
52
53 double l_val = value_of(length_scale);
54 double sigma_sq = square(value_of(sigma));
55 double neg_half_inv_l_sq = -0.5 / square(l_val);
56
57 size_t block_size = 10;
58 size_t pos = 0;
59 for (size_t jb = 0; jb < x_size; jb += block_size) {
60 size_t j_end = std::min(x_size, jb + block_size);
61 size_t j_size = j_end - jb;
62 cov.diagonal().segment(jb, j_size)
63 = Eigen::VectorXd::Constant(j_size, sigma_sq);
64 if (!is_constant<T_sigma>::value) {
65 cov_diag.segment(jb, j_size) = cov.diagonal().segment(jb, j_size);
66 }
67 for (size_t ib = jb; ib < x_size; ib += block_size) {
68 size_t i_end = std::min(x_size, ib + block_size);
69 for (size_t j = jb; j < j_end; ++j) {
70 for (size_t i = std::max(ib, j + 1); i < i_end; ++i) {
71 sq_dists_lin.coeffRef(pos) = squared_distance(x[i], x[j]);
72 cov_l_tri_lin.coeffRef(pos) = cov.coeffRef(j, i) = cov.coeffRef(i, j)
73 = sigma_sq * exp(sq_dists_lin.coeff(pos) * neg_half_inv_l_sq);
74 pos++;
75 }
76 }
77 }
78 }
79
80 reverse_pass_callback(
81 [cov_l_tri_lin, cov_diag, sq_dists_lin, sigma, length_scale, x_size]() {
82 size_t l_tri_size = x_size * (x_size - 1) / 2;
83 double adjl = 0;
84 double adjsigma = 0;
85 for (Eigen::Index pos = 0; pos < l_tri_size; pos++) {
86 double prod_add
87 = cov_l_tri_lin.coeff(pos).val() * cov_l_tri_lin.coeff(pos).adj();
88 adjl += prod_add * sq_dists_lin.coeff(pos);
89 if (!is_constant<T_sigma>::value) {
90 adjsigma += prod_add;
91 }
92 }
93 if (!is_constant<T_sigma>::value) {
94 adjsigma += (cov_diag.val().array() * cov_diag.adj().array()).sum();
95 adjoint_of(sigma) += adjsigma * 2 / value_of(sigma);
96 }
97 double l_val = value_of(length_scale);
98 length_scale.adj() += adjl / (l_val * l_val * l_val);
99 });
100
101 return cov;
102 }
103
104 } // namespace math
105 } // namespace stan
106 #endif
107