1 #ifndef STAN_MATH_PRIM_FUN_SQUARED_DISTANCE_HPP
2 #define STAN_MATH_PRIM_FUN_SQUARED_DISTANCE_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
7 #include <stan/math/prim/fun/Eigen.hpp>
8 #include <stan/math/prim/fun/square.hpp>
9
10 namespace stan {
11 namespace math {
12
13 /**
14 * Returns the squared distance.
15 *
16 * @tparam Scal1 Type of the first scalar.
17 * @tparam Scal2 Type of the second scalar.
18 * @param x1 First scalar.
19 * @param x2 Second scalar.
20 * @return Squared distance between scalars
21 * @throw std::domain_error Any scalar is not finite.
22 */
23 template <typename Scal1, typename Scal2,
24 require_all_stan_scalar_t<Scal1, Scal2>* = nullptr,
25 require_all_not_var_t<Scal1, Scal2>* = nullptr>
squared_distance(const Scal1 & x1,const Scal2 & x2)26 inline return_type_t<Scal1, Scal2> squared_distance(const Scal1& x1,
27 const Scal2& x2) {
28 check_finite("squared_distance", "x1", x1);
29 check_finite("squared_distance", "x2", x2);
30 return square(x1 - x2);
31 }
32
33 /**
34 * Returns the squared distance between the specified vectors
35 * of the same dimensions.
36 *
37 * @tparam EigVec1 type of the first vector (must be derived from \c
38 * Eigen::MatrixBase and have one compile time dimension equal to 1)
39 * @tparam EigVec2 type of the second vector (must be derived from \c
40 * Eigen::MatrixBase and have one compile time dimension equal to 1)
41 * @param v1 First vector.
42 * @param v2 Second vector.
43 * @return Square of distance between vectors.
44 * @throw std::domain_error If the vectors are not the same
45 * size.
46 */
47 template <typename EigVec1, typename EigVec2,
48 require_all_eigen_vector_t<EigVec1, EigVec2>* = nullptr,
49 require_all_not_eigen_vt<is_var, EigVec1, EigVec2>* = nullptr>
squared_distance(const EigVec1 & v1,const EigVec2 & v2)50 inline return_type_t<EigVec1, EigVec2> squared_distance(const EigVec1& v1,
51 const EigVec2& v2) {
52 check_matching_sizes("squared_distance", "v1", v1, "v2", v2);
53 return (as_column_vector_or_scalar(v1) - as_column_vector_or_scalar(v2))
54 .squaredNorm();
55 }
56
57 } // namespace math
58 } // namespace stan
59
60 #endif
61