1 #ifndef STAN_MATH_FWD_SCAL_FUN_LOG_MIX_HPP
2 #define STAN_MATH_FWD_SCAL_FUN_LOG_MIX_HPP
3 
4 #include <stan/math/fwd/core.hpp>
5 #include <stan/math/prim/scal/fun/value_of.hpp>
6 #include <stan/math/prim/scal/fun/log_mix.hpp>
7 #include <boost/math/tools/promotion.hpp>
8 #include <cmath>
9 #include <type_traits>
10 
11 namespace stan {
12 namespace math {
13 
14 /* Returns an array of size N with partials of log_mix wrt to its
15  * parameters instantiated as fvar<T>
16  *
17  * @tparam T_theta theta scalar type
18  * @tparam T_lambda1 lambda_1 scalar type
19  * @tparam T_lambda2 lambda_2 scalar type
20  *
21  * @param[in] N output array size
22  * @param[in] theta_d mixing proportion theta
23  * @param[in] lambda1_d log_density with mixing proportion theta
24  * @param[in] lambda2_d log_density with mixing proportion 1.0 - theta
25  * @param[out] partials_array array of partials derivatives
26  */
27 template <typename T_theta, typename T_lambda1, typename T_lambda2, int N>
log_mix_partial_helper(const T_theta & theta,const T_lambda1 & lambda1,const T_lambda2 & lambda2,typename boost::math::tools::promote_args<T_theta,T_lambda1,T_lambda2>::type (& partials_array)[N])28 inline void log_mix_partial_helper(
29     const T_theta& theta, const T_lambda1& lambda1, const T_lambda2& lambda2,
30     typename boost::math::tools::promote_args<
31         T_theta, T_lambda1, T_lambda2>::type (&partials_array)[N]) {
32   using boost::math::tools::promote_args;
33   using std::exp;
34   typedef typename promote_args<T_theta, T_lambda1, T_lambda2>::type
35       partial_return_type;
36 
37   typename promote_args<T_lambda1, T_lambda2>::type lam2_m_lam1
38       = lambda2 - lambda1;
39   typename promote_args<T_lambda1, T_lambda2>::type exp_lam2_m_lam1
40       = exp(lam2_m_lam1);
41   typename promote_args<T_lambda1, T_lambda2>::type one_m_exp_lam2_m_lam1
42       = 1.0 - exp_lam2_m_lam1;
43   typename promote_args<double, T_theta>::type one_m_t = 1.0 - theta;
44   partial_return_type one_m_t_prod_exp_lam2_m_lam1 = one_m_t * exp_lam2_m_lam1;
45   partial_return_type t_plus_one_m_t_prod_exp_lam2_m_lam1
46       = theta + one_m_t_prod_exp_lam2_m_lam1;
47   partial_return_type one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1
48       = 1.0 / t_plus_one_m_t_prod_exp_lam2_m_lam1;
49 
50   unsigned int offset = 0;
51   if (std::is_same<T_theta, partial_return_type>::value) {
52     partials_array[offset]
53         = one_m_exp_lam2_m_lam1 * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
54     ++offset;
55   }
56   if (std::is_same<T_lambda1, partial_return_type>::value) {
57     partials_array[offset] = theta * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
58     ++offset;
59   }
60   if (std::is_same<T_lambda2, partial_return_type>::value) {
61     partials_array[offset] = one_m_t_prod_exp_lam2_m_lam1
62                              * one_d_t_plus_one_m_t_prod_exp_lam2_m_lam1;
63   }
64 }
65 
66 /**
67  * Return the log mixture density with specified mixing proportion
68  * and log densities and its derivative at each.
69  *
70  * \f[
71  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
72  * = \log \left( \theta \exp(\lambda_1)
73    + (1 - \theta) \exp(\lambda_2) \right).
74  * \f]
75  *
76  * \f[
77  * \frac{\partial}{\partial \theta}
78  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
79  * = \dfrac{\exp(\lambda_1) - \exp(\lambda_2)}
80  * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
81  * \f]
82  *
83  * \f[
84  * \frac{\partial}{\partial \lambda_1}
85  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
86  * = \dfrac{\theta \exp(\lambda_1)}
87  * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
88  * \f]
89  *
90  * \f[
91  * \frac{\partial}{\partial \lambda_2}
92  * \mbox{log\_mix}(\theta, \lambda_1, \lambda_2)
93  * = \dfrac{\theta \exp(\lambda_2)}
94  * {\left( \theta \exp(\lambda_1) + (1 - \theta) \exp(\lambda_2) \right)}
95  * \f]
96  *
97  * @tparam T scalar type.
98  *
99  * @param[in] theta mixing proportion in [0, 1].
100  * @param[in] lambda1 first log density.
101  * @param[in] lambda2 second log density.
102  *
103  * @return log mixture of densities in specified proportion
104  */
105 template <typename T>
log_mix(const fvar<T> & theta,const fvar<T> & lambda1,const fvar<T> & lambda2)106 inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
107                        const fvar<T>& lambda2) {
108   if (lambda1.val_ > lambda2.val_) {
109     fvar<T> partial_deriv_array[3];
110     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
111     return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2.val_),
112                    theta.d_ * value_of(partial_deriv_array[0])
113                        + lambda1.d_ * value_of(partial_deriv_array[1])
114                        + lambda2.d_ * value_of(partial_deriv_array[2]));
115   } else {
116     fvar<T> partial_deriv_array[3];
117     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
118     return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2.val_),
119                    -theta.d_ * value_of(partial_deriv_array[0])
120                        + lambda1.d_ * value_of(partial_deriv_array[2])
121                        + lambda2.d_ * value_of(partial_deriv_array[1]));
122   }
123 }
124 
125 template <typename T>
log_mix(const fvar<T> & theta,const fvar<T> & lambda1,double lambda2)126 inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
127                        double lambda2) {
128   if (lambda1.val_ > lambda2) {
129     fvar<T> partial_deriv_array[2];
130     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
131     return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2),
132                    theta.d_ * value_of(partial_deriv_array[0])
133                        + lambda1.d_ * value_of(partial_deriv_array[1]));
134   } else {
135     fvar<T> partial_deriv_array[2];
136     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
137     return fvar<T>(log_mix(theta.val_, lambda1.val_, lambda2),
138                    -theta.d_ * value_of(partial_deriv_array[0])
139                        + lambda1.d_ * value_of(partial_deriv_array[1]));
140   }
141 }
142 
143 template <typename T>
log_mix(const fvar<T> & theta,double lambda1,const fvar<T> & lambda2)144 inline fvar<T> log_mix(const fvar<T>& theta, double lambda1,
145                        const fvar<T>& lambda2) {
146   if (lambda1 > lambda2.val_) {
147     fvar<T> partial_deriv_array[2];
148     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
149     return fvar<T>(log_mix(theta.val_, lambda1, lambda2.val_),
150                    theta.d_ * value_of(partial_deriv_array[0])
151                        + lambda2.d_ * value_of(partial_deriv_array[1]));
152   } else {
153     fvar<T> partial_deriv_array[2];
154     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
155     return fvar<T>(log_mix(theta.val_, lambda1, lambda2.val_),
156                    -theta.d_ * value_of(partial_deriv_array[0])
157                        + lambda2.d_ * value_of(partial_deriv_array[1]));
158   }
159 }
160 
161 template <typename T>
log_mix(double theta,const fvar<T> & lambda1,const fvar<T> & lambda2)162 inline fvar<T> log_mix(double theta, const fvar<T>& lambda1,
163                        const fvar<T>& lambda2) {
164   if (lambda1.val_ > lambda2.val_) {
165     fvar<T> partial_deriv_array[2];
166     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
167     return fvar<T>(log_mix(theta, lambda1.val_, lambda2.val_),
168                    lambda1.d_ * value_of(partial_deriv_array[0])
169                        + lambda2.d_ * value_of(partial_deriv_array[1]));
170   } else {
171     fvar<T> partial_deriv_array[2];
172     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
173     return fvar<T>(log_mix(theta, lambda1.val_, lambda2.val_),
174                    lambda1.d_ * value_of(partial_deriv_array[1])
175                        + lambda2.d_ * value_of(partial_deriv_array[0]));
176   }
177 }
178 
179 template <typename T>
log_mix(const fvar<T> & theta,double lambda1,double lambda2)180 inline fvar<T> log_mix(const fvar<T>& theta, double lambda1, double lambda2) {
181   if (lambda1 > lambda2) {
182     fvar<T> partial_deriv_array[1];
183     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
184     return fvar<T>(log_mix(theta.val_, lambda1, lambda2),
185                    theta.d_ * value_of(partial_deriv_array[0]));
186   } else {
187     fvar<T> partial_deriv_array[1];
188     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
189     return fvar<T>(log_mix(theta.val_, lambda1, lambda2),
190                    -theta.d_ * value_of(partial_deriv_array[0]));
191   }
192 }
193 
194 template <typename T>
log_mix(double theta,const fvar<T> & lambda1,double lambda2)195 inline fvar<T> log_mix(double theta, const fvar<T>& lambda1, double lambda2) {
196   if (lambda1.val_ > lambda2) {
197     fvar<T> partial_deriv_array[1];
198     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
199     return fvar<T>(log_mix(theta, lambda1.val_, lambda2),
200                    lambda1.d_ * value_of(partial_deriv_array[0]));
201   } else {
202     fvar<T> partial_deriv_array[1];
203     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
204     return fvar<T>(log_mix(theta, lambda1.val_, lambda2),
205                    lambda1.d_ * value_of(partial_deriv_array[0]));
206   }
207 }
208 
209 template <typename T>
log_mix(double theta,double lambda1,const fvar<T> & lambda2)210 inline fvar<T> log_mix(double theta, double lambda1, const fvar<T>& lambda2) {
211   if (lambda1 > lambda2.val_) {
212     fvar<T> partial_deriv_array[1];
213     log_mix_partial_helper(theta, lambda1, lambda2, partial_deriv_array);
214     return fvar<T>(log_mix(theta, lambda1, lambda2.val_),
215                    lambda2.d_ * value_of(partial_deriv_array[0]));
216   } else {
217     fvar<T> partial_deriv_array[1];
218     log_mix_partial_helper(1.0 - theta, lambda2, lambda1, partial_deriv_array);
219     return fvar<T>(log_mix(theta, lambda1, lambda2.val_),
220                    lambda2.d_ * value_of(partial_deriv_array[0]));
221   }
222 }
223 }  // namespace math
224 }  // namespace stan
225 #endif
226