1 #ifndef STAN_MATH_REV_FUN_LOG_MIX_HPP
2 #define STAN_MATH_REV_FUN_LOG_MIX_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/fun/value_of.hpp>
6 #include <stan/math/prim/fun/log_mix.hpp>
7 #include <stan/math/prim/fun/value_of.hpp>
8 #include <stan/math/prim/functor/operands_and_partials.hpp>
9 #include <cmath>
10
11 namespace stan {
12 namespace math {
13
14 /* Computes shared terms in log_mix partial derivative calculations
15 *
16 * @param[in] theta_val value of mixing proportion theta.
17 * @param[in] lambda1_val value of log density multiplied by theta.
18 * @param[in] lambda2_val value of log density multiplied by 1 - theta.
19 * @param[out] one_m_exp_lam2_m_lam1 shared term in deriv calculation.
20 * @param[out] one_m_t_prod_exp_lam2_m_lam1 shared term in deriv calculation.
21 * @param[out] one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1 shared term in deriv
22 * calculation.
23 */
log_mix_partial_helper(double theta_val,double lambda1_val,double lambda2_val,double & one_m_exp_lam2_m_lam1,double & one_m_t_prod_exp_lam2_m_lam1,double & one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1)24 inline void log_mix_partial_helper(
25 double theta_val, double lambda1_val, double lambda2_val,
26 double& one_m_exp_lam2_m_lam1, double& one_m_t_prod_exp_lam2_m_lam1,
27 double& one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1) {
28 using std::exp;
29 double lam2_m_lam1 = lambda2_val - lambda1_val;
30 double exp_lam2_m_lam1 = exp(lam2_m_lam1);
31 one_m_exp_lam2_m_lam1 = 1 - exp_lam2_m_lam1;
32 double one_m_t = 1 - theta_val;
33 one_m_t_prod_exp_lam2_m_lam1 = one_m_t * exp_lam2_m_lam1;
34 one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1
35 = 1 / (theta_val + one_m_t_prod_exp_lam2_m_lam1);
36 }
37
38 /**
39 * Return the log mixture density with specified mixing proportion
40 * and log densities and its derivative at each.
41 *
42 * \f[
43 * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
44 * = \log \left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right).
45 * \f]
46 *
47 * \f[
48 * \frac{\partial}{\partial \theta}
49 * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
50 * = \dfrac{\exp(\lambda_1) - \exp(\lambda_2)}
51 * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
52 * \f]
53 *
54 * \f[
55 * \frac{\partial}{\partial \lambda_1}
56 * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
57 * = \dfrac{\theta \exp(\lambda_1)}
58 * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
59 * \f]
60 *
61 * \f[
62 * \frac{\partial}{\partial \lambda_2}
63 * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
64 * = \dfrac{\theta \exp(\lambda_2)}
65 * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
66 * \f]
67 *
68 * @tparam T_theta theta scalar type.
69 * @tparam T_lambda1 lambda1 scalar type.
70 * @tparam T_lambda2 lambda2 scalar type.
71 *
72 * @param[in] theta mixing proportion in [0, 1].
73 * @param[in] lambda1 first log density.
74 * @param[in] lambda2 second log density.
75 * @return log mixture of densities in specified proportion
76 */
77 template <typename T_theta, typename T_lambda1, typename T_lambda2,
78 require_any_var_t<T_theta, T_lambda1, T_lambda2>* = nullptr>
log_mix(const T_theta & theta,const T_lambda1 & lambda1,const T_lambda2 & lambda2)79 inline return_type_t<T_theta, T_lambda1, T_lambda2> log_mix(
80 const T_theta& theta, const T_lambda1& lambda1, const T_lambda2& lambda2) {
81 using std::log;
82
83 operands_and_partials<T_theta, T_lambda1, T_lambda2> ops_partials(
84 theta, lambda1, lambda2);
85
86 double theta_double = value_of(theta);
87 const double lambda1_double = value_of(lambda1);
88 const double lambda2_double = value_of(lambda2);
89
90 double log_mix_function_value
91 = log_mix(theta_double, lambda1_double, lambda2_double);
92
93 double one_m_exp_lam2_m_lam1(0.0);
94 double one_m_t_prod_exp_lam2_m_lam1(0.0);
95 double one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1(0.0);
96
97 if (lambda1 > lambda2) {
98 log_mix_partial_helper(theta_double, lambda1_double, lambda2_double,
99 one_m_exp_lam2_m_lam1, one_m_t_prod_exp_lam2_m_lam1,
100 one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1);
101 } else {
102 log_mix_partial_helper(1.0 - theta_double, lambda2_double, lambda1_double,
103 one_m_exp_lam2_m_lam1, one_m_t_prod_exp_lam2_m_lam1,
104 one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1);
105 one_m_exp_lam2_m_lam1 = -one_m_exp_lam2_m_lam1;
106 theta_double = one_m_t_prod_exp_lam2_m_lam1;
107 one_m_t_prod_exp_lam2_m_lam1 = 1.0 - value_of(theta);
108 }
109
110 if (!is_constant_all<T_theta>::value) {
111 ops_partials.edge1_.partials_[0]
112 = one_m_exp_lam2_m_lam1 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
113 }
114 if (!is_constant_all<T_lambda1>::value) {
115 ops_partials.edge2_.partials_[0]
116 = theta_double * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
117 }
118 if (!is_constant_all<T_lambda2>::value) {
119 ops_partials.edge3_.partials_[0]
120 = one_m_t_prod_exp_lam2_m_lam1
121 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
122 }
123
124 return ops_partials.build(log_mix_function_value);
125 }
126
127 } // namespace math
128 } // namespace stan
129 #endif
130