1 #ifndef STAN_MATH_REV_FUN_LOG_MIX_HPP
2 #define STAN_MATH_REV_FUN_LOG_MIX_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/fun/value_of.hpp>
6 #include <stan/math/prim/fun/log_mix.hpp>
7 #include <stan/math/prim/fun/value_of.hpp>
8 #include <stan/math/prim/functor/operands_and_partials.hpp>
9 #include <cmath>
10 
11 namespace stan {
12 namespace math {
13 
14 /* Computes shared terms in log_mix partial derivative calculations
15  *
16  * @param[in] theta_val value of mixing proportion theta.
17  * @param[in] lambda1_val value of log density multiplied by theta.
18  * @param[in] lambda2_val value of log density multiplied by 1 - theta.
19  * @param[out] one_m_exp_lam2_m_lam1 shared term in deriv calculation.
20  * @param[out] one_m_t_prod_exp_lam2_m_lam1 shared term in deriv calculation.
21  * @param[out] one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1 shared term in deriv
22  * calculation.
23  */
log_mix_partial_helper(double theta_val,double lambda1_val,double lambda2_val,double & one_m_exp_lam2_m_lam1,double & one_m_t_prod_exp_lam2_m_lam1,double & one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1)24 inline void log_mix_partial_helper(
25     double theta_val, double lambda1_val, double lambda2_val,
26     double& one_m_exp_lam2_m_lam1, double& one_m_t_prod_exp_lam2_m_lam1,
27     double& one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1) {
28   using std::exp;
29   double lam2_m_lam1 = lambda2_val - lambda1_val;
30   double exp_lam2_m_lam1 = exp(lam2_m_lam1);
31   one_m_exp_lam2_m_lam1 = 1 - exp_lam2_m_lam1;
32   double one_m_t = 1 - theta_val;
33   one_m_t_prod_exp_lam2_m_lam1 = one_m_t * exp_lam2_m_lam1;
34   one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1
35       = 1 / (theta_val + one_m_t_prod_exp_lam2_m_lam1);
36 }
37 
38 /**
39  * Return the log mixture density with specified mixing proportion
40  * and log densities and its derivative at each.
41  *
42  * \f[
43  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
44  * = \log \left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right).
45  * \f]
46  *
47  * \f[
48  * \frac{\partial}{\partial \theta}
49  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
50  * = \dfrac{\exp(\lambda_1) - \exp(\lambda_2)}
51  * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
52  * \f]
53  *
54  * \f[
55  * \frac{\partial}{\partial \lambda_1}
56  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
57  * = \dfrac{\theta \exp(\lambda_1)}
58  * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
59  * \f]
60  *
61  * \f[
62  * \frac{\partial}{\partial \lambda_2}
63  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
64  * = \dfrac{\theta \exp(\lambda_2)}
65  * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
66  * \f]
67  *
68  * @tparam T_theta theta scalar type.
69  * @tparam T_lambda1 lambda1 scalar type.
70  * @tparam T_lambda2 lambda2 scalar type.
71  *
72  * @param[in] theta mixing proportion in [0, 1].
73  * @param[in] lambda1 first log density.
74  * @param[in] lambda2 second log density.
75  * @return log mixture of densities in specified proportion
76  */
77 template <typename T_theta, typename T_lambda1, typename T_lambda2,
78           require_any_var_t<T_theta, T_lambda1, T_lambda2>* = nullptr>
log_mix(const T_theta & theta,const T_lambda1 & lambda1,const T_lambda2 & lambda2)79 inline return_type_t<T_theta, T_lambda1, T_lambda2> log_mix(
80     const T_theta& theta, const T_lambda1& lambda1, const T_lambda2& lambda2) {
81   using std::log;
82 
83   operands_and_partials<T_theta, T_lambda1, T_lambda2> ops_partials(
84       theta, lambda1, lambda2);
85 
86   double theta_double = value_of(theta);
87   const double lambda1_double = value_of(lambda1);
88   const double lambda2_double = value_of(lambda2);
89 
90   double log_mix_function_value
91       = log_mix(theta_double, lambda1_double, lambda2_double);
92 
93   double one_m_exp_lam2_m_lam1(0.0);
94   double one_m_t_prod_exp_lam2_m_lam1(0.0);
95   double one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1(0.0);
96 
97   if (lambda1 > lambda2) {
98     log_mix_partial_helper(theta_double, lambda1_double, lambda2_double,
99                            one_m_exp_lam2_m_lam1, one_m_t_prod_exp_lam2_m_lam1,
100                            one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1);
101   } else {
102     log_mix_partial_helper(1.0 - theta_double, lambda2_double, lambda1_double,
103                            one_m_exp_lam2_m_lam1, one_m_t_prod_exp_lam2_m_lam1,
104                            one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1);
105     one_m_exp_lam2_m_lam1 = -one_m_exp_lam2_m_lam1;
106     theta_double = one_m_t_prod_exp_lam2_m_lam1;
107     one_m_t_prod_exp_lam2_m_lam1 = 1.0 - value_of(theta);
108   }
109 
110   if (!is_constant_all<T_theta>::value) {
111     ops_partials.edge1_.partials_[0]
112         = one_m_exp_lam2_m_lam1 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
113   }
114   if (!is_constant_all<T_lambda1>::value) {
115     ops_partials.edge2_.partials_[0]
116         = theta_double * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
117   }
118   if (!is_constant_all<T_lambda2>::value) {
119     ops_partials.edge3_.partials_[0]
120         = one_m_t_prod_exp_lam2_m_lam1
121           * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
122   }
123 
124   return ops_partials.build(log_mix_function_value);
125 }
126 
127 }  // namespace math
128 }  // namespace stan
129 #endif
130