1 #ifndef STAN_MATH_REV_FUN_INC_BETA_HPP
2 #define STAN_MATH_REV_FUN_INC_BETA_HPP
3 
4 #include <stan/math/rev/meta.hpp>
5 #include <stan/math/rev/core.hpp>
6 #include <stan/math/rev/fun/pow.hpp>
7 #include <stan/math/prim/fun/beta.hpp>
8 #include <stan/math/prim/fun/digamma.hpp>
9 #include <stan/math/prim/fun/grad_reg_inc_beta.hpp>
10 #include <cmath>
11 
12 namespace stan {
13 namespace math {
14 
15 namespace internal {
16 
17 class inc_beta_vvv_vari : public op_vvv_vari {
18  public:
inc_beta_vvv_vari(vari * avi,vari * bvi,vari * cvi)19   inc_beta_vvv_vari(vari* avi, vari* bvi, vari* cvi)
20       : op_vvv_vari(inc_beta(avi->val_, bvi->val_, cvi->val_), avi, bvi, cvi) {}
chain()21   void chain() {
22     double d_a;
23     double d_b;
24     const double beta_ab = beta(avi_->val_, bvi_->val_);
25     grad_reg_inc_beta(d_a, d_b, avi_->val_, bvi_->val_, cvi_->val_,
26                       digamma(avi_->val_), digamma(bvi_->val_),
27                       digamma(avi_->val_ + bvi_->val_), beta_ab);
28 
29     avi_->adj_ += adj_ * d_a;
30     bvi_->adj_ += adj_ * d_b;
31     cvi_->adj_ += adj_ * std::pow(1 - cvi_->val_, bvi_->val_ - 1)
32                   * std::pow(cvi_->val_, avi_->val_ - 1) / beta_ab;
33   }
34 };
35 
36 }  // namespace internal
37 
inc_beta(const var & a,const var & b,const var & c)38 inline var inc_beta(const var& a, const var& b, const var& c) {
39   return var(new internal::inc_beta_vvv_vari(a.vi_, b.vi_, c.vi_));
40 }
41 
42 }  // namespace math
43 }  // namespace stan
44 #endif
45