1 #include <stan/math/rev.hpp>
2 #include <test/unit/math/expect_near_rel.hpp>
3 #include <gtest/gtest.h>
4 #include <boost/math/differentiation/finite_difference.hpp>
5 #include <boost/math/special_functions/digamma.hpp>
6 #include <algorithm>
7 #include <limits>
8 #include <vector>
9 
10 namespace neg_binomial_2_log_test_internal {
11 struct TestValue {
12   int n;
13   double eta;
14   double phi;
15   double value;
16   double grad_eta;
17   double grad_phi;
18 
TestValueneg_binomial_2_log_test_internal::TestValue19   TestValue(int _n, double _eta, double _phi, double _value, double _grad_eta,
20             double _grad_phi)
21       : n(_n),
22         eta(_eta),
23         phi(_phi),
24         value(_value),
25         grad_eta(_grad_eta),
26         grad_phi(_grad_phi) {}
27 };
28 
29 // Test data generated in Mathematica (Wolfram Cloud). The code can be re-ran at
30 // https://www.wolframcloud.com/obj/martin.modrak/Published/neg_binomial_2_log_lpmf.nb
31 // but is also presented below for convenience:
32 //
33 // toCString[x_] := ToString[CForm[N[x, 24]]];
34 // nb2log[n_,logmu_,phi_]:= LogGamma[n + phi] - LogGamma[n + 1] -
35 //  LogGamma[phi] + n * (logmu - Log[Exp[logmu] + phi]) +
36 //  phi * (Log[phi] - Log[Exp[logmu] + phi])
37 // nb2logdmu[n_,logmu_,phi_]= D[nb2log[n, logmu, phi],logmu];
38 // nb2logdphi[n_,logmu_,phi_]= D[nb2log[n, logmu, phi],phi];
39 // logmus= SetPrecision[{-36,-26.1, -4.5,-0.2,2.0,13.4, 84.3}, Infinity];
40 // phis=  SetPrecision[{0.0001,0.084,3.76,461.0, 142311.0}, Infinity];
41 // ns = {0,7,11,2338,9611};
42 // out = "std::vector<TestValue> testValues = {\n";
43 // For[k = 1, k <= Length[ns], k++, {
44 //   For[i = 1, i <= Length[logmus], i++, {
45 //     For[j = 1, j <= Length[phis], j++, {
46 //       clogmu = logmus[[i]];
47 //       cphi = phis[[j]];
48 //       cn=ns[[k]];
49 //       val = nb2log[cn,clogmu,cphi];
50 //       ddmu= nb2logdmu[cn,clogmu,cphi];
51 //       ddphi= nb2logdphi[cn,clogmu,cphi];
52 //       out = StringJoin[out,"  {",ToString[cn],",",toCString[clogmu],",",
53 //         toCString[cphi],",", toCString[val],",",toCString[ddmu],",",
54 //         toCString[ddphi],"},\n"];
55 //     }]
56 //   }]
57 // }]
58 // out = StringJoin[out,"};\n"];
59 // out
60 
61 std::vector<TestValue> testValues = {
62     {0, -36., 0.000100000000000000004792174, -2.31952283024087929523226e-16,
63      -2.31952283023818920215225e-16, -2.69009308000224930599517e-24},
64     {0, -36., 0.0840000000000000052180482, -2.3195228302435661858205e-16,
65      -2.31952283024356298332874e-16, -3.81249019275872869665102e-30},
66     {0, -36., 3.75999999999999978683718, -2.31952283024356931676723e-16,
67      -2.31952283024356924522221e-16, -1.9027933171192913158222e-33},
68     {0, -36., 461., -2.31952283024356938772873e-16,
69      -2.3195228302435693871452e-16, -1.26580106437037714160219e-37},
70     {0, -36., 142311., -2.31952283024356938831037e-16,
71      -2.31952283024356938830848e-16, -1.32828224194512041655272e-42},
72     {0, -26.1000000000000014210855, 0.000100000000000000004792174,
73      -4.62289481781287824178924e-12, -4.62289471095709740559354e-12,
74      -1.06855780836195688701657e-15},
75     {0, -26.1000000000000014210855, 0.0840000000000000052180482,
76      -4.62289492454145310046865e-12, -4.62289492441424382973675e-12,
77      -1.51439608014165778339631e-21},
78     {0, -26.1000000000000014210855, 3.75999999999999978683718,
79      -4.62289492466582046196526e-12, -4.6228949246629785527253e-12,
80      -7.55826925521030645308565e-25},
81     {0, -26.1000000000000014210855, 461., -4.62289492466863919207562e-12,
82      -4.62289492466861601294603e-12, -5.02801075764912450865117e-29},
83     {0, -26.1000000000000014210855, 142311., -4.62289492466866229611911e-12,
84      -4.62289492466866222103301e-12, -5.27619828240266701257698e-34},
85     {0, -4.5, 0.000100000000000000004792174, -0.000471930181119562098823563,
86      -0.0000991078594800272766579705, -3.728223216395348042993},
87     {0, -4.5, 0.0840000000000000052180482, -0.0104333684327369739264569,
88      -0.0098114347031002672745465, -0.00740397297186555491995328},
89     {0, -4.5, 3.75999999999999978683718, -0.0110926179127638047017042,
90      -0.0110762714687203261403217, -4.34745852220174529502355e-6},
91     {0, -4.5, 461., -0.0111088626902796869123984, -0.0111087288444673011605548,
92      -2.90337987821587513149328e-10},
93     {0, -4.5, 142311., -0.0111089961046503927849115,
94      -0.0111089956710585016382076, -3.04679112048052440606817e-15},
95     {0, -0.20000000000000001110223, 0.000100000000000000004792174,
96      -0.000901046250479348295492963, -0.000099987787464060911106451,
97      -8.010584630152873459984},
98     {0, -0.20000000000000001110223, 0.0840000000000000052180482,
99      -0.199467033424133669556274, -0.0761837159352978552652188,
100      -1.46765854153375960277075},
101     {0, -0.20000000000000001110223, 3.75999999999999978683718,
102      -0.740730806872720638510927, -0.672332093234306416324947,
103      -0.0181911472442591026765043},
104     {0, -0.20000000000000001110223, 461., -0.818004584479470198072977,
105      -0.81727927438862128024214, -1.57334076106055928597981e-6},
106     {0, -0.20000000000000001110223, 142311., -0.818728397963215257316515,
107      -0.818726042857481443695725, -1.65490069904197201165846e-11},
108     {0, 2., 0.000100000000000000004792174, -0.00112103539054129295184384,
109      -0.000099998646665483029681996, -10.2103674387580987323199},
110     {0, 2., 0.0840000000000000052180482, -0.377012371003043676678775,
111      -0.0830558079711177822673545, -3.49948289323721281132084},
112     {0, 2., 3.75999999999999978683718, -4.08687891675795882940335,
113      -2.49194646483517166950394, -0.424184162745422141042474},
114     {0, 2., 461., -7.33046427302122464456484, -7.27249028826061546955612,
115      -0.000125757016834293221277066},
116     {0, 2., 142311., -7.38886427869112561814738, -7.38867246509109332505865,
117      -1.34784802321881715908971e-9},
118     {0, 13.4000000000000003552714, 0.000100000000000000004792174,
119      -0.00226103403721276985381512, -0.000099999999984848563673037,
120      -21.6103403722792118658158},
121     {0, 13.4000000000000003552714, 0.0840000000000000052180482,
122      -1.33366284302251745111466, -0.0839999893091445105829689,
123      -14.876938734683010272656},
124     {0, 13.4000000000000003552714, 3.75999999999999978683718,
125      -45.4042061406096035346291, -3.75997857962063060636234,
126      -11.0755924364332381088221},
127     {0, 13.4000000000000003552714, 461., -3350.22538971724911576917,
128      -460.678224812138394142983, -6.26799818851433996014357},
129     {0, 13.4000000000000003552714, 142311., -246124.815117870661913638,
130      -117068.49513614471120162, -0.906861170125471331885925},
131     {0, 84.2999999999999971578291, 0.000100000000000000004792174,
132      -0.00935103403719761843271571, -0.000100000000000000004792174,
133      -92.5103403719761798459793},
134     {0, 84.2999999999999971578291, 0.0840000000000000052180482,
135      -7.28926283233166137753023, -0.0840000000000000052180482,
136      -85.7769384801388205324283},
137     {0, 84.2999999999999971578291, 3.75999999999999978683718,
138      -311.988184720169192006433, -3.75999999999999978683718,
139      -81.9755810425981940674846},
140     {0, 84.2999999999999971578291, 461., -36034.8035021785437227894, -461.,
141      -77.1666019570033486394563},
142     {0, 84.2999999999999971578291, 142311., -1.03081876937799361823842e7,
143      -142311., -71.4342299174339030881957},
144     {7, -36., 0.000100000000000000004792174, -198.683622924671001775538,
145      6.99999999998376310823605, -59997.5501489646175798213},
146     {7, -36., 0.0840000000000000052180482, -238.883518369596162432722,
147      6.9999999999999804386908, -69.0960424923444586909594},
148     {7, -36., 3.75999999999999978683718, -256.751996432901803890657,
149      6.99999999999999933622166, -0.718640295127228617069051},
150     {7, -36., 461., -260.47982082426917795879, 6.99999999999999976452566,
151      -0.0000978945853445233788324339},
152     {7, -36., 142311., -260.525013799174310121032, 6.99999999999999976803631,
153      -1.03688150994232402700446e-9},
154     {7, -26.1000000000000014210855, 0.000100000000000000004792174,
155      -129.383623248262034973091, 6.99999967639274733829224,
156      -59997.5469131006865090386},
157     {7, -26.1000000000000014210855, 0.0840000000000000052180482,
158      -169.583518369986016957648, 6.9999999996101358613744,
159      -69.0960424877584835201958},
160     {7, -26.1000000000000014210855, 3.75999999999999978683718,
161      -187.451996432915042522719, 6.99999999998677065175815,
162      -0.718640295124939781566091},
163     {7, -26.1000000000000014210855, 461., -191.17982082427388076164,
164      6.99999999999530690927387, -0.000097894585344371117922657},
165     {7, -26.1000000000000014210855, 142311., -191.225013799178942958982,
166      6.99999999999537687768416, -1.03688150994072626043362e-9},
167     {7, -4.5, 0.000100000000000000004792174, -11.2192075332052297804221,
168      0.0623507285386109391248531, 9374.22326367570748540949},
169     {7, -4.5, 0.0840000000000000052180482, -19.2633991074236276185347,
170      6.17256900670521084397013, -59.3698802916059729529247},
171     {7, -4.5, 3.75999999999999978683718, -36.2837402011840741159725,
172      6.96830301037142800168473, -0.713160409032598892878726},
173     {7, -4.5, 461., -39.9910983681716742384461, 6.9887225919756817853728,
174      -0.0000975289772446567319934955},
175     {7, -4.5, 142311., -40.0361233417087705744711, 6.98889045789915253328171,
176      -1.03304486908480829056641e-9},
177     {7, -0.20000000000000001110223, 0.000100000000000000004792174,
178      -11.1577615044596219717133, 0.000754889728272497048508259,
179      9985.89049108549401489409},
180     {7, -0.20000000000000001110223, 0.0840000000000000052180482,
181      -5.10523818836474804913996, 0.575173289456547933675185,
182      5.01538223526623988873864},
183     {7, -0.20000000000000001110223, 3.75999999999999978683718,
184      -8.27174735895246138021499, 5.07598581829756986530093,
185      -0.403937269906561294876212},
186     {7, -0.20000000000000001110223, 461., -10.7102463026561453773833,
187      6.1703108450892087653941, -0.0000725484455802695111298855},
188     {7, -0.20000000000000001110223, 142311., -10.7437824687875434412923,
189      6.18123368560834338804528, -7.70447935296972460390227e-10},
190     {7, 2., 0.000100000000000000004792174, -11.1572212979358090452568,
191      -5.26523047722532179308571e-6, 9991.29214927237179177473},
192     {7, 2., 0.0840000000000000052180482, -4.67822832418615765300115,
193      -0.00437313889759920793061993, 9.8011095064000150019666},
194     {7, 2., 3.75999999999999978683718, -2.44742652447511137215548,
195      -0.131208500432565998079953, 0.0910218281904859865055242},
196     {7, 2., 461., -1.92159366542738408699561, -0.382918557279757136122044,
197      0.0000158891115286490464064998},
198     {7, 2., 142311., -1.91424152166072074784582, -0.38903589945144240315715,
199      1.69074180881189252797788e-10},
200     {7, 13.4000000000000003552714, 0.000100000000000000004792174,
201      -11.1582665635858573490992, -0.0000999989393839702240949489,
202      9980.83949989472447227895},
203     {7, 13.4000000000000003552714, 0.0840000000000000052180482,
204      -5.55575146449510519129741, -0.083999098404519957659751,
205      -0.639658499701806072156294},
206     {7, 13.4000000000000003552714, 3.75999999999999978683718,
207      -40.8853097501782241578175, -3.75993870125478372781648,
208      -9.93254120984925579863814},
209     {7, 13.4000000000000003552714, 461., -3315.77631190381925371792,
210      -460.673338854838977134308, -6.25292229992679818636882},
211     {7, 13.4000000000000003552714, 142311., -246051.646530372853518603,
212      -117067.253506656850199859, -0.906820707877630374849315},
213     {7, 84.2999999999999971578291, 0.000100000000000000004792174,
214      -11.1653565625252413192581, -0.000100000000000000004792174,
215      9909.93951050103628769457},
216     {7, 84.2999999999999971578291, 0.0840000000000000052180482,
217      -11.5113505628995678711385, -0.0840000000000000052180482,
218      -71.5396476391501811780814},
219     {7, 84.2999999999999971578291, 3.75999999999999978683718,
220      -307.469248451258373600901, -3.75999999999999978683718,
221      -80.8325192100658482257719},
222     {7, 84.2999999999999971578291, 461., -36000.3495367018363608841, -461.,
223      -77.151515469809951297333},
224     {7, 84.2999999999999971578291, 142311., -1.03081131584031573940356e7,
225      -142311., -71.4341807304248851244502},
226     {11, -36., 0.000100000000000000004792174, -306.294198663985101325787,
227      10.9999999999744850169151, -99997.0711864556783212622},
228     {11, -36., 0.0840000000000000052180482, -373.387720780818404406357,
229      10.999999999999969393344, -116.240975127892379549188},
230     {11, -36., 3.75999999999999978683718, -405.018198139691527464642,
231      10.9999999999999990894639, -1.45345494122108727163354},
232     {11, -36., 461., -413.383897627979795451303, 10.9999999999999997625131,
233      -0.000254934049399102214755753},
234     {11, -36., 142311., -413.501921377875076957036, 10.9999999999999997680298,
235      -2.71559116633421423577824e-9},
236     {11, -26.1000000000000014210855, 0.000100000000000000004792174,
237      -197.394199172482654828876, 10.9999994914769589000084,
238      -99997.0661015266437808503},
239     {11, -26.1000000000000014210855, 0.0840000000000000052180482,
240      -264.487720781428391423827, 10.9999999993899980078309,
241      -116.240975120685847137987},
242     {11, -26.1000000000000014210855, 3.75999999999999978683718,
243      -296.118198139709689507612, 10.999999999981852678434,
244      -1.45345494121749053012889},
245     {11, -26.1000000000000014210855, 461., -304.483897627984544048369,
246      10.9999999999952667973873, -0.000254934049398862947611818},
247     {11, -26.1000000000000014210855, 142311., -304.601921377879715609259,
248      10.9999999999953767477464, -2.71559116633170345973834e-9},
249     {11, -4.5, 0.000100000000000000004792174, -11.706990517292534287664,
250      0.0980363493375200627828593, 9017.8460181027726944107},
251     {11, -4.5, 0.0840000000000000052180482, -28.2644285868714287029341,
252      9.70535783036710290753852, -100.952775113605121643751},
253     {11, -4.5, 3.75999999999999978683718, -58.5617425653278014914233,
254      10.9565197428515127604419, -1.44484120738179930405144},
255     {11, -4.5, 461., -66.8952715611464156307853, 10.9886262038729098348204,
256      -0.0002543593564776044431017},
257     {11, -4.5, 142311., -67.0130312326551432919643, 10.9888901456535588389502,
258      -2.70956041824890637536164e-9},
259     {11, -0.20000000000000001110223, 0.000100000000000000004792174,
260      -11.6101872629383735675807, 0.00124339116583624445400238,
261      9981.48443912601296955335},
262     {11, -0.20000000000000001110223, 0.0840000000000000052180482,
263      -5.90787076264096269638241, 0.947377292537602670212559,
264      1.05849718208658084906379},
265     {11, -0.20000000000000001110223, 3.75999999999999978683718,
266      -14.1259605624152918587705, 8.36073891060149916908714,
267      -0.948526674591890563801079},
268     {11, -0.20000000000000001110223, 461., -20.4214207600281899855949,
269      10.1632194847908259343291, -0.000214205349334668674488331},
270     {11, -0.20000000000000001110223, 142311., -20.520713059859749302951,
271      10.1812106733031004347544, -2.2874532593255647910321e-9},
272     {11, 2., 0.000100000000000000004792174, -11.6092126588923464290597,
273      0.0000488681502017790827148631, 9991.22977788173817623571},
274     {11, 2., 0.0840000000000000052180482, -5.13540078317237206980856,
275      0.0405883862872685488332284, 9.73996823769877214616492},
276     {11, 2., 3.75999999999999978683718, -3.06137175973074859416703,
277      1.21778462208320867130519, 0.0612622027041340431507753},
278     {11, 2., 461., -2.88927536521627684556178, 3.55397957470930476869743,
279      -4.26994469309227722196226e-6},
280     {11, 2., 142311., -2.89135678253022200706033, 3.61075642377121526650085,
281      -5.03190674865121337520085e-11},
282     {11, 13.4000000000000003552714, 0.000100000000000000004792174,
283      -11.6102037914014672054394, -0.0000999983333263254586217559,
284      9981.31845625030445310433},
285     {11, 13.4000000000000003552714, 0.0840000000000000052180482,
286      -5.96770830536101890173932, -0.0839985893161630702750551,
287      -0.165549576777919563836645},
288     {11, 13.4000000000000003552714, 3.75999999999999978683718,
289      -39.8538584150632717115069, -3.759915913617156940076,
290      -9.60353212925099392334374},
291     {11, 13.4000000000000003552714, 461., -3300.14958948598756276581,
292      -460.670546879239310272208, -6.24440860614920345254555},
293     {11, 13.4000000000000003552714, 142311., -246013.941380067564168339,
294      -117066.544004092358198853, -0.906797587679355764396099},
295     {11, 84.2999999999999971578291, 0.000100000000000000004792174,
296      -11.617293789734793530787, -0.000100000000000000004792174,
297      9910.41847291719271617468},
298     {11, 84.2999999999999971578291, 0.0840000000000000052180482,
299      -11.9233068946770922978236, -0.0840000000000000052180482,
300      -71.065532655650617438992},
301     {11, 84.2999999999999971578291, 3.75999999999999978683718,
302      -306.43777432844088456675, -3.75999999999999978683718,
303      -80.503504068925664332461},
304     {11, 84.2999999999999971578291, 461., -35984.7200213335603843011, -461.,
305      -77.1429957196861533815963},
306     {11, 84.2999999999999971578291, 142311., -1.03080746722304058304262e7,
307      -142311., -71.4341526246487950815918},
308     {2338, -36., 0.000100000000000000004792174, -62651.1907684420448777443,
309      2337.99999999457695539095, -2.33699916660572733525724e7},
310     {2338, -36., 0.0840000000000000052180482, -78386.4573396010374044213,
311      2337.99999999999354376284, -27813.2247691071302770408},
312     {2338, -36., 3.75999999999999978683718, -87244.5784991555256038266,
313      2337.99999999999985553814, -615.235652274809944062375},
314     {2338, -36., 461., -97261.3541997290761812182, 2337.99999999999999859168,
315      -3.26705772262828347042341},
316     {2338, -36., 142311., -99951.6902594622698079577, 2337.99999999999999976424,
317      -0.000133435966326975414974749},
318     {2338, -26.1000000000000014210855, 0.000100000000000000004792174,
319      -39504.990876519910618993, 2337.99989191671703492836,
320      -2.33699905852787203745752e7},
321     {2338, -26.1000000000000014210855, 0.0840000000000000052180482,
322      -55240.2573397297094689766, 2337.99999987132480170888,
323      -27813.2247675754145700052},
324     {2338, -26.1000000000000014210855, 3.75999999999999978683718,
325      -64098.3784991584079601655, 2337.99999999712082169713,
326      -615.235652274045473004386},
327     {2338, -26.1000000000000014210855, 461., -74115.1541997291075706003,
328      2337.99999999997193170739, -3.26705772262823261527954},
329     {2338, -26.1000000000000014210855, 142311., -76805.4902594622778290634,
330      2337.99999999999530115643, -0.000133435966326974881320714},
331     {2338, -4.5, 0.000100000000000000004792174, -37.9188749421642944369079,
332      20.858146249102902750828, -198577.847904345129887008},
333     {2338, -4.5, 0.0840000000000000052180482, -5029.86319434730944140515,
334      2064.90525599567281088845, -24562.2210710608448107488},
335     {2338, -4.5, 3.75999999999999978683718, -13604.4870759968535893564,
336      2331.10160386314082116743, -613.403922615515723001624},
337     {2338, -4.5, 461., -23614.4216481166468801215, 2337.93255242508532760098,
338      -3.26693551284037806579056},
339     {2338, -4.5, 142311., -26304.7015509659310958491, 2337.98870849677942716161,
340      -0.000133434683874347561598764},
341     {2338, -0.20000000000000001110223, 0.000100000000000000004792174,
342      -17.2530056900344775246763, 0.285429102468546297600207,
343      7145.03240130482692470614},
344     {2338, -0.20000000000000001110223, 0.0840000000000000052180482,
345      -238.089236939508565582574, 217.47705608494119565083,
346      -2571.27861575441500369895},
347     {2338, -0.20000000000000001110223, 3.75999999999999978683718,
348      -4005.5119497678293099549, 1919.26585035841237164672,
349      -504.067189818768777258691},
350     {2338, -0.20000000000000001110223, 461., -13565.92078287865980031,
351      2333.03782063120661396226, -3.25806818947358951236853},
352     {2338, -0.20000000000000001110223, 142311., -16252.1224385913391334784,
353      2337.1678232647280123578, -0.000133341466693716057784701},
354     {2338, 2., 0.000100000000000000004792174, -16.9993203274403325176464,
355      0.0315409623602125914052141, 9681.71391098754441853037},
356     {2338, 2., 0.0840000000000000052180482, -36.345344890082340900417,
357      26.196955662584086046202, -296.248198078484197433633},
358     {2338, 2., 3.75999999999999978683718, -945.921470460610089695847,
359      785.994533645635122586109, -203.555176892467693501323},
360     {2338, 2., 461., -8461.86172575984929863877, 2293.84446785934606789742,
361      -3.18717688126682423770834},
362     {2338, 2., 142311., -11115.2005139685862037106, 2330.48994045855231459004,
363      -0.000132584343734508495425318},
364     {2338, 13.4000000000000003552714, 0.000100000000000000004792174,
365      -16.9688195052148641926364, -0.0000996457592914831445917036,
366      9986.7200057158703803923},
367     {2338, 13.4000000000000003552714, 0.0840000000000000052180482,
368      -10.8734665707860518042194, -0.0837024271645438342281827,
369      5.22808308495811843394431},
370     {2338, 13.4000000000000003552714, 3.75999999999999978683718,
371      -25.5045023028520732667892, -3.74665920542777317204743,
372      -4.50627645969873246867534},
373     {2338, 13.4000000000000003552714, 461., -2105.32744445484601087282,
374      -459.046315074133113245574, -4.46701233139860018126931},
375     {2338, 13.4000000000000003552714, 142311., -234622.84254414070827656,
376      -116653.790887199136613496, -0.893479868994494336336804},
377     {2338, 84.2999999999999971578291, 0.000100000000000000004792174,
378      -16.9759091509741556489599, -0.000100000000000000004792174,
379      9915.82354812310706660278},
380     {2338, 84.2999999999999971578291, 0.0840000000000000052180482,
381      -16.8287689979316593747652, -0.0840000000000000052180482,
382      -65.668374254014350440872},
383     {2338, 84.2999999999999971578291, 3.75999999999999978683718,
384      -292.075161470279026145921, -3.75999999999999978683718,
385      -75.4027226791102688967341},
386     {2338, 84.2999999999999971578291, 461., -34788.2730773814556666438, -461.,
387      -75.3620761655318490321712},
388     {2338, 84.2999999999999971578291, 142311., -1.02962292135863589242571e7,
389      -142311., -71.4179345460698058518492},
390     {9611, -36., 0.000100000000000000004792174, -257493.798714111504824811,
391      9610.99999997770706584663, -9.60999902521143419375755e7},
392     {9611, -36., 0.0840000000000000052180482, -322200.978630068417237347,
393      9610.999999999973460561, -114395.144355373590193239},
394     {9611, -36., 3.75999999999999978683718, -358701.177684561134391123,
395      9610.99999999999940687087, -2548.13160475324762913113},
396     {9611, -36., 461., -403079.072217515117383779, 9610.99999999999999493227,
397      -17.7630043020635701626676},
398     {9611, -36., 142311., -424212.312866130535985746, 9610.99999999999999975238,
399      -0.00218253709318427116222024},
400     {9611, -26.1000000000000014210855, 0.000100000000000000004792174,
401      -162344.899158395651111601, 9610.99955569358470701868,
402      -9.60999858092731645759248e7},
403     {9611, -26.1000000000000014210855, 0.0840000000000000052180482,
404      -227052.078630597345206472, 9610.99999947105914950335,
405      -114395.144349077046283779},
406     {9611, -26.1000000000000014210855, 3.75999999999999978683718,
407      -263552.277684572968739345, 9610.99999998817871670057,
408      -2548.13160475010505798557},
409     {9611, -26.1000000000000014210855, 461., -307930.172217515232038494,
410      9610.99999999989899826967, -17.7630043020633611084385},
411     {9611, -26.1000000000000014210855, 142311., -329063.412866130554578654,
412      9610.999999999995064897, -0.00218253709318426896848674},
413     {9611, -4.5, 0.000100000000000000004792174, -104.508893420504154556264,
414      85.743526266719416841998, -847430.234306281205498662},
415     {9611, -4.5, 0.0840000000000000052180482, -20648.2403016158121274787,
416      8488.39853461890805547167, -101030.965402842249733567},
417     {9611, -4.5, 3.75999999999999978683718, -55983.0428066363797886683,
418      9582.67667769505493377765, -2540.60175643222855702327},
419     {9611, -4.5, 461., -100332.814925682181363541, 9610.75729475722022850913,
420      -17.7625019237987339653397},
421     {9611, -4.5, 142311., -121465.824725374770167655, 9610.98814075622869244341,
422      -0.00218253182129617637581483},
423     {9611, -0.20000000000000001110223, 0.000100000000000000004792174,
424      -19.5547487056230065372451, 1.17364684131883001763989,
425      -1735.73121296897564419702},
426     {9611, -0.20000000000000001110223, 0.0840000000000000052180482,
427      -949.731171379774212295606, 894.23698468706897035991,
428      -10626.5306853797828683123},
429     {9611, -0.20000000000000001110223, 3.75999999999999978683718,
430      -16521.513038999314726487, 7891.768160440031828256,
431      -2091.0860971061479068882},
432     {9611, -0.20000000000000001110223, 461., -59023.1441094345908559896,
433      9593.14395476867203137831, -17.7260454286430745149716},
434     {9611, -0.20000000000000001110223, 142311., -80139.3868875039743196636,
435      9610.1259811407200125367, -0.00218214857464869223866328},
436     {9611, 2., 0.000100000000000000004792174, -18.5112200482628704439815,
437      0.129968981779812349901792, 8698.84749102058108657733},
438     {9611, 2., 0.0840000000000000052180482, -119.854424704305068624943,
439      107.948248829969884782069, -1268.06413157597183571962},
440     {9611, 2., 3.75999999999999978683718, -3933.80532649936636751687,
441      3238.80127865994241519564, -854.484838151306300912966},
442     {9611, 2., 461., -38021.2293458402145037005, 9452.10949634845787633542,
443      -17.4342406591600541232224},
444     {9611, 2., 142311., -59002.2007387401537465017, 9603.1123321581496224457,
445      -0.0021790320685329142613448},
446     {9611, 13.4000000000000003552714, 0.000100000000000000004792174,
447      -18.3822916425459253100299, -0.0000985437949788883229584548,
448      9988.12276030177863802081},
449     {9611, 13.4000000000000003552714, 0.0840000000000000052180482,
450      -12.1692487377720201796787, -0.0827767772596333470047606,
451      6.63081050986194646155585},
452     {9611, 13.4000000000000003552714, 3.75999999999999978683718,
453      -21.6460442652836866611064, -3.70522558331286636290617,
454      -3.10473788018826412150356},
455     {9611, 13.4000000000000003552714, 461., -1491.91977767158488399654,
456      -453.969805440038841232377, -3.19739819420502413093782},
457     {9611, 13.4000000000000003552714, 142311., -212175.813403243944037729,
458      -115363.73784931155578397, -0.853487617294708440971187},
459     {9611, 84.2999999999999971578291, 0.000100000000000000004792174,
460      -18.3893801863409040880501, -0.000100000000000000004792174,
461      9917.23732235214127244709},
462     {9611, 84.2999999999999971578291, 0.0840000000000000052180482,
463      -18.1236255149538125593602, -0.0840000000000000052180482,
464      -64.2546271873853975660471},
465     {9611, 84.2999999999999971578291, 3.75999999999999978683718,
466      -288.17526969257348869956, -3.75999999999999978683718,
467      -73.9901645192500912911873},
468     {9611, 84.2999999999999971578291, 461., -34169.7871284528721914203, -461.,
469      -74.0814500768543374677113},
470     {9611, 84.2999999999999971578291, 142311., -1.02723620903825239882665e7,
471      -142311., -71.3688772675071099858592},
472 };
473 
474 }  // namespace neg_binomial_2_log_test_internal
475 
TEST(ProbDistributionsNegBinomial2Log,derivativesPrecomputed)476 TEST(ProbDistributionsNegBinomial2Log, derivativesPrecomputed) {
477   using neg_binomial_2_log_test_internal::TestValue;
478   using neg_binomial_2_log_test_internal::testValues;
479   using stan::math::is_nan;
480   using stan::math::neg_binomial_2_log_lpmf;
481   using stan::math::value_of;
482   using stan::math::var;
483 
484   for (TestValue t : testValues) {
485     int n = t.n;
486     var eta(t.eta);
487     var phi(t.phi);
488     var val = neg_binomial_2_log_lpmf(n, eta, phi);
489 
490     std::vector<var> x;
491     x.push_back(eta);
492     x.push_back(phi);
493 
494     std::vector<double> gradients;
495     val.grad(x, gradients);
496 
497     for (int i = 0; i < 2; ++i) {
498       EXPECT_FALSE(is_nan(gradients[i]));
499     }
500 
501     auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-14); };
502 
503     EXPECT_NEAR(value_of(val), t.value, tolerance(t.value))
504         << "value n = " << n << ", eta = " << t.eta << ", phi = " << t.phi;
505     EXPECT_NEAR(gradients[0], t.grad_eta, tolerance(t.grad_eta))
506         << "grad_mu n = " << n << ", eta = " << t.eta << ", phi = " << t.phi;
507     EXPECT_NEAR(gradients[1], t.grad_phi, tolerance(t.grad_phi))
508         << "grad_phi n = " << n << ", eta = " << t.eta << ", phi = " << t.phi;
509   }
510 }
511 
TEST(ProbDistributionsNegBinomial2Log,derivativesComplexStep)512 TEST(ProbDistributionsNegBinomial2Log, derivativesComplexStep) {
513   using boost::math::differentiation::complex_step_derivative;
514   using stan::math::is_nan;
515   using stan::math::log1p_exp;
516   using stan::math::log_sum_exp;
517   using stan::math::neg_binomial_2_log_lpmf;
518   using stan::math::var;
519 
520   std::vector<int> n_to_test = {0, 7, 100, 835, 14238, 385000, 1000000};
521   std::vector<double> eta_to_test = {-124.5, -13, -2, 0, 0.536844, 1.26845, 11};
522 
523   auto nb2_log_for_test = [](int n, const std::complex<double>& eta,
524                              const std::complex<double>& phi) {
525     // Using first-order Taylor expansion of lgamma(a + b*i) around b = 0
526     // Which happens to work nice in this case, as b is always 0 or the very
527     // small complex step
528     auto lgamma_c_approx = [](const std::complex<double>& x) {
529       return std::complex<double>(lgamma(x.real()),
530                                   x.imag() * boost::math::digamma(x.real()));
531     };
532 
533     const double n_(n);
534     return lgamma_c_approx(n_ + phi) - lgamma(n + 1) - lgamma_c_approx(phi)
535            + phi * (log(phi) - log(exp(eta) + phi)) - n_ * log(exp(eta) + phi)
536            + n_ * eta;
537   };
538 
539   double phi_cutoff = 1e10;
540   for (double eta_dbl : eta_to_test) {
541     for (int n : n_to_test) {
542       for (double phi_dbl = 1.5; phi_dbl < 1e22; phi_dbl *= 10) {
543         var eta(eta_dbl);
544         var phi(phi_dbl);
545         var val = neg_binomial_2_log_lpmf(n, eta, phi);
546 
547         std::vector<var> x;
548         x.push_back(eta);
549         x.push_back(phi);
550 
551         std::vector<double> gradients;
552         val.grad(x, gradients);
553 
554         EXPECT_TRUE(value_of(val) < 0)
555             << "for n = " << n << ", eta = " << eta_dbl
556             << ", phi = " << phi_dbl;
557 
558         for (int i = 0; i < 2; ++i) {
559           EXPECT_FALSE(is_nan(gradients[i]));
560         }
561 
562         auto nb2_log_eta
563             = [n, phi_dbl, nb2_log_for_test](const std::complex<double>& eta) {
564                 return nb2_log_for_test(n, eta, phi_dbl);
565               };
566         auto nb2_log_phi
567             = [n, eta_dbl, nb2_log_for_test](const std::complex<double>& phi) {
568                 return nb2_log_for_test(n, eta_dbl, phi);
569               };
570         double complex_step_deta
571             = complex_step_derivative(nb2_log_eta, eta_dbl);
572         double complex_step_dphi
573             = complex_step_derivative(nb2_log_phi, phi_dbl);
574 
575         std::ostringstream message;
576         message << ", n = " << n << ", eta = " << eta_dbl
577                 << ", phi = " << phi_dbl;
578 
579         double tolerance_phi;
580         double tolerance_eta;
581         if (phi < phi_cutoff || n < 100000) {
582           tolerance_phi = std::max(1e-10, fabs(gradients[1]) * 1e-8);
583         } else {
584           tolerance_phi = std::max(1e-8, fabs(gradients[1]) * 1e-5);
585         }
586 
587         if (phi < phi_cutoff) {
588           tolerance_eta = std::max(1e-10, fabs(gradients[0]) * 1e-8);
589         } else {
590           tolerance_eta = std::max(1e-8, fabs(gradients[0]) * 1e-5);
591         }
592 
593         EXPECT_NEAR(gradients[0], complex_step_deta, tolerance_eta)
594             << "grad_mu" << message.str();
595 
596         EXPECT_NEAR(gradients[1], complex_step_dphi, tolerance_phi)
597             << "grad_phi" << message.str();
598       }
599     }
600   }
601 }
602 
TEST(ProbDistributionsNegBinomial2Log,derivativesZeroOne)603 TEST(ProbDistributionsNegBinomial2Log, derivativesZeroOne) {
604   using stan::math::log1p_exp;
605   using stan::math::log_diff_exp;
606   using stan::math::log_sum_exp;
607   using stan::math::var;
608   using stan::test::expect_near_rel;
609   using stan::test::relative_tolerance;
610 
611   std::vector<double> eta_to_test = {-943, -130, -15, -1, 0, 0.138, 14, 611};
612   double phi_start = 1e-8;
613   double phi_max = 1e20;
614   for (double eta_dbl : eta_to_test) {
615     for (double phi_dbl = phi_start; phi_dbl < phi_max;
616          phi_dbl *= stan::math::pi()) {
617       std::stringstream msg;
618       msg << std::setprecision(20) << ", eta = " << eta_dbl
619           << ", phi = " << phi_dbl;
620 
621       var eta0(eta_dbl);
622       var phi0(phi_dbl);
623       var val0 = neg_binomial_2_log_lpmf(0, eta0, phi0);
624 
625       std::vector<var> x0;
626       x0.push_back(eta0);
627       x0.push_back(phi0);
628 
629       std::vector<double> gradients0;
630       val0.grad(x0, gradients0);
631 
632       var eta1(eta_dbl);
633       var phi1(phi_dbl);
634       var val1 = neg_binomial_2_log_lpmf(1, eta1, phi1);
635 
636       std::vector<var> x1;
637       x1.push_back(eta1);
638       x1.push_back(phi1);
639 
640       std::vector<double> gradients1;
641       val1.grad(x1, gradients1);
642 
643       double expected_value_0 = phi_dbl * (-log1p_exp(eta_dbl - log(phi_dbl)));
644       expect_near_rel("value, n = 0 " + msg.str(), val0.val(),
645                       expected_value_0);
646 
647       double expected_deta_0 = -phi_dbl / (1 + phi_dbl / exp(eta_dbl));
648       expect_near_rel("deta, n = 0 " + msg.str(), gradients0[0],
649                       expected_deta_0);
650 
651       double expected_dphi_0 = 1.0 / (1.0 + phi_dbl / exp(eta_dbl))
652                                - log1p_exp(eta_dbl - log(phi_dbl));
653       expect_near_rel("dphi, n = 0 " + msg.str(), gradients0[1],
654                       expected_dphi_0);
655 
656       double expected_value_1
657           = (phi_dbl + 1) * (-log1p_exp(eta_dbl - log(phi_dbl))) + eta_dbl;
658       expect_near_rel("value, n = 1 " + msg.str(), val1.val(),
659                       expected_value_1);
660 
661       double expected_deta_1
662           = exp(log(phi_dbl) - log_sum_exp(eta_dbl, log(phi_dbl)))
663             + expected_deta_0;
664       expect_near_rel("deta, n = 1 " + msg.str(), gradients1[0],
665                       expected_deta_1);
666 
667       double expected_dphi_1
668           = (1 + phi_dbl) / (phi_dbl + (phi_dbl * phi_dbl) / exp(eta_dbl))
669             - log1p_exp(eta_dbl - log(phi_dbl));
670       expect_near_rel("dphi, n = 1 " + msg.str(), gradients1[1],
671                       expected_dphi_1);
672     }
673   }
674 }
675 
TEST(ProbDistributionsNegBinomial2Log,derivatives_diff_sizes)676 TEST(ProbDistributionsNegBinomial2Log, derivatives_diff_sizes) {
677   using stan::math::neg_binomial_2_log_lpmf;
678   using stan::math::var;
679 
680   int N = 100;
681   double eta_dbl = 1.5;
682   std::vector<double> phi_dbl{2, 4, 6, 8};
683 
684   var mu(eta_dbl);
685   std::vector<var> phi;
686   for (double i : phi_dbl) {
687     phi.push_back(var(i));
688   }
689   var val = neg_binomial_2_log_lpmf(N, mu, phi);
690 
691   std::vector<var> x{mu};
692   std::vector<double> gradients;
693   val.grad(x, gradients);
694 
695   double eps = 1e-6;
696   double grad_diff = (neg_binomial_2_log_lpmf(N, eta_dbl + eps, phi_dbl)
697                       - neg_binomial_2_log_lpmf(N, eta_dbl - eps, phi_dbl))
698                      / (2 * eps);
699   EXPECT_FLOAT_EQ(grad_diff, gradients[0]);
700 }
701