1 #ifndef STAN_MATH_REV_FUN_SQUARE_HPP 2 #define STAN_MATH_REV_FUN_SQUARE_HPP 3 4 #include <stan/math/rev/meta.hpp> 5 #include <stan/math/rev/core.hpp> 6 7 namespace stan { 8 namespace math { 9 10 /** 11 * Return the square of the input variable. 12 * 13 * <p>Using <code>square(x)</code> is more efficient 14 * than using <code>x * x</code>. 15 * 16 \f[ 17 \mbox{square}(x) = 18 \begin{cases} 19 x^2 & \mbox{if } -\infty\leq x \leq \infty \\[6pt] 20 \textrm{NaN} & \mbox{if } x = \textrm{NaN} 21 \end{cases} 22 \f] 23 24 \f[ 25 \frac{\partial\, \mbox{square}(x)}{\partial x} = 26 \begin{cases} 27 2x & \mbox{if } -\infty\leq x\leq \infty \\[6pt] 28 \textrm{NaN} & \mbox{if } x = \textrm{NaN} 29 \end{cases} 30 \f] 31 * 32 * @param x Variable to square. 33 * @return Square of variable. 34 */ square(const var & x)35inline var square(const var& x) { 36 return make_callback_var(square(x.val()), [x](auto& vi) mutable { 37 x.adj() += vi.adj() * 2.0 * x.val(); 38 }); 39 } 40 41 /** 42 * Return the elementwise square of x 43 * 44 * @tparam T type of x 45 * @param x argument 46 * @return elementwise square of x 47 */ 48 template <typename T, require_var_matrix_t<T>* = nullptr> square(const T & x)49inline auto square(const T& x) { 50 return make_callback_var( 51 (x.val().array().square()).matrix(), [x](const auto& vi) mutable { 52 x.adj() += (2.0 * x.val().array() * vi.adj().array()).matrix(); 53 }); 54 } 55 56 } // namespace math 57 } // namespace stan 58 #endif 59