1 #ifndef STAN_MATH_REV_FUN_DOT_PRODUCT_HPP
2 #define STAN_MATH_REV_FUN_DOT_PRODUCT_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/value_of.hpp>
7 #include <stan/math/rev/fun/to_arena.hpp>
8 #include <stan/math/rev/core/arena_matrix.hpp>
9 #include <stan/math/rev/core/reverse_pass_callback.hpp>
10 #include <stan/math/prim/meta.hpp>
11 #include <stan/math/prim/err.hpp>
12 #include <stan/math/prim/fun/Eigen.hpp>
13 #include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
14 #include <stan/math/prim/fun/dot_product.hpp>
15 #include <stan/math/prim/fun/typedefs.hpp>
16 #include <stan/math/prim/fun/value_of.hpp>
17 #include <type_traits>
18 #include <vector>
19 
20 namespace stan {
21 namespace math {
22 
23 /**
24  * Returns the dot product.
25  *
26  * @tparam T1 type of elements in the first vector
27  * @tparam T2 type of elements in the second vector
28  *
29  * @param[in] v1 First vector.
30  * @param[in] v2 Second vector.
31  * @return Dot product of the vectors.
32  * @throw std::domain_error if sizes of v1 and v2 do not match.
33  */
34 template <typename T1, typename T2, require_all_vector_t<T1, T2>* = nullptr,
35           require_not_complex_t<return_type_t<T1, T2>>* = nullptr,
36           require_all_not_std_vector_t<T1, T2>* = nullptr,
37           require_any_st_var<T1, T2>* = nullptr>
dot_product(const T1 & v1,const T2 & v2)38 inline var dot_product(const T1& v1, const T2& v2) {
39   check_matching_sizes("dot_product", "v1", v1, "v2", v2);
40 
41   if (v1.size() == 0) {
42     return 0.0;
43   }
44 
45   if (!is_constant<T1>::value && !is_constant<T2>::value) {
46     arena_t<promote_scalar_t<var, T1>> v1_arena = v1;
47     arena_t<promote_scalar_t<var, T2>> v2_arena = v2;
48     return make_callback_var(
49         v1_arena.val().dot(v2_arena.val()),
50         [v1_arena, v2_arena](const auto& vi) mutable {
51           const auto res_adj = vi.adj();
52           for (Eigen::Index i = 0; i < v1_arena.size(); ++i) {
53             v1_arena.adj().coeffRef(i) += res_adj * v2_arena.val().coeff(i);
54             v2_arena.adj().coeffRef(i) += res_adj * v1_arena.val().coeff(i);
55           }
56         });
57   } else if (!is_constant<T2>::value) {
58     arena_t<promote_scalar_t<var, T2>> v2_arena = v2;
59     arena_t<promote_scalar_t<double, T1>> v1_val_arena = value_of(v1);
60     return make_callback_var(v1_val_arena.dot(v2_arena.val()),
61                              [v1_val_arena, v2_arena](const auto& vi) mutable {
62                                v2_arena.adj().array()
63                                    += vi.adj() * v1_val_arena.array();
64                              });
65   } else {
66     arena_t<promote_scalar_t<var, T1>> v1_arena = v1;
67     arena_t<promote_scalar_t<double, T2>> v2_val_arena = value_of(v2);
68     return make_callback_var(v1_arena.val().dot(v2_val_arena),
69                              [v1_arena, v2_val_arena](const auto& vi) mutable {
70                                v1_arena.adj().array()
71                                    += vi.adj() * v2_val_arena.array();
72                              });
73   }
74 }
75 
76 }  // namespace math
77 }  // namespace stan
78 #endif
79