1 #ifndef STAN_MATH_PRIM_FUN_LOG_SUM_EXP_HPP
2 #define STAN_MATH_PRIM_FUN_LOG_SUM_EXP_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/fun/constants.hpp>
6 #include <stan/math/prim/fun/Eigen.hpp>
7 #include <stan/math/prim/fun/log1p_exp.hpp>
8 #include <cmath>
9 #include <vector>
10
11 namespace stan {
12 namespace math {
13
14 /**
15 * Calculates the log sum of exponentials without overflow.
16 *
17 * \f$\log (\exp(a) + \exp(b)) = m + \log(\exp(a-m) + \exp(b-m))\f$,
18 *
19 * where \f$m = max(a, b)\f$.
20 *
21 \f[
22 \mbox{log\_sum\_exp}(x, y) =
23 \begin{cases}
24 \ln(\exp(x)+\exp(y)) & \mbox{if } -\infty\leq x, y \leq \infty \\[6pt]
25 \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
26 \end{cases}
27 \f]
28
29 \f[
30 \frac{\partial\, \mbox{log\_sum\_exp}(x, y)}{\partial x} =
31 \begin{cases}
32 \frac{\exp(x)}{\exp(x)+\exp(y)} & \mbox{if } -\infty\leq x, y \leq \infty
33 \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
34 \end{cases}
35 \f]
36
37 \f[
38 \frac{\partial\, \mbox{log\_sum\_exp}(x, y)}{\partial y} =
39 \begin{cases}
40 \frac{\exp(y)}{\exp(x)+\exp(y)} & \mbox{if } -\infty\leq x, y \leq \infty
41 \\[6pt] \textrm{NaN} & \mbox{if } x = \textrm{NaN or } y = \textrm{NaN}
42 \end{cases}
43 \f]
44 *
45 * @tparam T1 type of the first variable
46 * @tparam T2 type of the second variable
47 * @param a the first variable
48 * @param b the second variable
49 */
50 template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr>
log_sum_exp(const T2 & a,const T1 & b)51 inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
52 if (a == NEGATIVE_INFTY) {
53 return b;
54 }
55 if (a == INFTY && b == INFTY) {
56 return INFTY;
57 }
58 if (a > b) {
59 return a + log1p_exp(b - a);
60 }
61 return b + log1p_exp(a - b);
62 }
63
64 /**
65 * Return the log of the sum of the exponentiated values of the specified
66 * matrix of values. The matrix may be a full matrix, a vector,
67 * a row vector, or a container of these.
68 *
69 * The function is defined as follows to prevent overflow in exponential
70 * calculations.
71 *
72 * \f$\log \sum_{n=1}^N \exp(x_n) = \max(x) + \log \sum_{n=1}^N \exp(x_n -
73 * \max(x))\f$.
74 *
75 * @tparam T type of input vector or matrix
76 * @param[in] x matrix of specified values
77 * @return The log of the sum of the exponentiated vector values.
78 */
79 template <typename T, require_container_st<std::is_arithmetic, T>* = nullptr>
log_sum_exp(const T & x)80 inline auto log_sum_exp(const T& x) {
81 return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
82 if (v.size() == 0) {
83 return NEGATIVE_INFTY;
84 }
85 const auto& v_ref = to_ref(v);
86 const double max = v_ref.maxCoeff();
87 if (!std::isfinite(max)) {
88 return max;
89 }
90 return max + std::log((v_ref.array() - max).exp().sum());
91 });
92 }
93
94 } // namespace math
95 } // namespace stan
96
97 #endif
98