1 #ifndef STAN_MATH_PRIM_FUN_FMA_HPP
2 #define STAN_MATH_PRIM_FUN_FMA_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/as_array_or_scalar.hpp>
6 #include <cmath>
7 
8 namespace stan {
9 namespace math {
10 
11 /**
12  * Return the product of the first two arguments plus the third
13  * argument.
14  *
15  * <p><i>Warning:</i> This does not delegate to the high-precision
16  * platform-specific <code>fma()</code> implementation.
17  *
18  * @param x First argument.
19  * @param y Second argument.
20  * @param z Third argument.
21  * @return The product of the first two arguments plus the third
22  * argument.
23  */
24 template <typename T1, typename T2, typename T3,
25           require_all_arithmetic_t<T1, T2, T3>* = nullptr>
fma(T1 x,T2 y,T3 z)26 inline double fma(T1 x, T2 y, T3 z) {
27   using std::fma;
28   return fma(x, y, z);
29 }
30 
31 template <typename T1, typename T2, typename T3,
32           require_any_matrix_t<T1, T2, T3>* = nullptr,
33           require_not_var_t<return_type_t<T1, T2, T3>>* = nullptr>
fma(T1 && x,T2 && y,T3 && z)34 inline auto fma(T1&& x, T2&& y, T3&& z) {
35   if (is_matrix<T1>::value && is_matrix<T2>::value) {
36     check_matching_dims("fma", "x", x, "y", y);
37   }
38   if (is_matrix<T1>::value && is_matrix<T3>::value) {
39     check_matching_dims("fma", "x", x, "z", z);
40   } else if (is_matrix<T2>::value && is_matrix<T3>::value) {
41     check_matching_dims("fma", "y", y, "z", z);
42   }
43   return make_holder(
44       [](auto&& x, auto&& y, auto&& z) {
45         return ((as_array_or_scalar(x) * as_array_or_scalar(y))
46                 + as_array_or_scalar(z))
47             .matrix();
48       },
49       std::forward<T1>(x), std::forward<T2>(y), std::forward<T3>(z));
50 }
51 
52 }  // namespace math
53 }  // namespace stan
54 #endif
55