1 #include <stan/math/rev.hpp>
2 #include <test/unit/math/expect_near_rel.hpp>
3 #include <gtest/gtest.h>
4 #include <boost/math/differentiation/finite_difference.hpp>
5 #include <boost/math/special_functions/digamma.hpp>
6 #include <algorithm>
7 #include <limits>
8 #include <vector>
9
10 namespace neg_binomial_2_log_test_internal {
11 struct TestValue {
12 int n;
13 double eta;
14 double phi;
15 double value;
16 double grad_eta;
17 double grad_phi;
18
TestValueneg_binomial_2_log_test_internal::TestValue19 TestValue(int _n, double _eta, double _phi, double _value, double _grad_eta,
20 double _grad_phi)
21 : n(_n),
22 eta(_eta),
23 phi(_phi),
24 value(_value),
25 grad_eta(_grad_eta),
26 grad_phi(_grad_phi) {}
27 };
28
29 // Test data generated in Mathematica (Wolfram Cloud). The code can be re-ran at
30 // https://www.wolframcloud.com/obj/martin.modrak/Published/neg_binomial_2_log_lpmf.nb
31 // but is also presented below for convenience:
32 //
33 // toCString[x_] := ToString[CForm[N[x, 24]]];
34 // nb2log[n_,logmu_,phi_]:= LogGamma[n + phi] - LogGamma[n + 1] -
35 // LogGamma[phi] + n * (logmu - Log[Exp[logmu] + phi]) +
36 // phi * (Log[phi] - Log[Exp[logmu] + phi])
37 // nb2logdmu[n_,logmu_,phi_]= D[nb2log[n, logmu, phi],logmu];
38 // nb2logdphi[n_,logmu_,phi_]= D[nb2log[n, logmu, phi],phi];
39 // logmus= SetPrecision[{-36,-26.1, -4.5,-0.2,2.0,13.4, 84.3}, Infinity];
40 // phis= SetPrecision[{0.0001,0.084,3.76,461.0, 142311.0}, Infinity];
41 // ns = {0,7,11,2338,9611};
42 // out = "std::vector<TestValue> testValues = {\n";
43 // For[k = 1, k <= Length[ns], k++, {
44 // For[i = 1, i <= Length[logmus], i++, {
45 // For[j = 1, j <= Length[phis], j++, {
46 // clogmu = logmus[[i]];
47 // cphi = phis[[j]];
48 // cn=ns[[k]];
49 // val = nb2log[cn,clogmu,cphi];
50 // ddmu= nb2logdmu[cn,clogmu,cphi];
51 // ddphi= nb2logdphi[cn,clogmu,cphi];
52 // out = StringJoin[out," {",ToString[cn],",",toCString[clogmu],",",
53 // toCString[cphi],",", toCString[val],",",toCString[ddmu],",",
54 // toCString[ddphi],"},\n"];
55 // }]
56 // }]
57 // }]
58 // out = StringJoin[out,"};\n"];
59 // out
60
61 std::vector<TestValue> testValues = {
62 {0, -36., 0.000100000000000000004792174, -2.31952283024087929523226e-16,
63 -2.31952283023818920215225e-16, -2.69009308000224930599517e-24},
64 {0, -36., 0.0840000000000000052180482, -2.3195228302435661858205e-16,
65 -2.31952283024356298332874e-16, -3.81249019275872869665102e-30},
66 {0, -36., 3.75999999999999978683718, -2.31952283024356931676723e-16,
67 -2.31952283024356924522221e-16, -1.9027933171192913158222e-33},
68 {0, -36., 461., -2.31952283024356938772873e-16,
69 -2.3195228302435693871452e-16, -1.26580106437037714160219e-37},
70 {0, -36., 142311., -2.31952283024356938831037e-16,
71 -2.31952283024356938830848e-16, -1.32828224194512041655272e-42},
72 {0, -26.1000000000000014210855, 0.000100000000000000004792174,
73 -4.62289481781287824178924e-12, -4.62289471095709740559354e-12,
74 -1.06855780836195688701657e-15},
75 {0, -26.1000000000000014210855, 0.0840000000000000052180482,
76 -4.62289492454145310046865e-12, -4.62289492441424382973675e-12,
77 -1.51439608014165778339631e-21},
78 {0, -26.1000000000000014210855, 3.75999999999999978683718,
79 -4.62289492466582046196526e-12, -4.6228949246629785527253e-12,
80 -7.55826925521030645308565e-25},
81 {0, -26.1000000000000014210855, 461., -4.62289492466863919207562e-12,
82 -4.62289492466861601294603e-12, -5.02801075764912450865117e-29},
83 {0, -26.1000000000000014210855, 142311., -4.62289492466866229611911e-12,
84 -4.62289492466866222103301e-12, -5.27619828240266701257698e-34},
85 {0, -4.5, 0.000100000000000000004792174, -0.000471930181119562098823563,
86 -0.0000991078594800272766579705, -3.728223216395348042993},
87 {0, -4.5, 0.0840000000000000052180482, -0.0104333684327369739264569,
88 -0.0098114347031002672745465, -0.00740397297186555491995328},
89 {0, -4.5, 3.75999999999999978683718, -0.0110926179127638047017042,
90 -0.0110762714687203261403217, -4.34745852220174529502355e-6},
91 {0, -4.5, 461., -0.0111088626902796869123984, -0.0111087288444673011605548,
92 -2.90337987821587513149328e-10},
93 {0, -4.5, 142311., -0.0111089961046503927849115,
94 -0.0111089956710585016382076, -3.04679112048052440606817e-15},
95 {0, -0.20000000000000001110223, 0.000100000000000000004792174,
96 -0.000901046250479348295492963, -0.000099987787464060911106451,
97 -8.010584630152873459984},
98 {0, -0.20000000000000001110223, 0.0840000000000000052180482,
99 -0.199467033424133669556274, -0.0761837159352978552652188,
100 -1.46765854153375960277075},
101 {0, -0.20000000000000001110223, 3.75999999999999978683718,
102 -0.740730806872720638510927, -0.672332093234306416324947,
103 -0.0181911472442591026765043},
104 {0, -0.20000000000000001110223, 461., -0.818004584479470198072977,
105 -0.81727927438862128024214, -1.57334076106055928597981e-6},
106 {0, -0.20000000000000001110223, 142311., -0.818728397963215257316515,
107 -0.818726042857481443695725, -1.65490069904197201165846e-11},
108 {0, 2., 0.000100000000000000004792174, -0.00112103539054129295184384,
109 -0.000099998646665483029681996, -10.2103674387580987323199},
110 {0, 2., 0.0840000000000000052180482, -0.377012371003043676678775,
111 -0.0830558079711177822673545, -3.49948289323721281132084},
112 {0, 2., 3.75999999999999978683718, -4.08687891675795882940335,
113 -2.49194646483517166950394, -0.424184162745422141042474},
114 {0, 2., 461., -7.33046427302122464456484, -7.27249028826061546955612,
115 -0.000125757016834293221277066},
116 {0, 2., 142311., -7.38886427869112561814738, -7.38867246509109332505865,
117 -1.34784802321881715908971e-9},
118 {0, 13.4000000000000003552714, 0.000100000000000000004792174,
119 -0.00226103403721276985381512, -0.000099999999984848563673037,
120 -21.6103403722792118658158},
121 {0, 13.4000000000000003552714, 0.0840000000000000052180482,
122 -1.33366284302251745111466, -0.0839999893091445105829689,
123 -14.876938734683010272656},
124 {0, 13.4000000000000003552714, 3.75999999999999978683718,
125 -45.4042061406096035346291, -3.75997857962063060636234,
126 -11.0755924364332381088221},
127 {0, 13.4000000000000003552714, 461., -3350.22538971724911576917,
128 -460.678224812138394142983, -6.26799818851433996014357},
129 {0, 13.4000000000000003552714, 142311., -246124.815117870661913638,
130 -117068.49513614471120162, -0.906861170125471331885925},
131 {0, 84.2999999999999971578291, 0.000100000000000000004792174,
132 -0.00935103403719761843271571, -0.000100000000000000004792174,
133 -92.5103403719761798459793},
134 {0, 84.2999999999999971578291, 0.0840000000000000052180482,
135 -7.28926283233166137753023, -0.0840000000000000052180482,
136 -85.7769384801388205324283},
137 {0, 84.2999999999999971578291, 3.75999999999999978683718,
138 -311.988184720169192006433, -3.75999999999999978683718,
139 -81.9755810425981940674846},
140 {0, 84.2999999999999971578291, 461., -36034.8035021785437227894, -461.,
141 -77.1666019570033486394563},
142 {0, 84.2999999999999971578291, 142311., -1.03081876937799361823842e7,
143 -142311., -71.4342299174339030881957},
144 {7, -36., 0.000100000000000000004792174, -198.683622924671001775538,
145 6.99999999998376310823605, -59997.5501489646175798213},
146 {7, -36., 0.0840000000000000052180482, -238.883518369596162432722,
147 6.9999999999999804386908, -69.0960424923444586909594},
148 {7, -36., 3.75999999999999978683718, -256.751996432901803890657,
149 6.99999999999999933622166, -0.718640295127228617069051},
150 {7, -36., 461., -260.47982082426917795879, 6.99999999999999976452566,
151 -0.0000978945853445233788324339},
152 {7, -36., 142311., -260.525013799174310121032, 6.99999999999999976803631,
153 -1.03688150994232402700446e-9},
154 {7, -26.1000000000000014210855, 0.000100000000000000004792174,
155 -129.383623248262034973091, 6.99999967639274733829224,
156 -59997.5469131006865090386},
157 {7, -26.1000000000000014210855, 0.0840000000000000052180482,
158 -169.583518369986016957648, 6.9999999996101358613744,
159 -69.0960424877584835201958},
160 {7, -26.1000000000000014210855, 3.75999999999999978683718,
161 -187.451996432915042522719, 6.99999999998677065175815,
162 -0.718640295124939781566091},
163 {7, -26.1000000000000014210855, 461., -191.17982082427388076164,
164 6.99999999999530690927387, -0.000097894585344371117922657},
165 {7, -26.1000000000000014210855, 142311., -191.225013799178942958982,
166 6.99999999999537687768416, -1.03688150994072626043362e-9},
167 {7, -4.5, 0.000100000000000000004792174, -11.2192075332052297804221,
168 0.0623507285386109391248531, 9374.22326367570748540949},
169 {7, -4.5, 0.0840000000000000052180482, -19.2633991074236276185347,
170 6.17256900670521084397013, -59.3698802916059729529247},
171 {7, -4.5, 3.75999999999999978683718, -36.2837402011840741159725,
172 6.96830301037142800168473, -0.713160409032598892878726},
173 {7, -4.5, 461., -39.9910983681716742384461, 6.9887225919756817853728,
174 -0.0000975289772446567319934955},
175 {7, -4.5, 142311., -40.0361233417087705744711, 6.98889045789915253328171,
176 -1.03304486908480829056641e-9},
177 {7, -0.20000000000000001110223, 0.000100000000000000004792174,
178 -11.1577615044596219717133, 0.000754889728272497048508259,
179 9985.89049108549401489409},
180 {7, -0.20000000000000001110223, 0.0840000000000000052180482,
181 -5.10523818836474804913996, 0.575173289456547933675185,
182 5.01538223526623988873864},
183 {7, -0.20000000000000001110223, 3.75999999999999978683718,
184 -8.27174735895246138021499, 5.07598581829756986530093,
185 -0.403937269906561294876212},
186 {7, -0.20000000000000001110223, 461., -10.7102463026561453773833,
187 6.1703108450892087653941, -0.0000725484455802695111298855},
188 {7, -0.20000000000000001110223, 142311., -10.7437824687875434412923,
189 6.18123368560834338804528, -7.70447935296972460390227e-10},
190 {7, 2., 0.000100000000000000004792174, -11.1572212979358090452568,
191 -5.26523047722532179308571e-6, 9991.29214927237179177473},
192 {7, 2., 0.0840000000000000052180482, -4.67822832418615765300115,
193 -0.00437313889759920793061993, 9.8011095064000150019666},
194 {7, 2., 3.75999999999999978683718, -2.44742652447511137215548,
195 -0.131208500432565998079953, 0.0910218281904859865055242},
196 {7, 2., 461., -1.92159366542738408699561, -0.382918557279757136122044,
197 0.0000158891115286490464064998},
198 {7, 2., 142311., -1.91424152166072074784582, -0.38903589945144240315715,
199 1.69074180881189252797788e-10},
200 {7, 13.4000000000000003552714, 0.000100000000000000004792174,
201 -11.1582665635858573490992, -0.0000999989393839702240949489,
202 9980.83949989472447227895},
203 {7, 13.4000000000000003552714, 0.0840000000000000052180482,
204 -5.55575146449510519129741, -0.083999098404519957659751,
205 -0.639658499701806072156294},
206 {7, 13.4000000000000003552714, 3.75999999999999978683718,
207 -40.8853097501782241578175, -3.75993870125478372781648,
208 -9.93254120984925579863814},
209 {7, 13.4000000000000003552714, 461., -3315.77631190381925371792,
210 -460.673338854838977134308, -6.25292229992679818636882},
211 {7, 13.4000000000000003552714, 142311., -246051.646530372853518603,
212 -117067.253506656850199859, -0.906820707877630374849315},
213 {7, 84.2999999999999971578291, 0.000100000000000000004792174,
214 -11.1653565625252413192581, -0.000100000000000000004792174,
215 9909.93951050103628769457},
216 {7, 84.2999999999999971578291, 0.0840000000000000052180482,
217 -11.5113505628995678711385, -0.0840000000000000052180482,
218 -71.5396476391501811780814},
219 {7, 84.2999999999999971578291, 3.75999999999999978683718,
220 -307.469248451258373600901, -3.75999999999999978683718,
221 -80.8325192100658482257719},
222 {7, 84.2999999999999971578291, 461., -36000.3495367018363608841, -461.,
223 -77.151515469809951297333},
224 {7, 84.2999999999999971578291, 142311., -1.03081131584031573940356e7,
225 -142311., -71.4341807304248851244502},
226 {11, -36., 0.000100000000000000004792174, -306.294198663985101325787,
227 10.9999999999744850169151, -99997.0711864556783212622},
228 {11, -36., 0.0840000000000000052180482, -373.387720780818404406357,
229 10.999999999999969393344, -116.240975127892379549188},
230 {11, -36., 3.75999999999999978683718, -405.018198139691527464642,
231 10.9999999999999990894639, -1.45345494122108727163354},
232 {11, -36., 461., -413.383897627979795451303, 10.9999999999999997625131,
233 -0.000254934049399102214755753},
234 {11, -36., 142311., -413.501921377875076957036, 10.9999999999999997680298,
235 -2.71559116633421423577824e-9},
236 {11, -26.1000000000000014210855, 0.000100000000000000004792174,
237 -197.394199172482654828876, 10.9999994914769589000084,
238 -99997.0661015266437808503},
239 {11, -26.1000000000000014210855, 0.0840000000000000052180482,
240 -264.487720781428391423827, 10.9999999993899980078309,
241 -116.240975120685847137987},
242 {11, -26.1000000000000014210855, 3.75999999999999978683718,
243 -296.118198139709689507612, 10.999999999981852678434,
244 -1.45345494121749053012889},
245 {11, -26.1000000000000014210855, 461., -304.483897627984544048369,
246 10.9999999999952667973873, -0.000254934049398862947611818},
247 {11, -26.1000000000000014210855, 142311., -304.601921377879715609259,
248 10.9999999999953767477464, -2.71559116633170345973834e-9},
249 {11, -4.5, 0.000100000000000000004792174, -11.706990517292534287664,
250 0.0980363493375200627828593, 9017.8460181027726944107},
251 {11, -4.5, 0.0840000000000000052180482, -28.2644285868714287029341,
252 9.70535783036710290753852, -100.952775113605121643751},
253 {11, -4.5, 3.75999999999999978683718, -58.5617425653278014914233,
254 10.9565197428515127604419, -1.44484120738179930405144},
255 {11, -4.5, 461., -66.8952715611464156307853, 10.9886262038729098348204,
256 -0.0002543593564776044431017},
257 {11, -4.5, 142311., -67.0130312326551432919643, 10.9888901456535588389502,
258 -2.70956041824890637536164e-9},
259 {11, -0.20000000000000001110223, 0.000100000000000000004792174,
260 -11.6101872629383735675807, 0.00124339116583624445400238,
261 9981.48443912601296955335},
262 {11, -0.20000000000000001110223, 0.0840000000000000052180482,
263 -5.90787076264096269638241, 0.947377292537602670212559,
264 1.05849718208658084906379},
265 {11, -0.20000000000000001110223, 3.75999999999999978683718,
266 -14.1259605624152918587705, 8.36073891060149916908714,
267 -0.948526674591890563801079},
268 {11, -0.20000000000000001110223, 461., -20.4214207600281899855949,
269 10.1632194847908259343291, -0.000214205349334668674488331},
270 {11, -0.20000000000000001110223, 142311., -20.520713059859749302951,
271 10.1812106733031004347544, -2.2874532593255647910321e-9},
272 {11, 2., 0.000100000000000000004792174, -11.6092126588923464290597,
273 0.0000488681502017790827148631, 9991.22977788173817623571},
274 {11, 2., 0.0840000000000000052180482, -5.13540078317237206980856,
275 0.0405883862872685488332284, 9.73996823769877214616492},
276 {11, 2., 3.75999999999999978683718, -3.06137175973074859416703,
277 1.21778462208320867130519, 0.0612622027041340431507753},
278 {11, 2., 461., -2.88927536521627684556178, 3.55397957470930476869743,
279 -4.26994469309227722196226e-6},
280 {11, 2., 142311., -2.89135678253022200706033, 3.61075642377121526650085,
281 -5.03190674865121337520085e-11},
282 {11, 13.4000000000000003552714, 0.000100000000000000004792174,
283 -11.6102037914014672054394, -0.0000999983333263254586217559,
284 9981.31845625030445310433},
285 {11, 13.4000000000000003552714, 0.0840000000000000052180482,
286 -5.96770830536101890173932, -0.0839985893161630702750551,
287 -0.165549576777919563836645},
288 {11, 13.4000000000000003552714, 3.75999999999999978683718,
289 -39.8538584150632717115069, -3.759915913617156940076,
290 -9.60353212925099392334374},
291 {11, 13.4000000000000003552714, 461., -3300.14958948598756276581,
292 -460.670546879239310272208, -6.24440860614920345254555},
293 {11, 13.4000000000000003552714, 142311., -246013.941380067564168339,
294 -117066.544004092358198853, -0.906797587679355764396099},
295 {11, 84.2999999999999971578291, 0.000100000000000000004792174,
296 -11.617293789734793530787, -0.000100000000000000004792174,
297 9910.41847291719271617468},
298 {11, 84.2999999999999971578291, 0.0840000000000000052180482,
299 -11.9233068946770922978236, -0.0840000000000000052180482,
300 -71.065532655650617438992},
301 {11, 84.2999999999999971578291, 3.75999999999999978683718,
302 -306.43777432844088456675, -3.75999999999999978683718,
303 -80.503504068925664332461},
304 {11, 84.2999999999999971578291, 461., -35984.7200213335603843011, -461.,
305 -77.1429957196861533815963},
306 {11, 84.2999999999999971578291, 142311., -1.03080746722304058304262e7,
307 -142311., -71.4341526246487950815918},
308 {2338, -36., 0.000100000000000000004792174, -62651.1907684420448777443,
309 2337.99999999457695539095, -2.33699916660572733525724e7},
310 {2338, -36., 0.0840000000000000052180482, -78386.4573396010374044213,
311 2337.99999999999354376284, -27813.2247691071302770408},
312 {2338, -36., 3.75999999999999978683718, -87244.5784991555256038266,
313 2337.99999999999985553814, -615.235652274809944062375},
314 {2338, -36., 461., -97261.3541997290761812182, 2337.99999999999999859168,
315 -3.26705772262828347042341},
316 {2338, -36., 142311., -99951.6902594622698079577, 2337.99999999999999976424,
317 -0.000133435966326975414974749},
318 {2338, -26.1000000000000014210855, 0.000100000000000000004792174,
319 -39504.990876519910618993, 2337.99989191671703492836,
320 -2.33699905852787203745752e7},
321 {2338, -26.1000000000000014210855, 0.0840000000000000052180482,
322 -55240.2573397297094689766, 2337.99999987132480170888,
323 -27813.2247675754145700052},
324 {2338, -26.1000000000000014210855, 3.75999999999999978683718,
325 -64098.3784991584079601655, 2337.99999999712082169713,
326 -615.235652274045473004386},
327 {2338, -26.1000000000000014210855, 461., -74115.1541997291075706003,
328 2337.99999999997193170739, -3.26705772262823261527954},
329 {2338, -26.1000000000000014210855, 142311., -76805.4902594622778290634,
330 2337.99999999999530115643, -0.000133435966326974881320714},
331 {2338, -4.5, 0.000100000000000000004792174, -37.9188749421642944369079,
332 20.858146249102902750828, -198577.847904345129887008},
333 {2338, -4.5, 0.0840000000000000052180482, -5029.86319434730944140515,
334 2064.90525599567281088845, -24562.2210710608448107488},
335 {2338, -4.5, 3.75999999999999978683718, -13604.4870759968535893564,
336 2331.10160386314082116743, -613.403922615515723001624},
337 {2338, -4.5, 461., -23614.4216481166468801215, 2337.93255242508532760098,
338 -3.26693551284037806579056},
339 {2338, -4.5, 142311., -26304.7015509659310958491, 2337.98870849677942716161,
340 -0.000133434683874347561598764},
341 {2338, -0.20000000000000001110223, 0.000100000000000000004792174,
342 -17.2530056900344775246763, 0.285429102468546297600207,
343 7145.03240130482692470614},
344 {2338, -0.20000000000000001110223, 0.0840000000000000052180482,
345 -238.089236939508565582574, 217.47705608494119565083,
346 -2571.27861575441500369895},
347 {2338, -0.20000000000000001110223, 3.75999999999999978683718,
348 -4005.5119497678293099549, 1919.26585035841237164672,
349 -504.067189818768777258691},
350 {2338, -0.20000000000000001110223, 461., -13565.92078287865980031,
351 2333.03782063120661396226, -3.25806818947358951236853},
352 {2338, -0.20000000000000001110223, 142311., -16252.1224385913391334784,
353 2337.1678232647280123578, -0.000133341466693716057784701},
354 {2338, 2., 0.000100000000000000004792174, -16.9993203274403325176464,
355 0.0315409623602125914052141, 9681.71391098754441853037},
356 {2338, 2., 0.0840000000000000052180482, -36.345344890082340900417,
357 26.196955662584086046202, -296.248198078484197433633},
358 {2338, 2., 3.75999999999999978683718, -945.921470460610089695847,
359 785.994533645635122586109, -203.555176892467693501323},
360 {2338, 2., 461., -8461.86172575984929863877, 2293.84446785934606789742,
361 -3.18717688126682423770834},
362 {2338, 2., 142311., -11115.2005139685862037106, 2330.48994045855231459004,
363 -0.000132584343734508495425318},
364 {2338, 13.4000000000000003552714, 0.000100000000000000004792174,
365 -16.9688195052148641926364, -0.0000996457592914831445917036,
366 9986.7200057158703803923},
367 {2338, 13.4000000000000003552714, 0.0840000000000000052180482,
368 -10.8734665707860518042194, -0.0837024271645438342281827,
369 5.22808308495811843394431},
370 {2338, 13.4000000000000003552714, 3.75999999999999978683718,
371 -25.5045023028520732667892, -3.74665920542777317204743,
372 -4.50627645969873246867534},
373 {2338, 13.4000000000000003552714, 461., -2105.32744445484601087282,
374 -459.046315074133113245574, -4.46701233139860018126931},
375 {2338, 13.4000000000000003552714, 142311., -234622.84254414070827656,
376 -116653.790887199136613496, -0.893479868994494336336804},
377 {2338, 84.2999999999999971578291, 0.000100000000000000004792174,
378 -16.9759091509741556489599, -0.000100000000000000004792174,
379 9915.82354812310706660278},
380 {2338, 84.2999999999999971578291, 0.0840000000000000052180482,
381 -16.8287689979316593747652, -0.0840000000000000052180482,
382 -65.668374254014350440872},
383 {2338, 84.2999999999999971578291, 3.75999999999999978683718,
384 -292.075161470279026145921, -3.75999999999999978683718,
385 -75.4027226791102688967341},
386 {2338, 84.2999999999999971578291, 461., -34788.2730773814556666438, -461.,
387 -75.3620761655318490321712},
388 {2338, 84.2999999999999971578291, 142311., -1.02962292135863589242571e7,
389 -142311., -71.4179345460698058518492},
390 {9611, -36., 0.000100000000000000004792174, -257493.798714111504824811,
391 9610.99999997770706584663, -9.60999902521143419375755e7},
392 {9611, -36., 0.0840000000000000052180482, -322200.978630068417237347,
393 9610.999999999973460561, -114395.144355373590193239},
394 {9611, -36., 3.75999999999999978683718, -358701.177684561134391123,
395 9610.99999999999940687087, -2548.13160475324762913113},
396 {9611, -36., 461., -403079.072217515117383779, 9610.99999999999999493227,
397 -17.7630043020635701626676},
398 {9611, -36., 142311., -424212.312866130535985746, 9610.99999999999999975238,
399 -0.00218253709318427116222024},
400 {9611, -26.1000000000000014210855, 0.000100000000000000004792174,
401 -162344.899158395651111601, 9610.99955569358470701868,
402 -9.60999858092731645759248e7},
403 {9611, -26.1000000000000014210855, 0.0840000000000000052180482,
404 -227052.078630597345206472, 9610.99999947105914950335,
405 -114395.144349077046283779},
406 {9611, -26.1000000000000014210855, 3.75999999999999978683718,
407 -263552.277684572968739345, 9610.99999998817871670057,
408 -2548.13160475010505798557},
409 {9611, -26.1000000000000014210855, 461., -307930.172217515232038494,
410 9610.99999999989899826967, -17.7630043020633611084385},
411 {9611, -26.1000000000000014210855, 142311., -329063.412866130554578654,
412 9610.999999999995064897, -0.00218253709318426896848674},
413 {9611, -4.5, 0.000100000000000000004792174, -104.508893420504154556264,
414 85.743526266719416841998, -847430.234306281205498662},
415 {9611, -4.5, 0.0840000000000000052180482, -20648.2403016158121274787,
416 8488.39853461890805547167, -101030.965402842249733567},
417 {9611, -4.5, 3.75999999999999978683718, -55983.0428066363797886683,
418 9582.67667769505493377765, -2540.60175643222855702327},
419 {9611, -4.5, 461., -100332.814925682181363541, 9610.75729475722022850913,
420 -17.7625019237987339653397},
421 {9611, -4.5, 142311., -121465.824725374770167655, 9610.98814075622869244341,
422 -0.00218253182129617637581483},
423 {9611, -0.20000000000000001110223, 0.000100000000000000004792174,
424 -19.5547487056230065372451, 1.17364684131883001763989,
425 -1735.73121296897564419702},
426 {9611, -0.20000000000000001110223, 0.0840000000000000052180482,
427 -949.731171379774212295606, 894.23698468706897035991,
428 -10626.5306853797828683123},
429 {9611, -0.20000000000000001110223, 3.75999999999999978683718,
430 -16521.513038999314726487, 7891.768160440031828256,
431 -2091.0860971061479068882},
432 {9611, -0.20000000000000001110223, 461., -59023.1441094345908559896,
433 9593.14395476867203137831, -17.7260454286430745149716},
434 {9611, -0.20000000000000001110223, 142311., -80139.3868875039743196636,
435 9610.1259811407200125367, -0.00218214857464869223866328},
436 {9611, 2., 0.000100000000000000004792174, -18.5112200482628704439815,
437 0.129968981779812349901792, 8698.84749102058108657733},
438 {9611, 2., 0.0840000000000000052180482, -119.854424704305068624943,
439 107.948248829969884782069, -1268.06413157597183571962},
440 {9611, 2., 3.75999999999999978683718, -3933.80532649936636751687,
441 3238.80127865994241519564, -854.484838151306300912966},
442 {9611, 2., 461., -38021.2293458402145037005, 9452.10949634845787633542,
443 -17.4342406591600541232224},
444 {9611, 2., 142311., -59002.2007387401537465017, 9603.1123321581496224457,
445 -0.0021790320685329142613448},
446 {9611, 13.4000000000000003552714, 0.000100000000000000004792174,
447 -18.3822916425459253100299, -0.0000985437949788883229584548,
448 9988.12276030177863802081},
449 {9611, 13.4000000000000003552714, 0.0840000000000000052180482,
450 -12.1692487377720201796787, -0.0827767772596333470047606,
451 6.63081050986194646155585},
452 {9611, 13.4000000000000003552714, 3.75999999999999978683718,
453 -21.6460442652836866611064, -3.70522558331286636290617,
454 -3.10473788018826412150356},
455 {9611, 13.4000000000000003552714, 461., -1491.91977767158488399654,
456 -453.969805440038841232377, -3.19739819420502413093782},
457 {9611, 13.4000000000000003552714, 142311., -212175.813403243944037729,
458 -115363.73784931155578397, -0.853487617294708440971187},
459 {9611, 84.2999999999999971578291, 0.000100000000000000004792174,
460 -18.3893801863409040880501, -0.000100000000000000004792174,
461 9917.23732235214127244709},
462 {9611, 84.2999999999999971578291, 0.0840000000000000052180482,
463 -18.1236255149538125593602, -0.0840000000000000052180482,
464 -64.2546271873853975660471},
465 {9611, 84.2999999999999971578291, 3.75999999999999978683718,
466 -288.17526969257348869956, -3.75999999999999978683718,
467 -73.9901645192500912911873},
468 {9611, 84.2999999999999971578291, 461., -34169.7871284528721914203, -461.,
469 -74.0814500768543374677113},
470 {9611, 84.2999999999999971578291, 142311., -1.02723620903825239882665e7,
471 -142311., -71.3688772675071099858592},
472 };
473
474 } // namespace neg_binomial_2_log_test_internal
475
TEST(ProbDistributionsNegBinomial2Log,derivativesPrecomputed)476 TEST(ProbDistributionsNegBinomial2Log, derivativesPrecomputed) {
477 using neg_binomial_2_log_test_internal::TestValue;
478 using neg_binomial_2_log_test_internal::testValues;
479 using stan::math::is_nan;
480 using stan::math::neg_binomial_2_log_lpmf;
481 using stan::math::value_of;
482 using stan::math::var;
483
484 for (TestValue t : testValues) {
485 int n = t.n;
486 var eta(t.eta);
487 var phi(t.phi);
488 var val = neg_binomial_2_log_lpmf(n, eta, phi);
489
490 std::vector<var> x;
491 x.push_back(eta);
492 x.push_back(phi);
493
494 std::vector<double> gradients;
495 val.grad(x, gradients);
496
497 for (int i = 0; i < 2; ++i) {
498 EXPECT_FALSE(is_nan(gradients[i]));
499 }
500
501 auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-14); };
502
503 EXPECT_NEAR(value_of(val), t.value, tolerance(t.value))
504 << "value n = " << n << ", eta = " << t.eta << ", phi = " << t.phi;
505 EXPECT_NEAR(gradients[0], t.grad_eta, tolerance(t.grad_eta))
506 << "grad_mu n = " << n << ", eta = " << t.eta << ", phi = " << t.phi;
507 EXPECT_NEAR(gradients[1], t.grad_phi, tolerance(t.grad_phi))
508 << "grad_phi n = " << n << ", eta = " << t.eta << ", phi = " << t.phi;
509 }
510 }
511
TEST(ProbDistributionsNegBinomial2Log,derivativesComplexStep)512 TEST(ProbDistributionsNegBinomial2Log, derivativesComplexStep) {
513 using boost::math::differentiation::complex_step_derivative;
514 using stan::math::is_nan;
515 using stan::math::log1p_exp;
516 using stan::math::log_sum_exp;
517 using stan::math::neg_binomial_2_log_lpmf;
518 using stan::math::var;
519
520 std::vector<int> n_to_test = {0, 7, 100, 835, 14238, 385000, 1000000};
521 std::vector<double> eta_to_test = {-124.5, -13, -2, 0, 0.536844, 1.26845, 11};
522
523 auto nb2_log_for_test = [](int n, const std::complex<double>& eta,
524 const std::complex<double>& phi) {
525 // Using first-order Taylor expansion of lgamma(a + b*i) around b = 0
526 // Which happens to work nice in this case, as b is always 0 or the very
527 // small complex step
528 auto lgamma_c_approx = [](const std::complex<double>& x) {
529 return std::complex<double>(lgamma(x.real()),
530 x.imag() * boost::math::digamma(x.real()));
531 };
532
533 const double n_(n);
534 return lgamma_c_approx(n_ + phi) - lgamma(n + 1) - lgamma_c_approx(phi)
535 + phi * (log(phi) - log(exp(eta) + phi)) - n_ * log(exp(eta) + phi)
536 + n_ * eta;
537 };
538
539 double phi_cutoff = 1e10;
540 for (double eta_dbl : eta_to_test) {
541 for (int n : n_to_test) {
542 for (double phi_dbl = 1.5; phi_dbl < 1e22; phi_dbl *= 10) {
543 var eta(eta_dbl);
544 var phi(phi_dbl);
545 var val = neg_binomial_2_log_lpmf(n, eta, phi);
546
547 std::vector<var> x;
548 x.push_back(eta);
549 x.push_back(phi);
550
551 std::vector<double> gradients;
552 val.grad(x, gradients);
553
554 EXPECT_TRUE(value_of(val) < 0)
555 << "for n = " << n << ", eta = " << eta_dbl
556 << ", phi = " << phi_dbl;
557
558 for (int i = 0; i < 2; ++i) {
559 EXPECT_FALSE(is_nan(gradients[i]));
560 }
561
562 auto nb2_log_eta
563 = [n, phi_dbl, nb2_log_for_test](const std::complex<double>& eta) {
564 return nb2_log_for_test(n, eta, phi_dbl);
565 };
566 auto nb2_log_phi
567 = [n, eta_dbl, nb2_log_for_test](const std::complex<double>& phi) {
568 return nb2_log_for_test(n, eta_dbl, phi);
569 };
570 double complex_step_deta
571 = complex_step_derivative(nb2_log_eta, eta_dbl);
572 double complex_step_dphi
573 = complex_step_derivative(nb2_log_phi, phi_dbl);
574
575 std::ostringstream message;
576 message << ", n = " << n << ", eta = " << eta_dbl
577 << ", phi = " << phi_dbl;
578
579 double tolerance_phi;
580 double tolerance_eta;
581 if (phi < phi_cutoff || n < 100000) {
582 tolerance_phi = std::max(1e-10, fabs(gradients[1]) * 1e-8);
583 } else {
584 tolerance_phi = std::max(1e-8, fabs(gradients[1]) * 1e-5);
585 }
586
587 if (phi < phi_cutoff) {
588 tolerance_eta = std::max(1e-10, fabs(gradients[0]) * 1e-8);
589 } else {
590 tolerance_eta = std::max(1e-8, fabs(gradients[0]) * 1e-5);
591 }
592
593 EXPECT_NEAR(gradients[0], complex_step_deta, tolerance_eta)
594 << "grad_mu" << message.str();
595
596 EXPECT_NEAR(gradients[1], complex_step_dphi, tolerance_phi)
597 << "grad_phi" << message.str();
598 }
599 }
600 }
601 }
602
TEST(ProbDistributionsNegBinomial2Log,derivativesZeroOne)603 TEST(ProbDistributionsNegBinomial2Log, derivativesZeroOne) {
604 using stan::math::log1p_exp;
605 using stan::math::log_diff_exp;
606 using stan::math::log_sum_exp;
607 using stan::math::var;
608 using stan::test::expect_near_rel;
609 using stan::test::relative_tolerance;
610
611 std::vector<double> eta_to_test = {-943, -130, -15, -1, 0, 0.138, 14, 611};
612 double phi_start = 1e-8;
613 double phi_max = 1e20;
614 for (double eta_dbl : eta_to_test) {
615 for (double phi_dbl = phi_start; phi_dbl < phi_max;
616 phi_dbl *= stan::math::pi()) {
617 std::stringstream msg;
618 msg << std::setprecision(20) << ", eta = " << eta_dbl
619 << ", phi = " << phi_dbl;
620
621 var eta0(eta_dbl);
622 var phi0(phi_dbl);
623 var val0 = neg_binomial_2_log_lpmf(0, eta0, phi0);
624
625 std::vector<var> x0;
626 x0.push_back(eta0);
627 x0.push_back(phi0);
628
629 std::vector<double> gradients0;
630 val0.grad(x0, gradients0);
631
632 var eta1(eta_dbl);
633 var phi1(phi_dbl);
634 var val1 = neg_binomial_2_log_lpmf(1, eta1, phi1);
635
636 std::vector<var> x1;
637 x1.push_back(eta1);
638 x1.push_back(phi1);
639
640 std::vector<double> gradients1;
641 val1.grad(x1, gradients1);
642
643 double expected_value_0 = phi_dbl * (-log1p_exp(eta_dbl - log(phi_dbl)));
644 expect_near_rel("value, n = 0 " + msg.str(), val0.val(),
645 expected_value_0);
646
647 double expected_deta_0 = -phi_dbl / (1 + phi_dbl / exp(eta_dbl));
648 expect_near_rel("deta, n = 0 " + msg.str(), gradients0[0],
649 expected_deta_0);
650
651 double expected_dphi_0 = 1.0 / (1.0 + phi_dbl / exp(eta_dbl))
652 - log1p_exp(eta_dbl - log(phi_dbl));
653 expect_near_rel("dphi, n = 0 " + msg.str(), gradients0[1],
654 expected_dphi_0);
655
656 double expected_value_1
657 = (phi_dbl + 1) * (-log1p_exp(eta_dbl - log(phi_dbl))) + eta_dbl;
658 expect_near_rel("value, n = 1 " + msg.str(), val1.val(),
659 expected_value_1);
660
661 double expected_deta_1
662 = exp(log(phi_dbl) - log_sum_exp(eta_dbl, log(phi_dbl)))
663 + expected_deta_0;
664 expect_near_rel("deta, n = 1 " + msg.str(), gradients1[0],
665 expected_deta_1);
666
667 double expected_dphi_1
668 = (1 + phi_dbl) / (phi_dbl + (phi_dbl * phi_dbl) / exp(eta_dbl))
669 - log1p_exp(eta_dbl - log(phi_dbl));
670 expect_near_rel("dphi, n = 1 " + msg.str(), gradients1[1],
671 expected_dphi_1);
672 }
673 }
674 }
675
TEST(ProbDistributionsNegBinomial2Log,derivatives_diff_sizes)676 TEST(ProbDistributionsNegBinomial2Log, derivatives_diff_sizes) {
677 using stan::math::neg_binomial_2_log_lpmf;
678 using stan::math::var;
679
680 int N = 100;
681 double eta_dbl = 1.5;
682 std::vector<double> phi_dbl{2, 4, 6, 8};
683
684 var mu(eta_dbl);
685 std::vector<var> phi;
686 for (double i : phi_dbl) {
687 phi.push_back(var(i));
688 }
689 var val = neg_binomial_2_log_lpmf(N, mu, phi);
690
691 std::vector<var> x{mu};
692 std::vector<double> gradients;
693 val.grad(x, gradients);
694
695 double eps = 1e-6;
696 double grad_diff = (neg_binomial_2_log_lpmf(N, eta_dbl + eps, phi_dbl)
697 - neg_binomial_2_log_lpmf(N, eta_dbl - eps, phi_dbl))
698 / (2 * eps);
699 EXPECT_FLOAT_EQ(grad_diff, gradients[0]);
700 }
701