1 #ifndef STAN_MATH_REV_FUN_SQRT_HPP
2 #define STAN_MATH_REV_FUN_SQRT_HPP
3 
4 #include <stan/math/prim/fun/sqrt.hpp>
5 #include <stan/math/rev/meta.hpp>
6 #include <stan/math/rev/core.hpp>
7 #include <stan/math/rev/fun/atan2.hpp>
8 #include <stan/math/rev/fun/cos.hpp>
9 #include <stan/math/rev/fun/hypot.hpp>
10 #include <cmath>
11 #include <complex>
12 
13 namespace stan {
14 namespace math {
15 
16 /**
17  * Return the square root of the specified variable (cmath).
18  *
19  * The derivative is defined by
20  *
21  * \f$\frac{d}{dx} \sqrt{x} = \frac{1}{2 \sqrt{x}}\f$.
22  *
23    \f[
24    \mbox{sqrt}(x) =
25    \begin{cases}
26      \textrm{NaN} & x < 0 \\
27      \sqrt{x} & \mbox{if } x\geq 0\\[6pt]
28      \textrm{NaN} & \mbox{if } x = \textrm{NaN}
29    \end{cases}
30    \f]
31 
32    \f[
33    \frac{\partial\, \mbox{sqrt}(x)}{\partial x} =
34    \begin{cases}
35      \textrm{NaN} & x < 0 \\
36      \frac{1}{2\sqrt{x}} & x\geq 0\\[6pt]
37      \textrm{NaN} & \mbox{if } x = \textrm{NaN}
38    \end{cases}
39    \f]
40  *
41  * @param a Variable whose square root is taken.
42  * @return Square root of variable.
43  */
sqrt(const var & a)44 inline var sqrt(const var& a) {
45   return make_callback_var(std::sqrt(a.val()), [a](auto& vi) mutable {
46     a.adj() += vi.adj() / (2.0 * vi.val());
47   });
48 }
49 
50 /**
51  * Return elementwise square root of vector
52  *
53  * @tparam T a `var_value` with inner Eigen type
54  * @param a input
55  * @return elementwise square root of vector
56  */
57 template <typename T, require_var_matrix_t<T>* = nullptr>
sqrt(const T & a)58 inline auto sqrt(const T& a) {
59   return make_callback_var(
60       a.val().array().sqrt().matrix(), [a](auto& vi) mutable {
61         a.adj().array() += vi.adj().array() / (2.0 * vi.val_op().array());
62       });
63 }
64 
65 /**
66  * Return the square root of the complex argument.
67  *
68  * @param[in] z argument
69  * @return square root of the argument
70  */
sqrt(const std::complex<var> & z)71 inline std::complex<var> sqrt(const std::complex<var>& z) {
72   return internal::complex_sqrt(z);
73 }
74 
75 }  // namespace math
76 }  // namespace stan
77 #endif
78