1 #ifndef STAN_MATH_REV_FUN_LOG_DIFF_EXP_HPP
2 #define STAN_MATH_REV_FUN_LOG_DIFF_EXP_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/prim/fun/constants.hpp>
7 #include <stan/math/prim/fun/expm1.hpp>
8 #include <stan/math/prim/fun/log_diff_exp.hpp>
9
10 namespace stan {
11 namespace math {
12
13 namespace internal {
14 class log_diff_exp_vv_vari : public op_vv_vari {
15 public:
log_diff_exp_vv_vari(vari * avi,vari * bvi)16 log_diff_exp_vv_vari(vari* avi, vari* bvi)
17 : op_vv_vari(log_diff_exp(avi->val_, bvi->val_), avi, bvi) {}
chain()18 void chain() {
19 avi_->adj_ -= adj_ / expm1(bvi_->val_ - avi_->val_);
20 bvi_->adj_ -= adj_ / expm1(avi_->val_ - bvi_->val_);
21 }
22 };
23 class log_diff_exp_vd_vari : public op_vd_vari {
24 public:
log_diff_exp_vd_vari(vari * avi,double b)25 log_diff_exp_vd_vari(vari* avi, double b)
26 : op_vd_vari(log_diff_exp(avi->val_, b), avi, b) {}
chain()27 void chain() {
28 if (val_ == NEGATIVE_INFTY) {
29 avi_->adj_ += (bd_ == NEGATIVE_INFTY) ? adj_ : adj_ * INFTY;
30 } else {
31 avi_->adj_ -= adj_ / expm1(bd_ - avi_->val_);
32 }
33 }
34 };
35 class log_diff_exp_dv_vari : public op_dv_vari {
36 public:
log_diff_exp_dv_vari(double a,vari * bvi)37 log_diff_exp_dv_vari(double a, vari* bvi)
38 : op_dv_vari(log_diff_exp(a, bvi->val_), a, bvi) {}
chain()39 void chain() {
40 if (val_ == NEGATIVE_INFTY) {
41 bvi_->adj_ -= adj_ * INFTY;
42 } else {
43 bvi_->adj_ -= adj_ / expm1(ad_ - bvi_->val_);
44 }
45 }
46 };
47 } // namespace internal
48
49 /**
50 * Returns the log difference of the exponentiated arguments.
51 *
52 * @param[in] a First argument.
53 * @param[in] b Second argument.
54 * @return Log difference of the exponentiated arguments.
55 */
log_diff_exp(const var & a,const var & b)56 inline var log_diff_exp(const var& a, const var& b) {
57 return var(new internal::log_diff_exp_vv_vari(a.vi_, b.vi_));
58 }
59
60 /**
61 * Returns the log difference of the exponentiated arguments.
62 *
63 * @param[in] a First argument.
64 * @param[in] b Second argument.
65 * @return Log difference of the exponentiated arguments.
66 */
log_diff_exp(const var & a,double b)67 inline var log_diff_exp(const var& a, double b) {
68 return var(new internal::log_diff_exp_vd_vari(a.vi_, b));
69 }
70
71 /**
72 * Returns the log difference of the exponentiated arguments.
73 *
74 * @param[in] a First argument.
75 * @param[in] b Second argument.
76 * @return Log difference of the exponentiated arguments.
77 */
log_diff_exp(double a,const var & b)78 inline var log_diff_exp(double a, const var& b) {
79 return var(new internal::log_diff_exp_dv_vari(a, b.vi_));
80 }
81
82 } // namespace math
83 } // namespace stan
84 #endif
85