1 #ifndef STAN_MATH_REV_FUN_FMA_HPP
2 #define STAN_MATH_REV_FUN_FMA_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/fun/constants.hpp>
7 #include <stan/math/prim/fun/fma.hpp>
8 #include <stan/math/prim/fun/is_any_nan.hpp>
9
10 namespace stan {
11 namespace math {
12
13 /**
14 * The fused multiply-add function for three variables (C99).
15 * This function returns the product of the first two arguments
16 * plus the third argument.
17 *
18 * The partial derivatives are
19 *
20 * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
21 *
22 * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$, and
23 *
24 * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
25 *
26 * @param x First multiplicand.
27 * @param y Second multiplicand.
28 * @param z Summand.
29 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
30 */
fma(const var & x,const var & y,const var & z)31 inline var fma(const var& x, const var& y, const var& z) {
32 return make_callback_var(fma(x.val(), y.val(), z.val()), [x, y, z](auto& vi) {
33 x.adj() += vi.adj() * y.val();
34 y.adj() += vi.adj() * x.val();
35 z.adj() += vi.adj();
36 });
37 }
38
39 /**
40 * The fused multiply-add function for two variables and a value
41 * (C99). This function returns the product of the first two
42 * arguments plus the third argument.
43 *
44 * The partial derivatives are
45 *
46 * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
47 *
48 * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$.
49 *
50 * @tparam Tc type of the summand
51 * @param x First multiplicand.
52 * @param y Second multiplicand.
53 * @param z Summand.
54 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
55 */
56 template <typename Tc, require_arithmetic_t<Tc>* = nullptr>
fma(const var & x,const var & y,Tc && z)57 inline var fma(const var& x, const var& y, Tc&& z) {
58 return make_callback_var(fma(x.val(), y.val(), z), [x, y](auto& vi) {
59 x.adj() += vi.adj() * y.val();
60 y.adj() += vi.adj() * x.val();
61 });
62 }
63
64 /**
65 * The fused multiply-add function for a variable, value, and
66 * variable (C99). This function returns the product of the first
67 * two arguments plus the third argument.
68 *
69 * The partial derivatives are
70 *
71 * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
72 *
73 * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
74 *
75 * @tparam Ta type of the first multiplicand
76 * @tparam Tb type of the second multiplicand
77 * @tparam Tc type of the summand
78 *
79 * @param x First multiplicand.
80 * @param y Second multiplicand.
81 * @param z Summand.
82 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
83 */
84 template <typename Ta, typename Tb, typename Tc,
85 require_arithmetic_t<Tb>* = nullptr,
86 require_all_var_t<Ta, Tc>* = nullptr>
fma(Ta && x,Tb && y,Tc && z)87 inline var fma(Ta&& x, Tb&& y, Tc&& z) {
88 return make_callback_var(fma(x.val(), y, z.val()), [x, y, z](auto& vi) {
89 x.adj() += vi.adj() * y;
90 z.adj() += vi.adj();
91 });
92 }
93
94 /**
95 * The fused multiply-add function for a variable and two values
96 * (C99). This function returns the product of the first two
97 * arguments plus the third argument.
98 *
99 * The double-based version
100 * <code>::%fma(double, double, double)</code> is defined in
101 * <code><cmath></code>.
102 *
103 * The derivative is
104 *
105 * \f$\frac{d}{d x} (x * y) + z = y\f$.
106 *
107 * @tparam Tb type of the second multiplicand
108 * @tparam Tc type of the summand
109 *
110 * @param x First multiplicand.
111 * @param y Second multiplicand.
112 * @param z Summand.
113 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
114 */
115 template <typename Tb, typename Tc, require_all_arithmetic_t<Tb, Tc>* = nullptr>
fma(const var & x,Tb && y,Tc && z)116 inline var fma(const var& x, Tb&& y, Tc&& z) {
117 return make_callback_var(fma(x.val(), y, z),
118 [x, y](auto& vi) { x.adj() += vi.adj() * y; });
119 }
120
121 /**
122 * The fused multiply-add function for a value, variable, and
123 * value (C99). This function returns the product of the first
124 * two arguments plus the third argument.
125 *
126 * The derivative is
127 *
128 * \f$\frac{d}{d y} (x * y) + z = x\f$, and
129 *
130 * @tparam Ta type of the first multiplicand
131 * @tparam Tc type of the summand
132 *
133 * @param x First multiplicand.
134 * @param y Second multiplicand.
135 * @param z Summand.
136 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
137 */
138 template <typename Ta, typename Tc, require_all_arithmetic_t<Ta, Tc>* = nullptr>
fma(Ta && x,const var & y,Tc && z)139 inline var fma(Ta&& x, const var& y, Tc&& z) {
140 return make_callback_var(fma(x, y.val(), z),
141 [x, y](auto& vi) { y.adj() += vi.adj() * x; });
142 }
143
144 /**
145 * The fused multiply-add function for two values and a variable,
146 * and value (C99). This function returns the product of the
147 * first two arguments plus the third argument.
148 *
149 * The derivative is
150 *
151 * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
152 *
153 * @tparam Ta type of the first multiplicand
154 * @tparam Tb type of the second multiplicand
155 *
156 * @param x First multiplicand.
157 * @param y Second multiplicand.
158 * @param z Summand.
159 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
160 */
161 template <typename Ta, typename Tb, require_all_arithmetic_t<Ta, Tb>* = nullptr>
fma(Ta && x,Tb && y,const var & z)162 inline var fma(Ta&& x, Tb&& y, const var& z) {
163 return make_callback_var(fma(x, y, z.val()),
164 [z](auto& vi) { z.adj() += vi.adj(); });
165 }
166
167 /**
168 * The fused multiply-add function for a value and two variables
169 * (C99). This function returns the product of the first two
170 * arguments plus the third argument.
171 *
172 * The partial derivatives are
173 *
174 * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$, and
175 *
176 * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
177 *
178 * @tparam Ta type of the first multiplicand
179 * @param x First multiplicand.
180 * @param y Second multiplicand.
181 * @param z Summand.
182 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
183 */
184 template <typename Ta, require_arithmetic_t<Ta>* = nullptr>
fma(Ta && x,const var & y,const var & z)185 inline var fma(Ta&& x, const var& y, const var& z) {
186 return make_callback_var(fma(x, y.val(), z.val()), [x, y, z](auto& vi) {
187 y.adj() += vi.adj() * x;
188 z.adj() += vi.adj();
189 });
190 }
191
192 namespace internal {
193 /**
194 * Overload for matrix, matrix, matrix
195 */
196 template <typename T1, typename T2, typename T3, typename T4,
197 require_all_matrix_t<T1, T2, T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)198 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
199 return [arena_x, arena_y, arena_z, ret]() mutable {
200 using T1_var = arena_t<plain_type_t<promote_scalar_t<var, T1>>>;
201 using T2_var = arena_t<plain_type_t<promote_scalar_t<var, T2>>>;
202 using T3_var = arena_t<plain_type_t<promote_scalar_t<var, T3>>>;
203 if (!is_constant<T1>::value) {
204 forward_as<T1_var>(arena_x).adj().array()
205 += ret.adj().array() * value_of(arena_y).array();
206 }
207 if (!is_constant<T2>::value) {
208 forward_as<T2_var>(arena_y).adj().array()
209 += ret.adj().array() * value_of(arena_x).array();
210 }
211 if (!is_constant<T3>::value) {
212 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
213 }
214 };
215 }
216
217 /**
218 * Overload for scalar, matrix, matrix
219 */
220 template <typename T1, typename T2, typename T3, typename T4,
221 require_all_matrix_t<T2, T3>* = nullptr,
222 require_stan_scalar_t<T1>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)223 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
224 return [arena_x, arena_y, arena_z, ret]() mutable {
225 using T1_var = arena_t<promote_scalar_t<var, T1>>;
226 using T2_var = arena_t<promote_scalar_t<var, T2>>;
227 using T3_var = arena_t<promote_scalar_t<var, T3>>;
228 if (!is_constant<T1>::value) {
229 forward_as<T1_var>(arena_x).adj()
230 += (ret.adj().array() * value_of(arena_y).array()).sum();
231 }
232 if (!is_constant<T2>::value) {
233 forward_as<T2_var>(arena_y).adj().array()
234 += ret.adj().array() * value_of(arena_x);
235 }
236 if (!is_constant<T3>::value) {
237 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
238 }
239 };
240 }
241
242 /**
243 * Overload for matrix, scalar, matrix
244 */
245 template <typename T1, typename T2, typename T3, typename T4,
246 require_all_matrix_t<T1, T3>* = nullptr,
247 require_stan_scalar_t<T2>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)248 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
249 return [arena_x, arena_y, arena_z, ret]() mutable {
250 using T1_var = arena_t<promote_scalar_t<var, T1>>;
251 using T2_var = arena_t<promote_scalar_t<var, T2>>;
252 using T3_var = arena_t<promote_scalar_t<var, T3>>;
253 if (!is_constant<T1>::value) {
254 forward_as<T1_var>(arena_x).adj().array()
255 += ret.adj().array() * value_of(arena_y);
256 }
257 if (!is_constant<T2>::value) {
258 forward_as<T2_var>(arena_y).adj()
259 += (ret.adj().array() * value_of(arena_x).array()).sum();
260 }
261 if (!is_constant<T3>::value) {
262 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
263 }
264 };
265 }
266
267 /**
268 * Overload for scalar, scalar, matrix
269 */
270 template <typename T1, typename T2, typename T3, typename T4,
271 require_matrix_t<T3>* = nullptr,
272 require_all_stan_scalar_t<T1, T2>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)273 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
274 return [arena_x, arena_y, arena_z, ret]() mutable {
275 using T1_var = arena_t<promote_scalar_t<var, T1>>;
276 using T2_var = arena_t<promote_scalar_t<var, T2>>;
277 using T3_var = arena_t<promote_scalar_t<var, T3>>;
278 if (!is_constant<T1>::value) {
279 forward_as<T1_var>(arena_x).adj()
280 += (ret.adj().array() * value_of(arena_y)).sum();
281 }
282 if (!is_constant<T2>::value) {
283 forward_as<T2_var>(arena_y).adj()
284 += (ret.adj().array() * value_of(arena_x)).sum();
285 }
286 if (!is_constant<T3>::value) {
287 forward_as<T3_var>(arena_z).adj().array() += ret.adj().array();
288 }
289 };
290 }
291
292 /**
293 * Overload for matrix, matrix, scalar
294 */
295 template <typename T1, typename T2, typename T3, typename T4,
296 require_all_matrix_t<T1, T2>* = nullptr,
297 require_stan_scalar_t<T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)298 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
299 return [arena_x, arena_y, arena_z, ret]() mutable {
300 using T1_var = arena_t<promote_scalar_t<var, T1>>;
301 using T2_var = arena_t<promote_scalar_t<var, T2>>;
302 using T3_var = arena_t<promote_scalar_t<var, T3>>;
303 if (!is_constant<T1>::value) {
304 forward_as<T1_var>(arena_x).adj().array()
305 += ret.adj().array() * value_of(arena_y).array();
306 }
307 if (!is_constant<T2>::value) {
308 forward_as<T2_var>(arena_y).adj().array()
309 += ret.adj().array() * value_of(arena_x).array();
310 }
311 if (!is_constant<T3>::value) {
312 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
313 }
314 };
315 }
316
317 /**
318 * Overload for scalar, matrix, scalar
319 */
320 template <typename T1, typename T2, typename T3, typename T4,
321 require_matrix_t<T2>* = nullptr,
322 require_all_stan_scalar_t<T1, T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)323 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
324 return [arena_x, arena_y, arena_z, ret]() mutable {
325 using T1_var = arena_t<promote_scalar_t<var, T1>>;
326 using T2_var = arena_t<promote_scalar_t<var, T2>>;
327 using T3_var = arena_t<promote_scalar_t<var, T3>>;
328 if (!is_constant<T1>::value) {
329 forward_as<T1_var>(arena_x).adj()
330 += (ret.adj().array() * value_of(arena_y).array()).sum();
331 }
332 if (!is_constant<T2>::value) {
333 forward_as<T2_var>(arena_y).adj().array()
334 += ret.adj().array() * value_of(arena_x);
335 }
336 if (!is_constant<T3>::value) {
337 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
338 }
339 };
340 }
341
342 /**
343 * Overload for matrix, scalar, scalar
344 */
345 template <typename T1, typename T2, typename T3, typename T4,
346 require_matrix_t<T1>* = nullptr,
347 require_all_stan_scalar_t<T2, T3>* = nullptr>
fma_reverse_pass(T1 & arena_x,T2 & arena_y,T3 & arena_z,T4 & ret)348 inline auto fma_reverse_pass(T1& arena_x, T2& arena_y, T3& arena_z, T4& ret) {
349 return [arena_x, arena_y, arena_z, ret]() mutable {
350 using T1_var = arena_t<promote_scalar_t<var, T1>>;
351 using T2_var = arena_t<promote_scalar_t<var, T2>>;
352 using T3_var = arena_t<promote_scalar_t<var, T3>>;
353 if (!is_constant<T1>::value) {
354 forward_as<T1_var>(arena_x).adj().array()
355 += ret.adj().array() * value_of(arena_y);
356 }
357 if (!is_constant<T2>::value) {
358 forward_as<T2_var>(arena_y).adj()
359 += (ret.adj().array() * value_of(arena_x).array()).sum();
360 }
361 if (!is_constant<T3>::value) {
362 forward_as<T3_var>(arena_z).adj() += ret.adj().sum();
363 }
364 };
365 }
366
367 } // namespace internal
368
369 /**
370 * The fused multiply-add function for three variables (C99).
371 * This function returns the product of the first two arguments
372 * plus the third argument.
373 *
374 * The partial derivatives are
375 *
376 * \f$\frac{\partial}{\partial x} (x * y) + z = y\f$, and
377 *
378 * \f$\frac{\partial}{\partial y} (x * y) + z = x\f$, and
379 *
380 * \f$\frac{\partial}{\partial z} (x * y) + z = 1\f$.
381 *
382 * @param x First multiplicand.
383 * @param y Second multiplicand.
384 * @param z Summand.
385 * @return Product of the multiplicands plus the summand, ($a * $b) + $c.
386 */
387 template <typename T1, typename T2, typename T3,
388 require_any_matrix_t<T1, T2, T3>* = nullptr,
389 require_var_t<return_type_t<T1, T2, T3>>* = nullptr>
fma(const T1 & x,const T2 & y,const T3 & z)390 inline auto fma(const T1& x, const T2& y, const T3& z) {
391 arena_t<T1> arena_x = x;
392 arena_t<T2> arena_y = y;
393 arena_t<T3> arena_z = z;
394 if (is_matrix<T1>::value && is_matrix<T2>::value) {
395 check_matching_dims("fma", "x", arena_x, "y", arena_y);
396 }
397 if (is_matrix<T1>::value && is_matrix<T3>::value) {
398 check_matching_dims("fma", "x", arena_x, "z", arena_z);
399 }
400 if (is_matrix<T2>::value && is_matrix<T3>::value) {
401 check_matching_dims("fma", "y", arena_y, "z", arena_z);
402 }
403 using inner_ret_type
404 = decltype(fma(value_of(arena_x), value_of(arena_y), value_of(arena_z)));
405 using ret_type = return_var_matrix_t<inner_ret_type, T1, T2, T3>;
406 arena_t<ret_type> ret
407 = fma(value_of(arena_x), value_of(arena_y), value_of(arena_z));
408 reverse_pass_callback(
409 internal::fma_reverse_pass(arena_x, arena_y, arena_z, ret));
410 return ret_type(ret);
411 }
412
413 } // namespace math
414 } // namespace stan
415 #endif
416