1 #ifndef STAN_MATH_PRIM_FUN_LOG_MIX_HPP
2 #define STAN_MATH_PRIM_FUN_LOG_MIX_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/as_array_or_scalar.hpp>
7 #include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
8 #include <stan/math/prim/fun/exp.hpp>
9 #include <stan/math/prim/fun/log.hpp>
10 #include <stan/math/prim/fun/log1m.hpp>
11 #include <stan/math/prim/fun/log_sum_exp.hpp>
12 #include <stan/math/prim/fun/size.hpp>
13 #include <stan/math/prim/fun/to_ref.hpp>
14 #include <stan/math/prim/fun/value_of.hpp>
15 #include <stan/math/prim/functor/operands_and_partials.hpp>
16 #include <vector>
17 #include <cmath>
18 
19 namespace stan {
20 namespace math {
21 
22 /**
23  * Return the log mixture density with specified mixing proportion
24  * and log densities.
25  *
26  * \f[
27  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
28  * = \log \left( \theta \lambda_1 + (1 - \theta) \lambda_2 \right).
29  * \f]
30  *
31  * @tparam T_theta type of mixing proportion - must be an arithmetic type
32  * @tparam T_lambda1 type of first log density - must be an arithmetic type
33  * @tparam T_lambda2 type of second log density - must be an arithmetic type
34  * @param[in] theta mixing proportion in [0, 1].
35  * @param[in] lambda1 first log density.
36  * @param[in] lambda2 second log density.
37  * @return log mixture of densities in specified proportion
38  */
39 template <typename T_theta, typename T_lambda1, typename T_lambda2,
40           require_all_arithmetic_t<T_theta, T_lambda1, T_lambda2>* = nullptr>
log_mix(T_theta theta,T_lambda1 lambda1,T_lambda2 lambda2)41 inline double log_mix(T_theta theta, T_lambda1 lambda1, T_lambda2 lambda2) {
42   using std::log;
43   check_not_nan("log_mix", "lambda1", lambda1);
44   check_not_nan("log_mix", "lambda2", lambda2);
45   check_bounded("log_mix", "theta", theta, 0, 1);
46   return log_sum_exp(log(theta) + lambda1, log1m(theta) + lambda2);
47 }
48 
49 /**
50  * Return the log mixture density with specified mixing proportions
51  * and log densities.
52  *
53  * \f[
54  * \frac{\partial }{\partial p_x}
55  * \log\left(\exp^{\log\left(p_1\right)+d_1}+\cdot\cdot\cdot+
56  * \exp^{\log\left(p_n\right)+d_n}\right)
57  * =\frac{e^{d_x}}{e^{d_1}p_1+\cdot\cdot\cdot+e^{d_m}p_m}
58  * \f]
59  *
60  * \f[
61  * \frac{\partial }{\partial d_x}
62  * \log\left(\exp^{\log\left(p_1\right)+d_1}+\cdot\cdot\cdot+
63  * \exp^{\log\left(p_n\right)+d_n}\right)
64  * =\frac{e^{d_x}p_x}{e^{d_1}p_1+\cdot\cdot\cdot+e^{d_m}p_m}
65  * \f]
66  *
67  * @tparam T_theta Type of theta. This can be a scalar, std vector or row/column
68  * vector.
69  * @tparam T_lam Type of lambda. This can be a scalar, std vector or row/column
70  * vector.
71  * @param theta std/row/col vector of mixing proportions in [0, 1].
72  * @param lambda std/row/col vector of log densities.
73  * @return log mixture of densities in specified proportion
74  */
75 template <typename T_theta, typename T_lam,
76           require_any_vector_t<T_theta, T_lam>* = nullptr>
log_mix(const T_theta & theta,const T_lam & lambda)77 return_type_t<T_theta, T_lam> log_mix(const T_theta& theta,
78                                       const T_lam& lambda) {
79   static const char* function = "log_mix";
80   using T_partials_return = partials_return_t<T_theta, T_lam>;
81   using T_partials_vec =
82       typename Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
83   using T_theta_ref = ref_type_t<T_theta>;
84   using T_lam_ref = ref_type_t<T_lam>;
85 
86   const int N = stan::math::size(theta);
87 
88   check_consistent_sizes(function, "theta", theta, "lambda", lambda);
89   T_theta_ref theta_ref = theta;
90   T_lam_ref lambda_ref = lambda;
91   check_bounded(function, "theta", theta_ref, 0, 1);
92   check_finite(function, "lambda", lambda_ref);
93 
94   const auto& theta_dbl
95       = to_ref(value_of(as_column_vector_or_scalar(theta_ref)));
96   const auto& lam_dbl
97       = to_ref(value_of(as_column_vector_or_scalar(lambda_ref)));
98 
99   T_partials_return logp = log_sum_exp(log(theta_dbl) + lam_dbl);
100 
101   operands_and_partials<T_theta_ref, T_lam_ref> ops_partials(theta_ref,
102                                                              lambda_ref);
103   if (!is_constant_all<T_lam, T_theta>::value) {
104     T_partials_vec theta_deriv = (lam_dbl.array() - logp).exp();
105     if (!is_constant_all<T_lam>::value) {
106       ops_partials.edge2_.partials_ = theta_deriv.cwiseProduct(theta_dbl);
107     }
108     if (!is_constant_all<T_theta>::value) {
109       ops_partials.edge1_.partials_ = std::move(theta_deriv);
110     }
111   }
112   return ops_partials.build(logp);
113 }
114 
115 /**
116  * Return the log mixture density given specified mixing proportions
117  * and array of log density vectors.
118  *
119  * \f[
120  * \frac{\partial }{\partial p_x}\left[
121  * \log\left(\exp^{\log\left(p_1\right)+d_1}+\cdot\cdot\cdot+
122  * \exp^{\log\left(p_n\right)+d_n}\right)+
123  * \log\left(\exp^{\log\left(p_1\right)+f_1}+\cdot\cdot\cdot+
124  * \exp^{\log\left(p_n\right)+f_n}\right)\right]
125  * =\frac{e^{d_x}}{e^{d_1}p_1+\cdot\cdot\cdot+e^{d_m}p_m}+
126  * \frac{e^{f_x}}{e^{f_1}p_1+\cdot\cdot\cdot+e^{f_m}p_m}
127  * \f]
128  *
129  * \f[
130  * \frac{\partial }{\partial d_x}\left[
131  * \log\left(\exp^{\log\left(p_1\right)+d_1}+\cdot\cdot\cdot+
132  * \exp^{\log\left(p_n\right)+d_n}\right)
133  * +\log\left(\exp^{\log\left(p_1\right)+f_1}+\cdot\cdot\cdot+
134  * \exp^{\log\left(p_n\right)+f_n}\right)\right]
135  * =\frac{e^{d_x}p_x}{e^{d_1}p_1+\cdot\cdot\cdot+e^{d_m}p_m}
136  * \f]
137  *
138  * @tparam T_theta Type of theta. This can be a scalar, std vector or row/column
139  * vector
140  * @tparam T_lam Type of vector in std vector lambda. This can be std vector or
141  * row/column vector.
142  * @param theta std/row/col vector of mixing proportions in [0, 1].
143  * @param lambda std vector containing std/row/col vectors of log densities.
144  * @return log mixture of densities in specified proportion
145  */
146 template <typename T_theta, typename T_lam, require_vector_t<T_lam>* = nullptr>
log_mix(const T_theta & theta,const std::vector<T_lam> & lambda)147 return_type_t<T_theta, std::vector<T_lam>> log_mix(
148     const T_theta& theta, const std::vector<T_lam>& lambda) {
149   static const char* function = "log_mix";
150   using T_partials_return = partials_return_t<T_theta, std::vector<T_lam>>;
151   using T_partials_vec =
152       typename Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
153   using T_partials_mat =
154       typename Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
155   using T_lamvec_type = typename std::vector<T_lam>;
156   using T_theta_ref = ref_type_t<T_theta>;
157 
158   const int N = stan::math::size(lambda);
159   const int M = theta.size();
160 
161   T_theta_ref theta_ref = theta;
162   check_bounded(function, "theta", theta_ref, 0, 1);
163   for (int n = 0; n < N; ++n) {
164     check_not_nan(function, "lambda", lambda[n]);
165     check_finite(function, "lambda", lambda[n]);
166     check_consistent_sizes(function, "theta", theta, "lambda", lambda[n]);
167   }
168 
169   const auto& theta_dbl
170       = to_ref(value_of(as_column_vector_or_scalar(theta_ref)));
171 
172   T_partials_mat lam_dbl(M, N);
173   for (int n = 0; n < N; ++n) {
174     lam_dbl.col(n) = value_of(as_column_vector_or_scalar(lambda[n]));
175   }
176 
177   T_partials_mat logp_tmp = lam_dbl.colwise() + log(theta_dbl);
178   T_partials_vec logp(N);
179   for (int n = 0; n < N; ++n) {
180     logp[n] = log_sum_exp(logp_tmp.col(n));
181   }
182 
183   operands_and_partials<T_theta_ref, T_lamvec_type> ops_partials(theta_ref,
184                                                                  lambda);
185   if (!is_constant_all<T_theta, T_lam>::value) {
186     T_partials_mat derivs = exp(lam_dbl.rowwise() - logp.transpose());
187     if (!is_constant_all<T_theta>::value) {
188       ops_partials.edge1_.partials_ = derivs.rowwise().sum();
189     }
190     if (!is_constant_all<T_lam>::value) {
191       for (int n = 0; n < N; ++n) {
192         as_column_vector_or_scalar(ops_partials.edge2_.partials_vec_[n])
193             = derivs.col(n).cwiseProduct(theta_dbl);
194       }
195     }
196   }
197   return ops_partials.build(logp.sum());
198 }
199 
200 }  // namespace math
201 }  // namespace stan
202 #endif
203