1 #include <stan/math/rev.hpp>
2 #include <test/unit/util.hpp>
3 #include <gtest/gtest.h>
4 
TEST(AgradRevMatrix,read_var_mat)5 TEST(AgradRevMatrix, read_var_mat) {
6   using Eigen::MatrixXd;
7   using stan::math::matrix_v;
8   using stan::math::matrix_vi;
9   using stan::math::read_val_adj;
10   using stan::math::read_vi_adj;
11   using stan::math::read_vi_val;
12   using stan::math::read_vi_val_adj;
13 
14   matrix_v matrix_var(100, 100);
15   matrix_vi matrix_vari(100, 100);
16   MatrixXd matrix_val(100, 100), matrix_deriv(100, 100);
17   matrix_var = MatrixXd::Random(100, 100);
18   matrix_var.adj() = MatrixXd::Random(100, 100);
19 
20   read_vi_val_adj(matrix_var, matrix_vari, matrix_val, matrix_deriv);
21   EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_val);
22   EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_deriv);
23 
24   matrix_val.setZero();
25   matrix_deriv.setZero();
26   read_val_adj(matrix_var, matrix_val, matrix_deriv);
27   EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_val);
28   EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_deriv);
29 
30   matrix_val.setZero();
31   matrix_deriv.setZero();
32   read_val_adj(matrix_vari, matrix_val, matrix_deriv);
33   EXPECT_MATRIX_FLOAT_EQ(matrix_vari.val(), matrix_val);
34   EXPECT_MATRIX_FLOAT_EQ(matrix_vari.adj(), matrix_deriv);
35 
36   matrix_val.setZero();
37   matrix_vi matrix_vi2(100, 100);
38   read_vi_val(matrix_var, matrix_vi2, matrix_val);
39   EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_val);
40   EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_vi2.val());
41   EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_vi2.adj());
42 
43   matrix_deriv.setZero();
44   matrix_vi matrix_vi3(100, 100);
45   read_vi_adj(matrix_var, matrix_vi3, matrix_deriv);
46   EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_deriv);
47   EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_vi3.val());
48   EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_vi3.adj());
49 }
50 
TEST(AgradRevMatrix,read_var_vec)51 TEST(AgradRevMatrix, read_var_vec) {
52   using Eigen::VectorXd;
53   using stan::math::read_val_adj;
54   using stan::math::read_vi_adj;
55   using stan::math::read_vi_val;
56   using stan::math::read_vi_val_adj;
57   using stan::math::vector_v;
58   using stan::math::vector_vi;
59 
60   vector_v vector_var(100);
61   vector_vi vector_vari(100);
62   VectorXd vector_val(100), vector_deriv(100);
63   vector_var = VectorXd::Random(100);
64   vector_var.adj() = VectorXd::Random(100);
65 
66   read_vi_val_adj(vector_var, vector_vari, vector_val, vector_deriv);
67   EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_val);
68   EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_deriv);
69 
70   vector_val.setZero();
71   vector_deriv.setZero();
72   read_val_adj(vector_vari, vector_val, vector_deriv);
73   EXPECT_MATRIX_FLOAT_EQ(vector_vari.val(), vector_val);
74   EXPECT_MATRIX_FLOAT_EQ(vector_vari.adj(), vector_deriv);
75 
76   vector_val.setZero();
77   vector_vi vector_vi2(100);
78   read_vi_val(vector_var, vector_vi2, vector_val);
79   EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_val);
80   EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_vi2.val());
81   EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_vi2.adj());
82 
83   vector_deriv.setZero();
84   vector_vi vector_vi3(100);
85   read_vi_adj(vector_var, vector_vi3, vector_deriv);
86   EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_deriv);
87   EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_vi3.val());
88   EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_vi3.adj());
89 }
90 
TEST(AgradRevMatrix,read_var_rowvec)91 TEST(AgradRevMatrix, read_var_rowvec) {
92   using Eigen::RowVectorXd;
93   using stan::math::read_val_adj;
94   using stan::math::read_vi_adj;
95   using stan::math::read_vi_val;
96   using stan::math::read_vi_val_adj;
97   using stan::math::row_vector_v;
98   using stan::math::row_vector_vi;
99 
100   row_vector_v row_vector_var(100);
101   row_vector_vi row_vector_vari(100);
102   RowVectorXd row_vector_val(100), row_vector_deriv(100);
103   row_vector_var = RowVectorXd::Random(100);
104   row_vector_var.adj() = RowVectorXd::Random(100);
105 
106   read_vi_val_adj(row_vector_var, row_vector_vari, row_vector_val,
107                   row_vector_deriv);
108   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_val);
109   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_deriv);
110 
111   row_vector_val.setZero();
112   row_vector_deriv.setZero();
113   read_val_adj(row_vector_vari, row_vector_val, row_vector_deriv);
114   EXPECT_MATRIX_FLOAT_EQ(row_vector_vari.val(), row_vector_val);
115   EXPECT_MATRIX_FLOAT_EQ(row_vector_vari.adj(), row_vector_deriv);
116 
117   row_vector_val.setZero();
118   row_vector_vi row_vector_vi2(100);
119   read_vi_val(row_vector_var, row_vector_vi2, row_vector_val);
120   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_val);
121   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_vi2.val());
122   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_vi2.adj());
123 
124   row_vector_deriv.setZero();
125   row_vector_vi row_vector_vi3(100);
126   read_vi_adj(row_vector_var, row_vector_vi3, row_vector_deriv);
127   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_deriv);
128   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_vi3.val());
129   EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_vi3.adj());
130 }
131 
TEST(AgradRevMatrix,read_var_expr)132 TEST(AgradRevMatrix, read_var_expr) {
133   using Eigen::MatrixXd;
134   using Eigen::VectorXd;
135   using stan::math::matrix_v;
136   using stan::math::matrix_vi;
137   using stan::math::read_val_adj;
138   using stan::math::read_vi_adj;
139   using stan::math::read_vi_val;
140   using stan::math::read_vi_val_adj;
141   using stan::math::vector_vi;
142 
143   matrix_v matrix_var(100, 100);
144   matrix_vi matrix_vari(100, 100);
145   vector_vi vector_vari(100);
146   VectorXd vector_val(100), vector_deriv(100);
147   matrix_var = MatrixXd::Random(100, 100);
148   matrix_var.adj() = MatrixXd::Random(100, 100);
149   matrix_vari = matrix_var.vi();
150 
151   read_vi_val_adj(matrix_var.diagonal(), vector_vari, vector_val, vector_deriv);
152   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
153   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
154 
155   vector_val.setZero();
156   vector_deriv.setZero();
157   read_val_adj(matrix_var.diagonal(), vector_val, vector_deriv);
158   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
159   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
160 
161   vector_val.setZero();
162   vector_deriv.setZero();
163   read_val_adj(matrix_vari.diagonal(), vector_val, vector_deriv);
164   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
165   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
166 
167   vector_val.setZero();
168   vector_vi vector_vari2(100);
169   read_vi_val(matrix_var.diagonal(), vector_vari2, vector_val);
170   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
171   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_vari2.val());
172   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_vari2.adj());
173 
174   vector_deriv.setZero();
175   vector_vi vector_vari3(100);
176   read_vi_adj(matrix_var.diagonal(), vector_vari3, vector_deriv);
177   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
178   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_vari3.val());
179   EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_vari3.adj());
180 }
181