1 
2 #include <stan/math/rev.hpp>
3 #include <gtest/gtest.h>
4 
TEST(MathFunctions,ReturnVarMatrix)5 TEST(MathFunctions, ReturnVarMatrix) {
6   using stan::return_var_matrix_t;
7   using stan::math::var;
8   using stan::math::var_value;
9   using std::is_same;
10   using var_matrix = var_value<Eigen::MatrixXd>;
11   using var_vector = var_value<Eigen::VectorXd>;
12   using var_row_vector = var_value<Eigen::RowVectorXd>;
13   using matrix_var = Eigen::Matrix<var, -1, -1>;
14   using vector_var = Eigen::Matrix<var, -1, 1>;
15   using row_vector_var = Eigen::Matrix<var, 1, -1>;
16   var_matrix A_vm(Eigen::MatrixXd::Zero(10, 10));
17   matrix_var A_mv(Eigen::MatrixXd::Zero(10, 10));
18 
19   EXPECT_TRUE((is_same<var_value<Eigen::MatrixXd>,
20                        return_var_matrix_t<Eigen::MatrixXd, var_matrix,
21                                            matrix_var>>::value));
22   EXPECT_TRUE((is_same<var_value<Eigen::MatrixXd>,
23                        return_var_matrix_t<Eigen::MatrixXd, var_matrix,
24                                            vector_var, var>>::value));
25 
26   EXPECT_TRUE(
27       (is_same<var_value<Eigen::MatrixXd>,
28                return_var_matrix_t<decltype(A_vm.block(0, 0, 2, 2))>>::value));
29 
30   EXPECT_TRUE(
31       (is_same<matrix_var,
32                return_var_matrix_t<decltype(A_mv.block(0, 0, 2, 2))>>::value));
33 
34   EXPECT_TRUE(
35       (is_same<matrix_var, return_var_matrix_t<decltype(A_mv * A_mv)>>::value));
36 
37   EXPECT_TRUE((is_same<var_value<Eigen::MatrixXd>,
38                        return_var_matrix_t<Eigen::MatrixXd, var_vector,
39                                            vector_var, double>>::value));
40 
41   EXPECT_TRUE((is_same<var_value<Eigen::VectorXd>,
42                        return_var_matrix_t<Eigen::VectorXd, var_matrix,
43                                            matrix_var>>::value));
44   EXPECT_TRUE((is_same<var_value<Eigen::VectorXd>,
45                        return_var_matrix_t<Eigen::VectorXd, var_matrix,
46                                            vector_var, var>>::value));
47   EXPECT_TRUE((is_same<var_value<Eigen::VectorXd>,
48                        return_var_matrix_t<Eigen::VectorXd, var_vector,
49                                            vector_var, double>>::value));
50 
51   EXPECT_TRUE((is_same<var_value<Eigen::RowVectorXd>,
52                        return_var_matrix_t<Eigen::RowVectorXd, var_matrix,
53                                            matrix_var>>::value));
54   EXPECT_TRUE((is_same<var_value<Eigen::RowVectorXd>,
55                        return_var_matrix_t<Eigen::RowVectorXd, var_matrix,
56                                            vector_var, var>>::value));
57   EXPECT_TRUE((is_same<var_value<Eigen::RowVectorXd>,
58                        return_var_matrix_t<Eigen::RowVectorXd, var_vector,
59                                            row_vector_var, double>>::value));
60 
61   EXPECT_TRUE((is_same<Eigen::Matrix<var, -1, -1>,
62                        return_var_matrix_t<Eigen::MatrixXd, vector_var,
63                                            vector_var, double>>::value));
64   EXPECT_TRUE((is_same<Eigen::Matrix<var, 1, -1>,
65                        return_var_matrix_t<Eigen::RowVectorXd, vector_var,
66                                            row_vector_var, double>>::value));
67 }
68