1 #ifndef STAN_MATH_PRIM_FUN_GRAD_F32_HPP
2 #define STAN_MATH_PRIM_FUN_GRAD_F32_HPP
3 
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/constants.hpp>
7 #include <stan/math/prim/fun/exp.hpp>
8 #include <stan/math/prim/fun/fabs.hpp>
9 #include <stan/math/prim/fun/inv.hpp>
10 #include <stan/math/prim/fun/log.hpp>
11 #include <cmath>
12 
13 namespace stan {
14 namespace math {
15 
16 /**
17  * Gradients of the hypergeometric function, 3F2.
18  *
19  * Calculate the gradients of the hypergeometric function (3F2)
20  * as the power series stopping when the series converges
21  * to within <code>precision</code> or throwing when the
22  * function takes <code>max_steps</code> steps.
23  *
24  * This power-series representation converges for all gradients
25  * under the same conditions as the 3F2 function itself.
26  *
27  * @tparam T type of arguments and result
28  * @param[out] g g pointer to array of six values of type T, result.
29  * @param[in] a1 a1 see generalized hypergeometric function definition.
30  * @param[in] a2 a2 see generalized hypergeometric function definition.
31  * @param[in] a3 a3 see generalized hypergeometric function definition.
32  * @param[in] b1 b1 see generalized hypergeometric function definition.
33  * @param[in] b2 b2 see generalized hypergeometric function definition.
34  * @param[in] z z see generalized hypergeometric function definition.
35  * @param[in] precision precision of the infinite sum
36  * @param[in] max_steps number of steps to take
37  */
38 template <typename T>
grad_F32(T * g,const T & a1,const T & a2,const T & a3,const T & b1,const T & b2,const T & z,const T & precision=1e-6,int max_steps=1e5)39 void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
40               const T& b2, const T& z, const T& precision = 1e-6,
41               int max_steps = 1e5) {
42   check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);
43 
44   using std::exp;
45   using std::fabs;
46   using std::log;
47 
48   for (int i = 0; i < 6; ++i) {
49     g[i] = 0.0;
50   }
51 
52   T log_g_old[6];
53   for (auto& x : log_g_old) {
54     x = NEGATIVE_INFTY;
55   }
56 
57   T log_t_old = 0.0;
58   T log_t_new = 0.0;
59 
60   T log_z = log(z);
61 
62   double log_t_new_sign = 1.0;
63   double log_t_old_sign = 1.0;
64   double log_g_old_sign[6];
65   for (int i = 0; i < 6; ++i) {
66     log_g_old_sign[i] = 1.0;
67   }
68 
69   for (int k = 0; k <= max_steps; ++k) {
70     T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
71     if (p == 0) {
72       return;
73     }
74 
75     log_t_new += log(fabs(p)) + log_z;
76     log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
77 
78     //        g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
79     T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
80              + inv(a1 + k);
81     log_g_old[0] = log_t_new + log(fabs(term));
82     log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
83 
84     //        g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
85     term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
86            + inv(a2 + k);
87     log_g_old[1] = log_t_new + log(fabs(term));
88     log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89 
90     //        g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
91     term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
92            + inv(a3 + k);
93     log_g_old[2] = log_t_new + log(fabs(term));
94     log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95 
96     //        g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
97     term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
98            - inv(b1 + k);
99     log_g_old[3] = log_t_new + log(fabs(term));
100     log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
101 
102     //        g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
103     term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
104            - inv(b2 + k);
105     log_g_old[4] = log_t_new + log(fabs(term));
106     log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
107 
108     //        g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
109     term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
110            + inv(z);
111     log_g_old[5] = log_t_new + log(fabs(term));
112     log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113 
114     for (int i = 0; i < 6; ++i) {
115       g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
116     }
117 
118     if (log_t_new <= log(precision)) {
119       return;  // implicit abs
120     }
121 
122     log_t_old = log_t_new;
123     log_t_old_sign = log_t_new_sign;
124   }
125   throw_domain_error("grad_F32", "k (internal counter)", max_steps, "exceeded ",
126                      " iterations, hypergeometric function gradient "
127                      "did not converge.");
128   return;
129 }
130 
131 }  // namespace math
132 }  // namespace stan
133 #endif
134