1 #ifndef STAN_MATH_REV_FUN_FMA_HPP
2 #define STAN_MATH_REV_FUN_FMA_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/fun/constants.hpp>
7 #include <stan/math/prim/fun/fma.hpp>
8 #include <stan/math/prim/fun/is_any_nan.hpp>
9 
10 namespace stan {
11 namespace math {
12 
13 /**
14  * The fused multiply-add function for three variables (C99).
15  * This function returns the product of the first two arguments
16  * plus the third argument.
17  *
18  * The partial derivatives are
19  *
20  * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
21  *
22  * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$, and
23  *
24  * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
25  *
26  * @param x First multiplicand.
27  * @param y Second multiplicand.
28  * @param z Summand.
29  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
30  */
fma(const var & x,const var & y,const var & z)31 inline var fma(const var& x, const var& y, const var& z) {
32   return make_callback_var(fma(x.val(), y.val(), z.val()), [x, y, z](auto& vi) {
33     x.adj() += vi.adj() * y.val();
34     y.adj() += vi.adj() * x.val();
35     z.adj() += vi.adj();
36   });
37 }
38 
39 /**
40  * The fused multiply-add function for two variables and a value
41  * (C99).  This function returns the product of the first two
42  * arguments plus the third argument.
43  *
44  * The partial derivatives are
45  *
46  * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
47  *
48  * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$.
49  *
50  * @tparam Tc type of the summand
51  * @param x First multiplicand.
52  * @param y Second multiplicand.
53  * @param z Summand.
54  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
55  */
56 template <typename Tc, require_arithmetic_t<Tc>* = nullptr>
fma(const var & x,const var & y,Tc && z)57 inline var fma(const var& x, const var& y, Tc&& z) {
58   return make_callback_var(fma(x.val(), y.val(), z), [x, y](auto& vi) {
59     x.adj() += vi.adj() * y.val();
60     y.adj() += vi.adj() * x.val();
61   });
62 }
63 
64 /**
65  * The fused multiply-add function for a variable, value, and
66  * variable (C99).  This function returns the product of the first
67  * two arguments plus the third argument.
68  *
69  * The partial derivatives are
70  *
71  * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
72  *
73  * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
74  *
75  * @tparam Ta type of the first multiplicand
76  * @tparam Tb type of the second multiplicand
77  * @tparam Tc type of the summand
78  *
79  * @param x First multiplicand.
80  * @param y Second multiplicand.
81  * @param z Summand.
82  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
83  */
84 template <typename Ta, typename Tb, typename Tc,
85           require_arithmetic_t<Tb>* = nullptr,
86           require_all_var_t<Ta, Tc>* = nullptr>
fma(Ta && x,Tb && y,Tc && z)87 inline var fma(Ta&& x, Tb&& y, Tc&& z) {
88   return make_callback_var(fma(x.val(), y, z.val()), [x, y, z](auto& vi) {
89     x.adj() += vi.adj() * y;
90     z.adj() += vi.adj();
91   });
92 }
93 
94 /**
95  * The fused multiply-add function for a variable and two values
96  * (C99).  This function returns the product of the first two
97  * arguments plus the third argument.
98  *
99  * The double-based version
100  * <code>::%fma(double, double, double)</code> is defined in
101  * <code>&lt;cmath&gt;</code>.
102  *
103  * The derivative is
104  *
105  * \f$\frac{d}{d x} (x * y) + z = y\f$.
106  *
107  * @tparam Tb type of the second multiplicand
108  * @tparam Tc type of the summand
109  *
110  * @param x First multiplicand.
111  * @param y Second multiplicand.
112  * @param z Summand.
113  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
114  */
115 template <typename Tb, typename Tc, require_all_arithmetic_t<Tb, Tc>* = nullptr>
fma(const var & x,Tb && y,Tc && z)116 inline var fma(const var& x, Tb&& y, Tc&& z) {
117   return make_callback_var(fma(x.val(), y, z),
118                            [x, y](auto& vi) { x.adj() += vi.adj() * y; });
119 }
120 
121 /**
122  * The fused multiply-add function for a value, variable, and
123  * value (C99).  This function returns the product of the first
124  * two arguments plus the third argument.
125  *
126  * The derivative is
127  *
128  * \f$\frac{d}{d y} (x * y) + z = x\f$, and
129  *
130  * @tparam Ta type of the first multiplicand
131  * @tparam Tc type of the summand
132  *
133  * @param x First multiplicand.
134  * @param y Second multiplicand.
135  * @param z Summand.
136  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
137  */
138 template <typename Ta, typename Tc, require_all_arithmetic_t<Ta, Tc>* = nullptr>
fma(Ta && x,const var & y,Tc && z)139 inline var fma(Ta&& x, const var& y, Tc&& z) {
140   return make_callback_var(fma(x, y.val(), z),
141                            [x, y](auto& vi) { y.adj() += vi.adj() * x; });
142 }
143 
144 /**
145  * The fused multiply-add function for two values and a variable,
146  * and value (C99).  This function returns the product of the
147  * first two arguments plus the third argument.
148  *
149  * The derivative is
150  *
151  * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
152  *
153  * @tparam Ta type of the first multiplicand
154  * @tparam Tb type of the second multiplicand
155  *
156  * @param x First multiplicand.
157  * @param y Second multiplicand.
158  * @param z Summand.
159  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
160  */
161 template <typename Ta, typename Tb, require_all_arithmetic_t<Ta, Tb>* = nullptr>
fma(Ta && x,Tb && y,const var & z)162 inline var fma(Ta&& x, Tb&& y, const var& z) {
163   return make_callback_var(fma(x, y, z.val()),
164                            [z](auto& vi) { z.adj() += vi.adj(); });
165 }
166 
167 /**
168  * The fused multiply-add function for a value and two variables
169  * (C99).  This function returns the product of the first two
170  * arguments plus the third argument.
171  *
172  * The partial derivatives are
173  *
174  * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$, and
175  *
176  * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
177  *
178  * @tparam Ta type of the first multiplicand
179  * @param x First multiplicand.
180  * @param y Second multiplicand.
181  * @param z Summand.
182  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
183  */
184 template <typename Ta, require_arithmetic_t<Ta>* = nullptr>
fma(Ta && x,const var & y,const var & z)185 inline var fma(Ta&& x, const var& y, const var& z) {
186   return make_callback_var(fma(x, y.val(), z.val()), [x, y, z](auto& vi) {
187     y.adj() += vi.adj() * x;
188     z.adj() += vi.adj();
189   });
190 }
191 
192 namespace internal {
193 /**
194  * Overload for matrix, matrix, matrix
195  */
196 template <typename T1, typename T2, typename T3, typename T4,
197           require_all_matrix_t<T1, T2, T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)198 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
199   return [arena_x, arena_y, arena_z, ret]() mutable {
200     using T1_var = arena_t<plain_type_t<promote_scalar_t<var, T1>>>;
201     using T2_var = arena_t<plain_type_t<promote_scalar_t<var, T2>>>;
202     using T3_var = arena_t<plain_type_t<promote_scalar_t<var, T3>>>;
203     if (!is_constant<T1>::value) {
204       forward_as<T1_var>(arena_x).adj().array()
205           += ret.adj().array() * value_of(arena_y).array();
206     }
207     if (!is_constant<T2>::value) {
208       forward_as<T2_var>(arena_y).adj().array()
209           += ret.adj().array() * value_of(arena_x).array();
210     }
211     if (!is_constant<T3>::value) {
212       forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
213     }
214   };
215 }
216 
217 /**
218  * Overload for scalar, matrix, matrix
219  */
220 template <typename T1, typename T2, typename T3, typename T4,
221           require_all_matrix_t<T2, T3>* = nullptr,
222           require_stan_scalar_t<T1>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)223 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
224   return [arena_x, arena_y, arena_z, ret]() mutable {
225     using T1_var = arena_t<promote_scalar_t<var, T1>>;
226     using T2_var = arena_t<promote_scalar_t<var, T2>>;
227     using T3_var = arena_t<promote_scalar_t<var, T3>>;
228     if (!is_constant<T1>::value) {
229       forward_as<T1_var>(arena_x).adj()
230           += (ret.adj().array() * value_of(arena_y).array()).sum();
231     }
232     if (!is_constant<T2>::value) {
233       forward_as<T2_var>(arena_y).adj().array()
234           += ret.adj().array() * value_of(arena_x);
235     }
236     if (!is_constant<T3>::value) {
237       forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
238     }
239   };
240 }
241 
242 /**
243  * Overload for matrix, scalar, matrix
244  */
245 template <typename T1, typename T2, typename T3, typename T4,
246           require_all_matrix_t<T1, T3>* = nullptr,
247           require_stan_scalar_t<T2>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)248 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
249   return [arena_x, arena_y, arena_z, ret]() mutable {
250     using T1_var = arena_t<promote_scalar_t<var, T1>>;
251     using T2_var = arena_t<promote_scalar_t<var, T2>>;
252     using T3_var = arena_t<promote_scalar_t<var, T3>>;
253     if (!is_constant<T1>::value) {
254       forward_as<T1_var>(arena_x).adj().array()
255           += ret.adj().array() * value_of(arena_y);
256     }
257     if (!is_constant<T2>::value) {
258       forward_as<T2_var>(arena_y).adj()
259           += (ret.adj().array() * value_of(arena_x).array()).sum();
260     }
261     if (!is_constant<T3>::value) {
262       forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
263     }
264   };
265 }
266 
267 /**
268  * Overload for scalar, scalar, matrix
269  */
270 template <typename T1, typename T2, typename T3, typename T4,
271           require_matrix_t<T3>* = nullptr,
272           require_all_stan_scalar_t<T1, T2>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)273 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
274   return [arena_x, arena_y, arena_z, ret]() mutable {
275     using T1_var = arena_t<promote_scalar_t<var, T1>>;
276     using T2_var = arena_t<promote_scalar_t<var, T2>>;
277     using T3_var = arena_t<promote_scalar_t<var, T3>>;
278     if (!is_constant<T1>::value) {
279       forward_as<T1_var>(arena_x).adj()
280           += (ret.adj().array() * value_of(arena_y)).sum();
281     }
282     if (!is_constant<T2>::value) {
283       forward_as<T2_var>(arena_y).adj()
284           += (ret.adj().array() * value_of(arena_x)).sum();
285     }
286     if (!is_constant<T3>::value) {
287       forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
288     }
289   };
290 }
291 
292 /**
293  * Overload for matrix, matrix, scalar
294  */
295 template <typename T1, typename T2, typename T3, typename T4,
296           require_all_matrix_t<T1, T2>* = nullptr,
297           require_stan_scalar_t<T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)298 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
299   return [arena_x, arena_y, arena_z, ret]() mutable {
300     using T1_var = arena_t<promote_scalar_t<var, T1>>;
301     using T2_var = arena_t<promote_scalar_t<var, T2>>;
302     using T3_var = arena_t<promote_scalar_t<var, T3>>;
303     if (!is_constant<T1>::value) {
304       forward_as<T1_var>(arena_x).adj().array()
305           += ret.adj().array() * value_of(arena_y).array();
306     }
307     if (!is_constant<T2>::value) {
308       forward_as<T2_var>(arena_y).adj().array()
309           += ret.adj().array() * value_of(arena_x).array();
310     }
311     if (!is_constant<T3>::value) {
312       forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
313     }
314   };
315 }
316 
317 /**
318  * Overload for scalar, matrix, scalar
319  */
320 template <typename T1, typename T2, typename T3, typename T4,
321           require_matrix_t<T2>* = nullptr,
322           require_all_stan_scalar_t<T1, T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)323 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
324   return [arena_x, arena_y, arena_z, ret]() mutable {
325     using T1_var = arena_t<promote_scalar_t<var, T1>>;
326     using T2_var = arena_t<promote_scalar_t<var, T2>>;
327     using T3_var = arena_t<promote_scalar_t<var, T3>>;
328     if (!is_constant<T1>::value) {
329       forward_as<T1_var>(arena_x).adj()
330           += (ret.adj().array() * value_of(arena_y).array()).sum();
331     }
332     if (!is_constant<T2>::value) {
333       forward_as<T2_var>(arena_y).adj().array()
334           += ret.adj().array() * value_of(arena_x);
335     }
336     if (!is_constant<T3>::value) {
337       forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
338     }
339   };
340 }
341 
342 /**
343  * Overload for matrix, scalar, scalar
344  */
345 template <typename T1, typename T2, typename T3, typename T4,
346           require_matrix_t<T1>* = nullptr,
347           require_all_stan_scalar_t<T2, T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)348 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
349   return [arena_x, arena_y, arena_z, ret]() mutable {
350     using T1_var = arena_t<promote_scalar_t<var, T1>>;
351     using T2_var = arena_t<promote_scalar_t<var, T2>>;
352     using T3_var = arena_t<promote_scalar_t<var, T3>>;
353     if (!is_constant<T1>::value) {
354       forward_as<T1_var>(arena_x).adj().array()
355           += ret.adj().array() * value_of(arena_y);
356     }
357     if (!is_constant<T2>::value) {
358       forward_as<T2_var>(arena_y).adj()
359           += (ret.adj().array() * value_of(arena_x).array()).sum();
360     }
361     if (!is_constant<T3>::value) {
362       forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
363     }
364   };
365 }
366 
367 }  // namespace internal
368 
369 /**
370  * The fused multiply-add function for three variables (C99).
371  * This function returns the product of the first two arguments
372  * plus the third argument.
373  *
374  * The partial derivatives are
375  *
376  * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
377  *
378  * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$, and
379  *
380  * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
381  *
382  * @param x First multiplicand.
383  * @param y Second multiplicand.
384  * @param z Summand.
385  * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
386  */
387 template <typename T1, typename T2, typename T3,
388           require_any_matrix_t<T1, T2, T3>* = nullptr,
389           require_var_t<return_type_t<T1, T2, T3>>* = nullptr>
fma(const T1 & x,const T2 & y,const T3 & z)390 inline auto fma(const T1& x, const T2& y, const T3& z) {
391   arena_t<T1> arena_x = x;
392   arena_t<T2> arena_y = y;
393   arena_t<T3> arena_z = z;
394   if (is_matrix<T1>::value && is_matrix<T2>::value) {
395     check_matching_dims("fma", "x", arena_x, "y", arena_y);
396   }
397   if (is_matrix<T1>::value && is_matrix<T3>::value) {
398     check_matching_dims("fma", "x", arena_x, "z", arena_z);
399   }
400   if (is_matrix<T2>::value && is_matrix<T3>::value) {
401     check_matching_dims("fma", "y", arena_y, "z", arena_z);
402   }
403   using inner_ret_type
404       = decltype(fma(value_of(arena_x), value_of(arena_y), value_of(arena_z)));
405   using ret_type = return_var_matrix_t<inner_ret_type, T1, T2, T3>;
406   arena_t<ret_type> ret
407       = fma(value_of(arena_x), value_of(arena_y), value_of(arena_z));
408   reverse_pass_callback(
409       internal::fma_reverse_pass(arena_x, arena_y, arena_z, ret));
410   return ret_type(ret);
411 }
412 
413 }  // namespace math
414 }  // namespace stan
415 #endif
416