1 #ifndef STAN_MATH_REV_FUN_INV_LOGIT_HPP
2 #define STAN_MATH_REV_FUN_INV_LOGIT_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/fun/inv_logit.hpp>
7 
8 namespace stan {
9 namespace math {
10 
11 /**
12  * The inverse logit function for variables (stan).
13  *
14  * See inv_logit() for the double-based version.
15  *
16  * The derivative of inverse logit is
17  *
18  * \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 -
19  * \mbox{logit}^{-1}(x))\f$.
20  *
21  * @tparam T Arithmetic or a type inheriting from `EigenBase`.
22  * @param a Argument variable.
23  * @return Inverse logit of argument.
24  */
25 template <typename T, require_stan_scalar_or_eigen_t<T>* = nullptr>
inv_logit(const var_value<T> & a)26 inline auto inv_logit(const var_value<T>& a) {
27   return make_callback_var(inv_logit(a.val()), [a](auto& vi) mutable {
28     as_array_or_scalar(a).adj() += as_array_or_scalar(vi.adj())
29                                    * as_array_or_scalar(vi.val())
30                                    * (1.0 - as_array_or_scalar(vi.val()));
31   });
32 }
33 
34 }  // namespace math
35 }  // namespace stan
36 #endif
37