1 #ifndef STAN_MATH_REV_FUN_LOG_SOFTMAX_HPP
2 #define STAN_MATH_REV_FUN_LOG_SOFTMAX_HPP
3
4 #include <stan/math/rev/core.hpp>
5 #include <stan/math/rev/core/typedefs.hpp>
6 #include <stan/math/prim/meta.hpp>
7 #include <stan/math/prim/err.hpp>
8 #include <stan/math/prim/fun/Eigen.hpp>
9 #include <stan/math/prim/fun/log_softmax.hpp>
10 #include <stan/math/prim/fun/softmax.hpp>
11 #include <stan/math/prim/fun/to_ref.hpp>
12 #include <stan/math/prim/fun/typedefs.hpp>
13 #include <cmath>
14 #include <vector>
15
16 namespace stan {
17 namespace math {
18
19 namespace internal {
20
21 class log_softmax_elt_vari : public vari {
22 private:
23 vari** alpha_;
24 const double* softmax_alpha_;
25 const int size_; // array sizes
26 const int idx_; // in in softmax output
27
28 public:
log_softmax_elt_vari(double val,vari ** alpha,const double * softmax_alpha,int size,int idx)29 log_softmax_elt_vari(double val, vari** alpha, const double* softmax_alpha,
30 int size, int idx)
31 : vari(val),
32 alpha_(alpha),
33 softmax_alpha_(softmax_alpha),
34 size_(size),
35 idx_(idx) {}
chain()36 void chain() {
37 for (int m = 0; m < size_; ++m) {
38 if (m == idx_) {
39 alpha_[m]->adj_ += adj_ * (1 - softmax_alpha_[m]);
40 } else {
41 alpha_[m]->adj_ -= adj_ * softmax_alpha_[m];
42 }
43 }
44 }
45 };
46 } // namespace internal
47
48 /**
49 * Return the log softmax of the specified vector
50 *
51 * @tparam T type of input
52 * @param x input
53 * @return softmax of the input
54 * @throw std::domain_error if the input size is 0
55 */
56 template <typename T, require_eigen_st<is_var, T>* = nullptr>
log_softmax(const T & x)57 auto log_softmax(const T& x) {
58 const int a_size = x.size();
59
60 check_nonzero_size("log_softmax", "x", x);
61
62 const auto& x_ref = to_ref(x);
63
64 vari** x_vi_array
65 = ChainableStack::instance_->memalloc_.alloc_array<vari*>(a_size);
66 Eigen::Map<vector_vi>(x_vi_array, a_size) = x_ref.vi();
67
68 vector_d x_d = x_ref.val();
69
70 // fold logic of math::softmax() and math::log_softmax()
71 // to save computations
72
73 vector_d diff = (x_d.array() - x_d.maxCoeff());
74 vector_d softmax_x_d = diff.array().exp();
75 double sum = softmax_x_d.sum();
76 vector_d log_softmax_x_d = diff.array() - std::log(sum);
77
78 // end fold
79 double* softmax_x_d_array
80 = ChainableStack::instance_->memalloc_.alloc_array<double>(a_size);
81 Eigen::Map<vector_d>(softmax_x_d_array, a_size) = softmax_x_d.array() / sum;
82
83 plain_type_t<T> log_softmax_x(a_size);
84 for (int k = 0; k < a_size; ++k) {
85 log_softmax_x(k) = var(new internal::log_softmax_elt_vari(
86 log_softmax_x_d[k], x_vi_array, softmax_x_d_array, a_size, k));
87 }
88 return log_softmax_x;
89 }
90
91 /**
92 * Return the log softmax of the specified vector
93 *
94 * @tparam T type of input
95 * @param x input
96 * @return softmax of the input
97 * @throw std::domain_error if the input size is 0
98 */
99 template <typename T, require_var_matrix_t<T>* = nullptr>
log_softmax(const T & x)100 inline auto log_softmax(const T& x) {
101 check_nonzero_size("log_softmax", "x", x);
102
103 const auto& theta = (x.val().array() - x.val().maxCoeff()).eval();
104
105 return make_callback_var(
106 (theta.array() - log(theta.exp().sum())).matrix(),
107 [x](const auto& res) mutable {
108 x.adj().noalias()
109 += res.adj() - (res.adj().sum() * res.val().array().exp()).matrix();
110 });
111 }
112
113 /**
114 * Return the log softmax of the specified `std::vector` or
115 * `std::vector` of containers.
116 *
117 * @tparam T type of input
118 * @param x input
119 * @return softmax of the input
120 * @throw std::domain_error if the input size is 0
121 */
122 template <typename T, require_std_vector_st<is_var, T>* = nullptr>
log_softmax(const T & x)123 inline auto log_softmax(const T& x) {
124 return apply_vector_unary<T>::apply(
125 x, [](const auto& alpha) { return log_softmax(alpha); });
126 }
127
128 } // namespace math
129 } // namespace stan
130 #endif
131