1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and         *
3 * Martin Renou                                                             *
4 * Copyright (c) QuantStack                                                 *
5 *                                                                          *
6 * Distributed under the terms of the BSD 3-Clause License.                 *
7 *                                                                          *
8 * The full license is in the file LICENSE, distributed with this software. *
9 ****************************************************************************/
10 
11 #include "test_utils.hpp"
12 
13 template <class B>
14 class complex_power_test : public testing::Test
15 {
16 protected:
17 
18     using batch_type = B;
19     using real_batch_type = typename B::real_batch;
20     using value_type = typename B::value_type;
21     using real_value_type = typename value_type::value_type;
22     static constexpr size_t size = B::size;
23     using vector_type = std::vector<value_type>;
24     using real_vector_type = std::vector<real_value_type>;
25 
26     size_t nb_input;
27     vector_type lhs_nn;
28     vector_type lhs_pn;
29     vector_type lhs_np;
30     vector_type lhs_pp;
31     vector_type rhs;
32     vector_type expected;
33     vector_type res;
34 
complex_power_test()35     complex_power_test()
36     {
37         nb_input = 10000 * size;
38         lhs_nn.resize(nb_input);
39         lhs_pn.resize(nb_input);
40         lhs_np.resize(nb_input);
41         lhs_pp.resize(nb_input);
42         rhs.resize(nb_input);
43         for (size_t i = 0; i < nb_input; ++i)
44         {
45             real_value_type real = (real_value_type(i) / 4 + real_value_type(1.2) * std::sqrt(real_value_type(i + 0.25)))/ 100;
46             real_value_type imag = (real_value_type(i) / 7 + real_value_type(1.7) * std::sqrt(real_value_type(i + 0.37))) / 100;
47             lhs_nn[i] = value_type(-real, -imag);
48             lhs_pn[i] = value_type(real, -imag);
49             lhs_np[i] = value_type(-real, imag);
50             lhs_pp[i] = value_type(real, imag);
51             rhs[i] = value_type(real_value_type(10.2) / (i + 2) + real_value_type(0.25),
52                                 real_value_type(9.1) / (i + 3) + real_value_type(0.45));
53         }
54         expected.resize(nb_input);
55         res.resize(nb_input);
56     }
57 
test_abs()58     void test_abs()
59     {
60         real_vector_type real_expected(nb_input), real_res(nb_input);
61         std::transform(lhs_np.cbegin(), lhs_np.cend(), real_expected.begin(),
62                     [](const value_type& v) { using std::abs; return abs(v); });
63         batch_type in;
64         real_batch_type out;
65         for (size_t i = 0; i < nb_input; i += size)
66         {
67             detail::load_batch(in, lhs_np, i);
68             out = abs(in);
69             detail::store_batch(out, real_res, i);
70         }
71         size_t diff = detail::get_nb_diff(real_res, real_expected);
72         EXPECT_EQ(diff, 0) << print_function_name("abs");
73     }
74 
test_arg()75     void test_arg()
76     {
77         real_vector_type real_expected(nb_input), real_res(nb_input);
78         std::transform(lhs_np.cbegin(), lhs_np.cend(), real_expected.begin(),
79                     [](const value_type& v) { using std::arg; return arg(v); });
80         batch_type in;
81         real_batch_type out;
82         for (size_t i = 0; i < nb_input; i += size)
83         {
84             detail::load_batch(in, lhs_np, i);
85             out = arg(in);
86             detail::store_batch(out, real_res, i);
87         }
88         size_t diff = detail::get_nb_diff(real_res, real_expected);
89         EXPECT_EQ(diff, 0) << print_function_name("arg");
90     }
91 
test_pow()92     void test_pow()
93     {
94         test_conditional_pow<real_value_type>();
95     }
96 
test_sqrt_nn()97     void test_sqrt_nn()
98     {
99         std::transform(lhs_nn.cbegin(), lhs_nn.cend(), expected.begin(),
100                     [](const value_type& v) { using std::sqrt; return sqrt(v); });
101         batch_type in, out;
102         for (size_t i = 0; i < nb_input; i += size)
103         {
104             detail::load_batch(in, lhs_nn, i);
105             out = sqrt(in);
106             detail::store_batch(out, res, i);
107         }
108         size_t diff = detail::get_nb_diff(res, expected);
109         EXPECT_EQ(diff, 0) << print_function_name("sqrt_nn");
110     }
111 
test_sqrt_pn()112     void test_sqrt_pn()
113     {
114         std::transform(lhs_pn.cbegin(), lhs_pn.cend(), expected.begin(),
115                     [](const value_type& v) { using std::sqrt; return sqrt(v); });
116         batch_type in, out;
117         for (size_t i = 0; i < nb_input; i += size)
118         {
119             detail::load_batch(in, lhs_pn, i);
120             out = sqrt(in);
121             detail::store_batch(out, res, i);
122         }
123         size_t diff = detail::get_nb_diff(res, expected);
124         EXPECT_EQ(diff, 0) << print_function_name("sqrt_pn");
125     }
126 
test_sqrt_np()127     void test_sqrt_np()
128     {
129         std::transform(lhs_np.cbegin(), lhs_np.cend(), expected.begin(),
130                     [](const value_type& v) { using std::sqrt; return sqrt(v); });
131         batch_type in, out;
132         for (size_t i = 0; i < nb_input; i += size)
133         {
134             detail::load_batch(in, lhs_np, i);
135             out = sqrt(in);
136             detail::store_batch(out, res, i);
137         }
138         size_t diff = detail::get_nb_diff(res, expected);
139         EXPECT_EQ(diff, 0) << print_function_name("sqrt_nn");
140     }
141 
test_sqrt_pp()142     void test_sqrt_pp()
143     {
144         std::transform(lhs_pp.cbegin(), lhs_pp.cend(), expected.begin(),
145                     [](const value_type& v) { using std::sqrt; return sqrt(v); });
146         batch_type in, out;
147         for (size_t i = 0; i < nb_input; i += size)
148         {
149             detail::load_batch(in, lhs_pp, i);
150             out = sqrt(in);
151             detail::store_batch(out, res, i);
152         }
153         size_t diff = detail::get_nb_diff(res, expected);
154         EXPECT_EQ(diff, 0) << print_function_name("sqrt_pp");
155     }
156 
157 private:
158 
test_pow_impl()159     void test_pow_impl()
160     {
161         std::transform(lhs_np.cbegin(), lhs_np.cend(), rhs.cbegin(), expected.begin(),
162                     [](const value_type& l, const value_type& r) { using std::pow; return pow(l, r); });
163         batch_type lhs_in, rhs_in, out;
164         for (size_t i = 0; i < nb_input; i += size)
165         {
166             detail::load_batch(lhs_in, lhs_np, i);
167             detail::load_batch(rhs_in, rhs, i);
168             out = pow(lhs_in, rhs_in);
169             detail::store_batch(out, res, i);
170         }
171         size_t diff = detail::get_nb_diff(res, expected);
172         EXPECT_EQ(diff, 0) << print_function_name("pow");
173     }
174 
175     template <class T, typename std::enable_if<!std::is_same<T, float>::value, int>::type = 0>
test_conditional_pow()176     void test_conditional_pow()
177     {
178         test_pow_impl();
179     }
180 
181     template <class T, typename std::enable_if<std::is_same<T, float>::value, int>::type = 0>
test_conditional_pow()182     void test_conditional_pow()
183     {
184 
185 #if (XSIMD_X86_INSTR_SET >= XSIMD_X86_AVX512_VERSION) || (XSIMD_ARM_INSTR_SET >= XSIMD_ARM7_NEON_VERSION)
186 #if DEBUG_ACCURACY
187         test_pow_impl();
188 #endif
189 #else
190         test_pow_impl();
191 #endif
192     }
193 };
194 
195 TYPED_TEST_SUITE(complex_power_test, batch_complex_types, simd_test_names);
196 
TYPED_TEST(complex_power_test,abs)197 TYPED_TEST(complex_power_test, abs)
198 {
199     this->test_abs();
200 }
201 
TYPED_TEST(complex_power_test,arg)202 TYPED_TEST(complex_power_test, arg)
203 {
204     this->test_arg();
205 }
206 
TYPED_TEST(complex_power_test,pow)207 TYPED_TEST(complex_power_test, pow)
208 {
209     this->test_pow();
210 }
211 
TYPED_TEST(complex_power_test,sqrt_nn)212 TYPED_TEST(complex_power_test, sqrt_nn)
213 {
214     this->test_sqrt_nn();
215 }
216 
217 
TYPED_TEST(complex_power_test,sqrt_pn)218 TYPED_TEST(complex_power_test, sqrt_pn)
219 {
220     this->test_sqrt_pn();
221 }
222 
TYPED_TEST(complex_power_test,sqrt_np)223 TYPED_TEST(complex_power_test, sqrt_np)
224 {
225     this->test_sqrt_np();
226 }
227 
228 
TYPED_TEST(complex_power_test,sqrt_pp)229 TYPED_TEST(complex_power_test, sqrt_pp)
230 {
231     this->test_sqrt_pp();
232 }
233