1 #ifndef STAN_MATH_PRIM_FUN_FMA_HPP
2 #define STAN_MATH_PRIM_FUN_FMA_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/as_array_or_scalar.hpp>
6 #include <cmath>
7
8 namespace stan {
9 namespace math {
10
11 /**
12 * Return the product of the first two arguments plus the third
13 * argument.
14 *
15 * <p><i>Warning:</i> This does not delegate to the high-precision
16 * platform-specific <code>fma()</code> implementation.
17 *
18 * @param x First argument.
19 * @param y Second argument.
20 * @param z Third argument.
21 * @return The product of the first two arguments plus the third
22 * argument.
23 */
24 template <typename T1, typename T2, typename T3,
25 require_all_arithmetic_t<T1, T2, T3>* = nullptr>
fma(T1 x,T2 y,T3 z)26 inline double fma(T1 x, T2 y, T3 z) {
27 using std::fma;
28 return fma(x, y, z);
29 }
30
31 template <typename T1, typename T2, typename T3,
32 require_any_matrix_t<T1, T2, T3>* = nullptr,
33 require_not_var_t<return_type_t<T1, T2, T3>>* = nullptr>
fma(T1 && x,T2 && y,T3 && z)34 inline auto fma(T1&& x, T2&& y, T3&& z) {
35 if (is_matrix<T1>::value && is_matrix<T2>::value) {
36 check_matching_dims("fma", "x", x, "y", y);
37 }
38 if (is_matrix<T1>::value && is_matrix<T3>::value) {
39 check_matching_dims("fma", "x", x, "z", z);
40 } else if (is_matrix<T2>::value && is_matrix<T3>::value) {
41 check_matching_dims("fma", "y", y, "z", z);
42 }
43 return make_holder(
44 [](auto&& x, auto&& y, auto&& z) {
45 return ((as_array_or_scalar(x) * as_array_or_scalar(y))
46 + as_array_or_scalar(z))
47 .matrix();
48 },
49 std::forward<T1>(x), std::forward<T2>(y), std::forward<T3>(z));
50 }
51
52 } // namespace math
53 } // namespace stan
54 #endif
55