1 #ifndef STAN_MATH_REV_FUN_SQUARE_HPP
2 #define STAN_MATH_REV_FUN_SQUARE_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 
7 namespace stan {
8 namespace math {
9 
10 /**
11  * Return the square of the input variable.
12  *
13  * <p>Using <code>square(x)</code> is more efficient
14  * than using <code>x * x</code>.
15  *
16    \f[
17    \mbox{square}(x) =
18    \begin{cases}
19      x^2 & \mbox{if } -\infty\leq x \leq \infty \\[6pt]
20      \textrm{NaN} & \mbox{if } x = \textrm{NaN}
21    \end{cases}
22    \f]
23 
24    \f[
25    \frac{\partial\, \mbox{square}(x)}{\partial x} =
26    \begin{cases}
27      2x & \mbox{if } -\infty\leq x\leq \infty \\[6pt]
28      \textrm{NaN} & \mbox{if } x = \textrm{NaN}
29    \end{cases}
30    \f]
31  *
32  * @param x Variable to square.
33  * @return Square of variable.
34  */
square(const var & x)35 inline var square(const var& x) {
36   return make_callback_var(square(x.val()), [x](auto& vi) mutable {
37     x.adj() += vi.adj() * 2.0 * x.val();
38   });
39 }
40 
41 /**
42  * Return the elementwise square of x
43  *
44  * @tparam T type of x
45  * @param x argument
46  * @return elementwise square of x
47  */
48 template <typename T, require_var_matrix_t<T>* = nullptr>
square(const T & x)49 inline auto square(const T& x) {
50   return make_callback_var(
51       (x.val().array().square()).matrix(), [x](const auto& vi) mutable {
52         x.adj() += (2.0 * x.val().array() * vi.adj().array()).matrix();
53       });
54 }
55 
56 }  // namespace math
57 }  // namespace stan
58 #endif
59