1 #ifndef STAN_MATH_REV_FUN_INC_BETA_HPP
2 #define STAN_MATH_REV_FUN_INC_BETA_HPP
3
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/pow.hpp>
7 #include <stan/math/prim/fun/beta.hpp>
8 #include <stan/math/prim/fun/digamma.hpp>
9 #include <stan/math/prim/fun/grad_reg_inc_beta.hpp>
10 #include <cmath>
11
12 namespace stan {
13 namespace math {
14
15 namespace internal {
16
17 class inc_beta_vvv_vari : public op_vvv_vari {
18 public:
inc_beta_vvv_vari(vari * avi,vari * bvi,vari * cvi)19 inc_beta_vvv_vari(vari* avi, vari* bvi, vari* cvi)
20 : op_vvv_vari(inc_beta(avi->val_, bvi->val_, cvi->val_), avi, bvi, cvi) {}
chain()21 void chain() {
22 double d_a;
23 double d_b;
24 const double beta_ab = beta(avi_->val_, bvi_->val_);
25 grad_reg_inc_beta(d_a, d_b, avi_->val_, bvi_->val_, cvi_->val_,
26 digamma(avi_->val_), digamma(bvi_->val_),
27 digamma(avi_->val_ + bvi_->val_), beta_ab);
28
29 avi_->adj_ += adj_ * d_a;
30 bvi_->adj_ += adj_ * d_b;
31 cvi_->adj_ += adj_ * std::pow(1 - cvi_->val_, bvi_->val_ - 1)
32 * std::pow(cvi_->val_, avi_->val_ - 1) / beta_ab;
33 }
34 };
35
36 } // namespace internal
37
inc_beta(const var & a,const var & b,const var & c)38 inline var inc_beta(const var& a, const var& b, const var& c) {
39 return var(new internal::inc_beta_vvv_vari(a.vi_, b.vi_, c.vi_));
40 }
41
42 } // namespace math
43 } // namespace stan
44 #endif
45