1 #ifndef STAN_MATH_REV_FUN_DOT_PRODUCT_HPP
2 #define STAN_MATH_REV_FUN_DOT_PRODUCT_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/value_of.hpp>
7 #include <stan/math/rev/fun/to_arena.hpp>
8 #include <stan/math/rev/core/arena_matrix.hpp>
9 #include <stan/math/rev/core/reverse_pass_callback.hpp>
10 #include <stan/math/prim/meta.hpp>
11 #include <stan/math/prim/err.hpp>
12 #include <stan/math/prim/fun/Eigen.hpp>
13 #include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
14 #include <stan/math/prim/fun/dot_product.hpp>
15 #include <stan/math/prim/fun/typedefs.hpp>
16 #include <stan/math/prim/fun/value_of.hpp>
17 #include <type_traits>
18 #include <vector>
19
20 namespace stan {
21 namespace math {
22
23 /**
24 * Returns the dot product.
25 *
26 * @tparam T1 type of elements in the first vector
27 * @tparam T2 type of elements in the second vector
28 *
29 * @param[in] v1 First vector.
30 * @param[in] v2 Second vector.
31 * @return Dot product of the vectors.
32 * @throw std::domain_error if sizes of v1 and v2 do not match.
33 */
34 template <typename T1, typename T2, require_all_vector_t<T1, T2>* = nullptr,
35 require_not_complex_t<return_type_t<T1, T2>>* = nullptr,
36 require_all_not_std_vector_t<T1, T2>* = nullptr,
37 require_any_st_var<T1, T2>* = nullptr>
dot_product(const T1 & v1,const T2 & v2)38 inline var dot_product(const T1& v1, const T2& v2) {
39 check_matching_sizes("dot_product", "v1", v1, "v2", v2);
40
41 if (v1.size() == 0) {
42 return 0.0;
43 }
44
45 if (!is_constant<T1>::value && !is_constant<T2>::value) {
46 arena_t<promote_scalar_t<var, T1>> v1_arena = v1;
47 arena_t<promote_scalar_t<var, T2>> v2_arena = v2;
48 return make_callback_var(
49 v1_arena.val().dot(v2_arena.val()),
50 [v1_arena, v2_arena](const auto& vi) mutable {
51 const auto res_adj = vi.adj();
52 for (Eigen::Index i = 0; i < v1_arena.size(); ++i) {
53 v1_arena.adj().coeffRef(i) += res_adj * v2_arena.val().coeff(i);
54 v2_arena.adj().coeffRef(i) += res_adj * v1_arena.val().coeff(i);
55 }
56 });
57 } else if (!is_constant<T2>::value) {
58 arena_t<promote_scalar_t<var, T2>> v2_arena = v2;
59 arena_t<promote_scalar_t<double, T1>> v1_val_arena = value_of(v1);
60 return make_callback_var(v1_val_arena.dot(v2_arena.val()),
61 [v1_val_arena, v2_arena](const auto& vi) mutable {
62 v2_arena.adj().array()
63 += vi.adj() * v1_val_arena.array();
64 });
65 } else {
66 arena_t<promote_scalar_t<var, T1>> v1_arena = v1;
67 arena_t<promote_scalar_t<double, T2>> v2_val_arena = value_of(v2);
68 return make_callback_var(v1_arena.val().dot(v2_val_arena),
69 [v1_arena, v2_val_arena](const auto& vi) mutable {
70 v1_arena.adj().array()
71 += vi.adj() * v2_val_arena.array();
72 });
73 }
74 }
75
76 } // namespace math
77 } // namespace stan
78 #endif
79