1 #include <stan/math/rev.hpp>
2 #include <test/unit/util.hpp>
3 #include <gtest/gtest.h>
4
TEST(AgradRevMatrix,read_var_mat)5 TEST(AgradRevMatrix, read_var_mat) {
6 using Eigen::MatrixXd;
7 using stan::math::matrix_v;
8 using stan::math::matrix_vi;
9 using stan::math::read_val_adj;
10 using stan::math::read_vi_adj;
11 using stan::math::read_vi_val;
12 using stan::math::read_vi_val_adj;
13
14 matrix_v matrix_var(100, 100);
15 matrix_vi matrix_vari(100, 100);
16 MatrixXd matrix_val(100, 100), matrix_deriv(100, 100);
17 matrix_var = MatrixXd::Random(100, 100);
18 matrix_var.adj() = MatrixXd::Random(100, 100);
19
20 read_vi_val_adj(matrix_var, matrix_vari, matrix_val, matrix_deriv);
21 EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_val);
22 EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_deriv);
23
24 matrix_val.setZero();
25 matrix_deriv.setZero();
26 read_val_adj(matrix_var, matrix_val, matrix_deriv);
27 EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_val);
28 EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_deriv);
29
30 matrix_val.setZero();
31 matrix_deriv.setZero();
32 read_val_adj(matrix_vari, matrix_val, matrix_deriv);
33 EXPECT_MATRIX_FLOAT_EQ(matrix_vari.val(), matrix_val);
34 EXPECT_MATRIX_FLOAT_EQ(matrix_vari.adj(), matrix_deriv);
35
36 matrix_val.setZero();
37 matrix_vi matrix_vi2(100, 100);
38 read_vi_val(matrix_var, matrix_vi2, matrix_val);
39 EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_val);
40 EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_vi2.val());
41 EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_vi2.adj());
42
43 matrix_deriv.setZero();
44 matrix_vi matrix_vi3(100, 100);
45 read_vi_adj(matrix_var, matrix_vi3, matrix_deriv);
46 EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_deriv);
47 EXPECT_MATRIX_FLOAT_EQ(matrix_var.val(), matrix_vi3.val());
48 EXPECT_MATRIX_FLOAT_EQ(matrix_var.adj(), matrix_vi3.adj());
49 }
50
TEST(AgradRevMatrix,read_var_vec)51 TEST(AgradRevMatrix, read_var_vec) {
52 using Eigen::VectorXd;
53 using stan::math::read_val_adj;
54 using stan::math::read_vi_adj;
55 using stan::math::read_vi_val;
56 using stan::math::read_vi_val_adj;
57 using stan::math::vector_v;
58 using stan::math::vector_vi;
59
60 vector_v vector_var(100);
61 vector_vi vector_vari(100);
62 VectorXd vector_val(100), vector_deriv(100);
63 vector_var = VectorXd::Random(100);
64 vector_var.adj() = VectorXd::Random(100);
65
66 read_vi_val_adj(vector_var, vector_vari, vector_val, vector_deriv);
67 EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_val);
68 EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_deriv);
69
70 vector_val.setZero();
71 vector_deriv.setZero();
72 read_val_adj(vector_vari, vector_val, vector_deriv);
73 EXPECT_MATRIX_FLOAT_EQ(vector_vari.val(), vector_val);
74 EXPECT_MATRIX_FLOAT_EQ(vector_vari.adj(), vector_deriv);
75
76 vector_val.setZero();
77 vector_vi vector_vi2(100);
78 read_vi_val(vector_var, vector_vi2, vector_val);
79 EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_val);
80 EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_vi2.val());
81 EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_vi2.adj());
82
83 vector_deriv.setZero();
84 vector_vi vector_vi3(100);
85 read_vi_adj(vector_var, vector_vi3, vector_deriv);
86 EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_deriv);
87 EXPECT_MATRIX_FLOAT_EQ(vector_var.val(), vector_vi3.val());
88 EXPECT_MATRIX_FLOAT_EQ(vector_var.adj(), vector_vi3.adj());
89 }
90
TEST(AgradRevMatrix,read_var_rowvec)91 TEST(AgradRevMatrix, read_var_rowvec) {
92 using Eigen::RowVectorXd;
93 using stan::math::read_val_adj;
94 using stan::math::read_vi_adj;
95 using stan::math::read_vi_val;
96 using stan::math::read_vi_val_adj;
97 using stan::math::row_vector_v;
98 using stan::math::row_vector_vi;
99
100 row_vector_v row_vector_var(100);
101 row_vector_vi row_vector_vari(100);
102 RowVectorXd row_vector_val(100), row_vector_deriv(100);
103 row_vector_var = RowVectorXd::Random(100);
104 row_vector_var.adj() = RowVectorXd::Random(100);
105
106 read_vi_val_adj(row_vector_var, row_vector_vari, row_vector_val,
107 row_vector_deriv);
108 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_val);
109 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_deriv);
110
111 row_vector_val.setZero();
112 row_vector_deriv.setZero();
113 read_val_adj(row_vector_vari, row_vector_val, row_vector_deriv);
114 EXPECT_MATRIX_FLOAT_EQ(row_vector_vari.val(), row_vector_val);
115 EXPECT_MATRIX_FLOAT_EQ(row_vector_vari.adj(), row_vector_deriv);
116
117 row_vector_val.setZero();
118 row_vector_vi row_vector_vi2(100);
119 read_vi_val(row_vector_var, row_vector_vi2, row_vector_val);
120 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_val);
121 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_vi2.val());
122 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_vi2.adj());
123
124 row_vector_deriv.setZero();
125 row_vector_vi row_vector_vi3(100);
126 read_vi_adj(row_vector_var, row_vector_vi3, row_vector_deriv);
127 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_deriv);
128 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.val(), row_vector_vi3.val());
129 EXPECT_MATRIX_FLOAT_EQ(row_vector_var.adj(), row_vector_vi3.adj());
130 }
131
TEST(AgradRevMatrix,read_var_expr)132 TEST(AgradRevMatrix, read_var_expr) {
133 using Eigen::MatrixXd;
134 using Eigen::VectorXd;
135 using stan::math::matrix_v;
136 using stan::math::matrix_vi;
137 using stan::math::read_val_adj;
138 using stan::math::read_vi_adj;
139 using stan::math::read_vi_val;
140 using stan::math::read_vi_val_adj;
141 using stan::math::vector_vi;
142
143 matrix_v matrix_var(100, 100);
144 matrix_vi matrix_vari(100, 100);
145 vector_vi vector_vari(100);
146 VectorXd vector_val(100), vector_deriv(100);
147 matrix_var = MatrixXd::Random(100, 100);
148 matrix_var.adj() = MatrixXd::Random(100, 100);
149 matrix_vari = matrix_var.vi();
150
151 read_vi_val_adj(matrix_var.diagonal(), vector_vari, vector_val, vector_deriv);
152 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
153 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
154
155 vector_val.setZero();
156 vector_deriv.setZero();
157 read_val_adj(matrix_var.diagonal(), vector_val, vector_deriv);
158 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
159 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
160
161 vector_val.setZero();
162 vector_deriv.setZero();
163 read_val_adj(matrix_vari.diagonal(), vector_val, vector_deriv);
164 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
165 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
166
167 vector_val.setZero();
168 vector_vi vector_vari2(100);
169 read_vi_val(matrix_var.diagonal(), vector_vari2, vector_val);
170 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_val);
171 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_vari2.val());
172 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_vari2.adj());
173
174 vector_deriv.setZero();
175 vector_vi vector_vari3(100);
176 read_vi_adj(matrix_var.diagonal(), vector_vari3, vector_deriv);
177 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_deriv);
178 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().val(), vector_vari3.val());
179 EXPECT_MATRIX_FLOAT_EQ(matrix_var.diagonal().adj(), vector_vari3.adj());
180 }
181