1 #ifndef STAN_MATH_REV_MAT_FUN_LOG_SUM_EXP_HPP
2 #define STAN_MATH_REV_MAT_FUN_LOG_SUM_EXP_HPP
3 
4 #include <stan/math/rev/core.hpp>
5 #include <stan/math/rev/scal/fun/calculate_chain.hpp>
6 #include <stan/math/prim/scal/fun/log_sum_exp.hpp>
7 #include <stan/math/prim/mat/fun/Eigen.hpp>
8 #include <limits>
9 
10 namespace stan {
11 namespace math {
12 
13 namespace internal {
14 
15 // these function and the following class just translate
16 // log_sum_exp for std::vector for Eigen::Matrix
17 
18 template <int R, int C>
log_sum_exp_as_double(const Eigen::Matrix<var,R,C> & x)19 inline double log_sum_exp_as_double(const Eigen::Matrix<var, R, C>& x) {
20   using std::exp;
21   using std::log;
22   using std::numeric_limits;
23   double max = -numeric_limits<double>::infinity();
24   for (int i = 0; i < x.size(); ++i)
25     if (x(i) > max)
26       max = x(i).val();
27   double sum = 0.0;
28   for (int i = 0; i < x.size(); ++i)
29     if (x(i) != -numeric_limits<double>::infinity())
30       sum += exp(x(i).val() - max);
31   return max + log(sum);
32 }
33 
34 class log_sum_exp_matrix_vari : public op_matrix_vari {
35  public:
36   template <int R, int C>
log_sum_exp_matrix_vari(const Eigen::Matrix<var,R,C> & x)37   explicit log_sum_exp_matrix_vari(const Eigen::Matrix<var, R, C>& x)
38       : op_matrix_vari(log_sum_exp_as_double(x), x) {}
chain()39   void chain() {
40     for (size_t i = 0; i < size_; ++i) {
41       vis_[i]->adj_ += adj_ * calculate_chain(vis_[i]->val_, val_);
42     }
43   }
44 };
45 }  // namespace internal
46 
47 /**
48  * Returns the log sum of exponentials.
49  *
50  * @param x matrix
51  */
52 template <int R, int C>
log_sum_exp(const Eigen::Matrix<var,R,C> & x)53 inline var log_sum_exp(const Eigen::Matrix<var, R, C>& x) {
54   return var(new internal::log_sum_exp_matrix_vari(x));
55 }
56 
57 }  // namespace math
58 }  // namespace stan
59 #endif
60