1 #include <stan/math/rev.hpp>
2 #include <test/unit/math/rev/util.hpp>
3 #include <test/unit/math/rev/prob/expect_eq_diffs.hpp>
4 #include <gtest/gtest.h>
5 #include <string>
6 #include <vector>
7 
8 template <typename T_prob>
expect_propto_multinomial_logit_lpmf(std::vector<int> & ns1,T_prob beta1,std::vector<int> & ns2,T_prob beta2,std::string message)9 void expect_propto_multinomial_logit_lpmf(std::vector<int>& ns1, T_prob beta1,
10                                           std::vector<int>& ns2, T_prob beta2,
11                                           std::string message) {
12   expect_eq_diffs(stan::math::multinomial_logit_lpmf<false>(ns1, beta1),
13                   stan::math::multinomial_logit_lpmf<false>(ns2, beta2),
14                   stan::math::multinomial_logit_lpmf<true>(ns1, beta1),
15                   stan::math::multinomial_logit_lpmf<true>(ns2, beta2),
16                   message);
17 }
18 
TEST(AgradDistributionsMultinomialLogit,Propto)19 TEST(AgradDistributionsMultinomialLogit, Propto) {
20   using Eigen::Dynamic;
21   using Eigen::Matrix;
22   using stan::math::var;
23   std::vector<int> ns;
24   ns.push_back(1);
25   ns.push_back(2);
26   ns.push_back(3);
27   Matrix<var, Dynamic, 1> beta1(3, 1);
28   beta1 << log(0.3), log(0.5), log(0.2);
29   Matrix<var, Dynamic, 1> beta2(3, 1);
30   beta2 << log(0.1), log(0.2), log(0.7);
31 
32   expect_propto_multinomial_logit_lpmf(ns, beta1, ns, beta2, "var: beta");
33 }
34 
TEST(AgradDistributionsMultinomialLogit,check_varis_on_stack)35 TEST(AgradDistributionsMultinomialLogit, check_varis_on_stack) {
36   using Eigen::Dynamic;
37   using Eigen::Matrix;
38   using stan::math::var;
39   std::vector<int> ns;
40   ns.push_back(1);
41   ns.push_back(2);
42   ns.push_back(3);
43   Matrix<var, Dynamic, 1> beta(3, 1);
44   beta << log(0.3), log(0.5), log(0.2);
45 
46   test::check_varis_on_stack(
47       stan::math::multinomial_logit_lpmf<false>(ns, beta));
48   test::check_varis_on_stack(
49       stan::math::multinomial_logit_lpmf<true>(ns, beta));
50 }
51