1 #ifndef STAN_MATH_PRIM_FUN_LOG_SUM_EXP_HPP
2 #define STAN_MATH_PRIM_FUN_LOG_SUM_EXP_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/constants.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/log1p_exp.hpp>
8 #include <cmath>
9 #include <vector>
10 
11 namespace stan {
12 namespace math {
13 
14 /**
15  * Calculates the log sum of exponentials without overflow.
16  *
17  * \f$\log (\exp(a) + \exp(b)) = m + \log(\exp(a-m) + \exp(b-m))\f$,
18  *
19  * where \f$m = max(a, b)\f$.
20  *
21    \f[
22    \mbox{log\_sum\_exp}(x, y) =
23    \begin{cases}
24      \ln(\exp(x)+\exp(y)) & \mbox{if } -\infty\leq x, y \leq \infty \\[6pt]
25      \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
26    \end{cases}
27    \f]
28 
29    \f[
30    \frac{\partial\, \mbox{log\_sum\_exp}(x, y)}{\partial x} =
31    \begin{cases}
32      \frac{\exp(x)}{\exp(x)+\exp(y)} & \mbox{if } -\infty\leq x, y \leq \infty
33  \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
34    \end{cases}
35    \f]
36 
37    \f[
38    \frac{\partial\, \mbox{log\_sum\_exp}(x, y)}{\partial y} =
39    \begin{cases}
40      \frac{\exp(y)}{\exp(x)+\exp(y)} & \mbox{if } -\infty\leq x, y \leq \infty
41  \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
42    \end{cases}
43    \f]
44  *
45  * @tparam T1 type of the first variable
46  * @tparam T2 type of the second variable
47  * @param a the first variable
48  * @param b the second variable
49  */
50 template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr>
log_sum_exp(const T2 & a,const T1 & b)51 inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
52   if (a == NEGATIVE_INFTY) {
53     return b;
54   }
55   if (a == INFTY && b == INFTY) {
56     return INFTY;
57   }
58   if (a > b) {
59     return a + log1p_exp(b - a);
60   }
61   return b + log1p_exp(a - b);
62 }
63 
64 /**
65  * Return the log of the sum of the exponentiated values of the specified
66  * matrix of values.  The matrix may be a full matrix, a vector,
67  * a row vector, or a container of these.
68  *
69  * The function is defined as follows to prevent overflow in exponential
70  * calculations.
71  *
72  * \f$\log \sum_{n=1}^N \exp(x_n) = \max(x) + \log \sum_{n=1}^N \exp(x_n -
73  * \max(x))\f$.
74  *
75  * @tparam T type of input vector or matrix
76  * @param[in] x matrix of specified values
77  * @return The log of the sum of the exponentiated vector values.
78  */
79 template <typename T, require_container_st<std::is_arithmetic, T>* = nullptr>
log_sum_exp(const T & x)80 inline auto log_sum_exp(const T& x) {
81   return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
82     if (v.size() == 0) {
83       return NEGATIVE_INFTY;
84     }
85     const auto& v_ref = to_ref(v);
86     const double max = v_ref.maxCoeff();
87     if (!std::isfinite(max)) {
88       return max;
89     }
90     return max + std::log((v_ref.array() - max).exp().sum());
91   });
92 }
93 
94 }  // namespace math
95 }  // namespace stan
96 
97 #endif
98