1 #ifndef STAN_MATH_PRIM_FUN_GRAD_F32_HPP
2 #define STAN_MATH_PRIM_FUN_GRAD_F32_HPP
3
4 #include <stan/math/prim/meta.hpp>
5 #include <stan/math/prim/err.hpp>
6 #include <stan/math/prim/fun/constants.hpp>
7 #include <stan/math/prim/fun/exp.hpp>
8 #include <stan/math/prim/fun/fabs.hpp>
9 #include <stan/math/prim/fun/inv.hpp>
10 #include <stan/math/prim/fun/log.hpp>
11 #include <cmath>
12
13 namespace stan {
14 namespace math {
15
16 /**
17 * Gradients of the hypergeometric function, 3F2.
18 *
19 * Calculate the gradients of the hypergeometric function (3F2)
20 * as the power series stopping when the series converges
21 * to within <code>precision</code> or throwing when the
22 * function takes <code>max_steps</code> steps.
23 *
24 * This power-series representation converges for all gradients
25 * under the same conditions as the 3F2 function itself.
26 *
27 * @tparam T type of arguments and result
28 * @param[out] g g pointer to array of six values of type T, result.
29 * @param[in] a1 a1 see generalized hypergeometric function definition.
30 * @param[in] a2 a2 see generalized hypergeometric function definition.
31 * @param[in] a3 a3 see generalized hypergeometric function definition.
32 * @param[in] b1 b1 see generalized hypergeometric function definition.
33 * @param[in] b2 b2 see generalized hypergeometric function definition.
34 * @param[in] z z see generalized hypergeometric function definition.
35 * @param[in] precision precision of the infinite sum
36 * @param[in] max_steps number of steps to take
37 */
38 template <typename T>
grad_F32(T * g,const T & a1,const T & a2,const T & a3,const T & b1,const T & b2,const T & z,const T & precision=1e-6,int max_steps=1e5)39 void grad_F32(T* g, const T& a1, const T& a2, const T& a3, const T& b1,
40 const T& b2, const T& z, const T& precision = 1e-6,
41 int max_steps = 1e5) {
42 check_3F2_converges("grad_F32", a1, a2, a3, b1, b2, z);
43
44 using std::exp;
45 using std::fabs;
46 using std::log;
47
48 for (int i = 0; i < 6; ++i) {
49 g[i] = 0.0;
50 }
51
52 T log_g_old[6];
53 for (auto& x : log_g_old) {
54 x = NEGATIVE_INFTY;
55 }
56
57 T log_t_old = 0.0;
58 T log_t_new = 0.0;
59
60 T log_z = log(z);
61
62 double log_t_new_sign = 1.0;
63 double log_t_old_sign = 1.0;
64 double log_g_old_sign[6];
65 for (int i = 0; i < 6; ++i) {
66 log_g_old_sign[i] = 1.0;
67 }
68
69 for (int k = 0; k <= max_steps; ++k) {
70 T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (1 + k));
71 if (p == 0) {
72 return;
73 }
74
75 log_t_new += log(fabs(p)) + log_z;
76 log_t_new_sign = p >= 0.0 ? log_t_new_sign : -log_t_new_sign;
77
78 // g_old[0] = t_new * (g_old[0] / t_old + 1.0 / (a1 + k));
79 T term = log_g_old_sign[0] * log_t_old_sign * exp(log_g_old[0] - log_t_old)
80 + inv(a1 + k);
81 log_g_old[0] = log_t_new + log(fabs(term));
82 log_g_old_sign[0] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
83
84 // g_old[1] = t_new * (g_old[1] / t_old + 1.0 / (a2 + k));
85 term = log_g_old_sign[1] * log_t_old_sign * exp(log_g_old[1] - log_t_old)
86 + inv(a2 + k);
87 log_g_old[1] = log_t_new + log(fabs(term));
88 log_g_old_sign[1] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
89
90 // g_old[2] = t_new * (g_old[2] / t_old + 1.0 / (a3 + k));
91 term = log_g_old_sign[2] * log_t_old_sign * exp(log_g_old[2] - log_t_old)
92 + inv(a3 + k);
93 log_g_old[2] = log_t_new + log(fabs(term));
94 log_g_old_sign[2] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
95
96 // g_old[3] = t_new * (g_old[3] / t_old - 1.0 / (b1 + k));
97 term = log_g_old_sign[3] * log_t_old_sign * exp(log_g_old[3] - log_t_old)
98 - inv(b1 + k);
99 log_g_old[3] = log_t_new + log(fabs(term));
100 log_g_old_sign[3] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
101
102 // g_old[4] = t_new * (g_old[4] / t_old - 1.0 / (b2 + k));
103 term = log_g_old_sign[4] * log_t_old_sign * exp(log_g_old[4] - log_t_old)
104 - inv(b2 + k);
105 log_g_old[4] = log_t_new + log(fabs(term));
106 log_g_old_sign[4] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
107
108 // g_old[5] = t_new * (g_old[5] / t_old + 1.0 / z);
109 term = log_g_old_sign[5] * log_t_old_sign * exp(log_g_old[5] - log_t_old)
110 + inv(z);
111 log_g_old[5] = log_t_new + log(fabs(term));
112 log_g_old_sign[5] = term >= 0.0 ? log_t_new_sign : -log_t_new_sign;
113
114 for (int i = 0; i < 6; ++i) {
115 g[i] += log_g_old_sign[i] * exp(log_g_old[i]);
116 }
117
118 if (log_t_new <= log(precision)) {
119 return; // implicit abs
120 }
121
122 log_t_old = log_t_new;
123 log_t_old_sign = log_t_new_sign;
124 }
125 throw_domain_error("grad_F32", "k (internal counter)", max_steps, "exceeded ",
126 " iterations, hypergeometric function gradient "
127 "did not converge.");
128 return;
129 }
130
131 } // namespace math
132 } // namespace stan
133 #endif
134