1 #ifndef STAN_MATH_REV_FUN_LOG_SOFTMAX_HPP
2 #define STAN_MATH_REV_FUN_LOG_SOFTMAX_HPP
3 
4 #include <stan/math/rev/core.hpp>
5 #include <stan/math/rev/core/typedefs.hpp>
6 #include <stan/math/prim/meta.hpp>
7 #include <stan/math/prim/err.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9 #include <stan/math/prim/fun/log_softmax.hpp>
10 #include <stan/math/prim/fun/softmax.hpp>
11 #include <stan/math/prim/fun/to_ref.hpp>
12 #include <stan/math/prim/fun/typedefs.hpp>
13 #include <cmath>
14 #include <vector>
15 
16 namespace stan {
17 namespace math {
18 
19 namespace internal {
20 
21 class log_softmax_elt_vari : public vari {
22  private:
23   vari** alpha_;
24   const double* softmax_alpha_;
25   const int size_;  // array sizes
26   const int idx_;   // in in softmax output
27 
28  public:
log_softmax_elt_vari(double val,vari ** alpha,const double * softmax_alpha,int size,int idx)29   log_softmax_elt_vari(double val, vari** alpha, const double* softmax_alpha,
30                        int size, int idx)
31       : vari(val),
32         alpha_(alpha),
33         softmax_alpha_(softmax_alpha),
34         size_(size),
35         idx_(idx) {}
chain()36   void chain() {
37     for (int m = 0; m < size_; ++m) {
38       if (m == idx_) {
39         alpha_[m]->adj_ += adj_ * (1 - softmax_alpha_[m]);
40       } else {
41         alpha_[m]->adj_ -= adj_ * softmax_alpha_[m];
42       }
43     }
44   }
45 };
46 }  // namespace internal
47 
48 /**
49  * Return the log softmax of the specified vector
50  *
51  * @tparam T type of input
52  * @param x input
53  * @return softmax of the input
54  * @throw std::domain_error if the input size is 0
55  */
56 template <typename T, require_eigen_st<is_var, T>* = nullptr>
log_softmax(const T & x)57 auto log_softmax(const T& x) {
58   const int a_size = x.size();
59 
60   check_nonzero_size("log_softmax", "x", x);
61 
62   const auto& x_ref = to_ref(x);
63 
64   vari** x_vi_array
65       = ChainableStack::instance_->memalloc_.alloc_array<vari*>(a_size);
66   Eigen::Map<vector_vi>(x_vi_array, a_size) = x_ref.vi();
67 
68   vector_d x_d = x_ref.val();
69 
70   // fold logic of math::softmax() and math::log_softmax()
71   // to save computations
72 
73   vector_d diff = (x_d.array() - x_d.maxCoeff());
74   vector_d softmax_x_d = diff.array().exp();
75   double sum = softmax_x_d.sum();
76   vector_d log_softmax_x_d = diff.array() - std::log(sum);
77 
78   // end fold
79   double* softmax_x_d_array
80       = ChainableStack::instance_->memalloc_.alloc_array<double>(a_size);
81   Eigen::Map<vector_d>(softmax_x_d_array, a_size) = softmax_x_d.array() / sum;
82 
83   plain_type_t<T> log_softmax_x(a_size);
84   for (int k = 0; k < a_size; ++k) {
85     log_softmax_x(k) = var(new internal::log_softmax_elt_vari(
86         log_softmax_x_d[k], x_vi_array, softmax_x_d_array, a_size, k));
87   }
88   return log_softmax_x;
89 }
90 
91 /**
92  * Return the log softmax of the specified vector
93  *
94  * @tparam T type of input
95  * @param x input
96  * @return softmax of the input
97  * @throw std::domain_error if the input size is 0
98  */
99 template <typename T, require_var_matrix_t<T>* = nullptr>
log_softmax(const T & x)100 inline auto log_softmax(const T& x) {
101   check_nonzero_size("log_softmax", "x", x);
102 
103   const auto& theta = (x.val().array() - x.val().maxCoeff()).eval();
104 
105   return make_callback_var(
106       (theta.array() - log(theta.exp().sum())).matrix(),
107       [x](const auto& res) mutable {
108         x.adj().noalias()
109             += res.adj() - (res.adj().sum() * res.val().array().exp()).matrix();
110       });
111 }
112 
113 /**
114  * Return the log softmax of the specified `std::vector` or
115  * `std::vector` of containers.
116  *
117  * @tparam T type of input
118  * @param x input
119  * @return softmax of the input
120  * @throw std::domain_error if the input size is 0
121  */
122 template <typename T, require_std_vector_st<is_var, T>* = nullptr>
log_softmax(const T & x)123 inline auto log_softmax(const T& x) {
124   return apply_vector_unary<T>::apply(
125       x, [](const auto& alpha) { return log_softmax(alpha); });
126 }
127 
128 }  // namespace math
129 }  // namespace stan
130 #endif
131