1 #ifndef STAN_MATH_REV_FUN_LOG_DIFF_EXP_HPP
2 #define STAN_MATH_REV_FUN_LOG_DIFF_EXP_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/fun/constants.hpp>
7 #include <stan/math/prim/fun/expm1.hpp>
8 #include <stan/math/prim/fun/log_diff_exp.hpp>
9 
10 namespace stan {
11 namespace math {
12 
13 namespace internal {
14 class log_diff_exp_vv_vari : public op_vv_vari {
15  public:
log_diff_exp_vv_vari(vari * avi,vari * bvi)16   log_diff_exp_vv_vari(vari* avi, vari* bvi)
17       : op_vv_vari(log_diff_exp(avi->val_, bvi->val_), avi, bvi) {}
chain()18   void chain() {
19     avi_->adj_ -= adj_ / expm1(bvi_->val_ - avi_->val_);
20     bvi_->adj_ -= adj_ / expm1(avi_->val_ - bvi_->val_);
21   }
22 };
23 class log_diff_exp_vd_vari : public op_vd_vari {
24  public:
log_diff_exp_vd_vari(vari * avi,double b)25   log_diff_exp_vd_vari(vari* avi, double b)
26       : op_vd_vari(log_diff_exp(avi->val_, b), avi, b) {}
chain()27   void chain() {
28     if (val_ == NEGATIVE_INFTY) {
29       avi_->adj_ += (bd_ == NEGATIVE_INFTY) ? adj_ : adj_ * INFTY;
30     } else {
31       avi_->adj_ -= adj_ / expm1(bd_ - avi_->val_);
32     }
33   }
34 };
35 class log_diff_exp_dv_vari : public op_dv_vari {
36  public:
log_diff_exp_dv_vari(double a,vari * bvi)37   log_diff_exp_dv_vari(double a, vari* bvi)
38       : op_dv_vari(log_diff_exp(a, bvi->val_), a, bvi) {}
chain()39   void chain() {
40     if (val_ == NEGATIVE_INFTY) {
41       bvi_->adj_ -= adj_ * INFTY;
42     } else {
43       bvi_->adj_ -= adj_ / expm1(ad_ - bvi_->val_);
44     }
45   }
46 };
47 }  // namespace internal
48 
49 /**
50  * Returns the log difference of the exponentiated arguments.
51  *
52  * @param[in] a First argument.
53  * @param[in] b Second argument.
54  * @return Log difference of the exponentiated arguments.
55  */
log_diff_exp(const var & a,const var & b)56 inline var log_diff_exp(const var& a, const var& b) {
57   return var(new internal::log_diff_exp_vv_vari(a.vi_, b.vi_));
58 }
59 
60 /**
61  * Returns the log difference of the exponentiated arguments.
62  *
63  * @param[in] a First argument.
64  * @param[in] b Second argument.
65  * @return Log difference of the exponentiated arguments.
66  */
log_diff_exp(const var & a,double b)67 inline var log_diff_exp(const var& a, double b) {
68   return var(new internal::log_diff_exp_vd_vari(a.vi_, b));
69 }
70 
71 /**
72  * Returns the log difference of the exponentiated arguments.
73  *
74  * @param[in] a First argument.
75  * @param[in] b Second argument.
76  * @return Log difference of the exponentiated arguments.
77  */
log_diff_exp(double a,const var & b)78 inline var log_diff_exp(double a, const var& b) {
79   return var(new internal::log_diff_exp_dv_vari(a, b.vi_));
80 }
81 
82 }  // namespace math
83 }  // namespace stan
84 #endif
85