1 #ifndef STAN_MATH_PRIM_FUN_INV_SQRT_HPP
2 #define STAN_MATH_PRIM_FUN_INV_SQRT_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/Eigen.hpp>
6 #include <stan/math/prim/fun/inv.hpp>
7 #include <stan/math/prim/fun/sqrt.hpp>
8 #include <stan/math/prim/functor/apply_scalar_unary.hpp>
9 #include <stan/math/prim/functor/apply_vector_unary.hpp>
10 #include <cmath>
11 
12 namespace stan {
13 namespace math {
14 
15 template <typename T, require_stan_scalar_t<T>* = nullptr>
inv_sqrt(T x)16 inline auto inv_sqrt(T x) {
17   using std::sqrt;
18   return inv(sqrt(x));
19 }
20 /**
21  * Structure to wrap `1 / sqrt(x)` so that it can be vectorized.
22  *
23  * @tparam T type of variable
24  * @param x variable
25  * @return inverse square root of x.
26  */
27 struct inv_sqrt_fun {
28   template <typename T>
funstan::math::inv_sqrt_fun29   static inline T fun(const T& x) {
30     return inv_sqrt(x);
31   }
32 };
33 
34 /**
35  * Return the elementwise `1 / sqrt(x)}` of the specified argument,
36  * which may be a scalar or any Stan container of numeric scalars.
37  *
38  * @tparam Container type of container
39  * @param x container
40  * @return inverse square root of each value in x.
41  */
42 template <typename Container,
43           require_not_container_st<std::is_arithmetic, Container>* = nullptr,
44           require_not_var_matrix_t<Container>* = nullptr,
45           require_not_stan_scalar_t<Container>* = nullptr,
46           require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
47               Container>* = nullptr>
inv_sqrt(const Container & x)48 inline auto inv_sqrt(const Container& x) {
49   return apply_scalar_unary<inv_sqrt_fun, Container>::apply(x);
50 }
51 
52 /**
53  * Version of `inv_sqrt()` that accepts std::vectors, Eigen Matrix/Array objects
54  *  or expressions, and containers of these.
55  *
56  * @tparam Container Type of x
57  * @param x Container
58  * @return inverse square root each variable in the container.
59  */
60 template <typename Container, require_not_var_matrix_t<Container>* = nullptr,
61           require_container_st<std::is_arithmetic, Container>* = nullptr>
inv_sqrt(const Container & x)62 inline auto inv_sqrt(const Container& x) {
63   return apply_vector_unary<Container>::apply(
64       x, [](const auto& v) { return v.array().rsqrt(); });
65 }
66 
67 }  // namespace math
68 }  // namespace stan
69 
70 #endif
71