1 #ifdef STAN_OPENCL
2 
3 #include <stan/math/opencl/kernel_generator.hpp>
4 #include <stan/math/opencl/matrix_cl.hpp>
5 #include <stan/math/opencl/copy.hpp>
6 #include <test/unit/math/opencl/kernel_generator/reference_kernel.hpp>
7 #include <test/unit/util.hpp>
8 #include <Eigen/Dense>
9 #include <gtest/gtest.h>
10 #include <string>
11 
12 using Eigen::Matrix;
13 using Eigen::MatrixXd;
14 using Eigen::MatrixXi;
15 using stan::math::matrix_cl;
16 
TEST(KernelGenerator,cast_zero_size)17 TEST(KernelGenerator, cast_zero_size) {
18   matrix_cl<double> m_zero(0, 0);
19 
20   EXPECT_NO_THROW(stan::math::cast<int>(m_zero));
21 }
22 
TEST(KernelGenerator,cast_test)23 TEST(KernelGenerator, cast_test) {
24   using stan::math::cast;
25   MatrixXd m1(2, 3);
26   m1 << 1, 2.5, 3.4, 4.7, 5.9, 6.3;
27 
28   matrix_cl<double> m1_cl(m1);
29 
30   matrix_cl<int> res_cl = cast<int>(m1_cl);
31 
32   MatrixXi res = stan::math::from_matrix_cl(res_cl);
33 
34   MatrixXi correct = m1.cast<int>();
35   EXPECT_MATRIX_NEAR(res, correct, 1e-9);
36 }
37 
TEST(KernelGenerator,cast_multiple_operations)38 TEST(KernelGenerator, cast_multiple_operations) {
39   using stan::math::cast;
40   MatrixXd m1(2, 3);
41   m1 << 1, 2.5, 3.4, 4.7, 5.9, 6.3;
42 
43   matrix_cl<double> m1_cl(m1);
44   auto tmp = cast<double>(cast<int>(m1_cl));
45   matrix_cl<double> res_cl = tmp;
46 
47   MatrixXd res = stan::math::from_matrix_cl(res_cl);
48 
49   MatrixXd correct = m1.cast<int>().template cast<double>();
50   EXPECT_MATRIX_NEAR(res, correct, 1e-9);
51 }
52 
TEST(KernelGenerator,cast_multiple_operations_lvalue)53 TEST(KernelGenerator, cast_multiple_operations_lvalue) {
54   using stan::math::cast;
55   MatrixXd m1(2, 3);
56   m1 << 1, 2.5, 3.4, 4.7, 5.9, 6.3;
57 
58   matrix_cl<double> m1_cl(m1);
59   auto tmp = cast<int>(m1_cl);
60   matrix_cl<double> res_cl = cast<double>(tmp);
61 
62   MatrixXd res = stan::math::from_matrix_cl(res_cl);
63 
64   MatrixXd correct = m1.cast<int>().template cast<double>();
65   EXPECT_MATRIX_NEAR(res, correct, 1e-9);
66 }
67 
68 #endif
69