1 /***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and *
3 * Martin Renou *
4 * Copyright (c) QuantStack *
5 * *
6 * Distributed under the terms of the BSD 3-Clause License. *
7 * *
8 * The full license is in the file LICENSE, distributed with this software. *
9 ****************************************************************************/
10
11 #include "test_utils.hpp"
12
13 template <class B>
14 class complex_power_test : public testing::Test
15 {
16 protected:
17
18 using batch_type = B;
19 using real_batch_type = typename B::real_batch;
20 using value_type = typename B::value_type;
21 using real_value_type = typename value_type::value_type;
22 static constexpr size_t size = B::size;
23 using vector_type = std::vector<value_type>;
24 using real_vector_type = std::vector<real_value_type>;
25
26 size_t nb_input;
27 vector_type lhs_nn;
28 vector_type lhs_pn;
29 vector_type lhs_np;
30 vector_type lhs_pp;
31 vector_type rhs;
32 vector_type expected;
33 vector_type res;
34
complex_power_test()35 complex_power_test()
36 {
37 nb_input = 10000 * size;
38 lhs_nn.resize(nb_input);
39 lhs_pn.resize(nb_input);
40 lhs_np.resize(nb_input);
41 lhs_pp.resize(nb_input);
42 rhs.resize(nb_input);
43 for (size_t i = 0; i < nb_input; ++i)
44 {
45 real_value_type real = (real_value_type(i) / 4 + real_value_type(1.2) * std::sqrt(real_value_type(i + 0.25)))/ 100;
46 real_value_type imag = (real_value_type(i) / 7 + real_value_type(1.7) * std::sqrt(real_value_type(i + 0.37))) / 100;
47 lhs_nn[i] = value_type(-real, -imag);
48 lhs_pn[i] = value_type(real, -imag);
49 lhs_np[i] = value_type(-real, imag);
50 lhs_pp[i] = value_type(real, imag);
51 rhs[i] = value_type(real_value_type(10.2) / (i + 2) + real_value_type(0.25),
52 real_value_type(9.1) / (i + 3) + real_value_type(0.45));
53 }
54 expected.resize(nb_input);
55 res.resize(nb_input);
56 }
57
test_abs()58 void test_abs()
59 {
60 real_vector_type real_expected(nb_input), real_res(nb_input);
61 std::transform(lhs_np.cbegin(), lhs_np.cend(), real_expected.begin(),
62 [](const value_type& v) { using std::abs; return abs(v); });
63 batch_type in;
64 real_batch_type out;
65 for (size_t i = 0; i < nb_input; i += size)
66 {
67 detail::load_batch(in, lhs_np, i);
68 out = abs(in);
69 detail::store_batch(out, real_res, i);
70 }
71 size_t diff = detail::get_nb_diff(real_res, real_expected);
72 EXPECT_EQ(diff, 0) << print_function_name("abs");
73 }
74
test_arg()75 void test_arg()
76 {
77 real_vector_type real_expected(nb_input), real_res(nb_input);
78 std::transform(lhs_np.cbegin(), lhs_np.cend(), real_expected.begin(),
79 [](const value_type& v) { using std::arg; return arg(v); });
80 batch_type in;
81 real_batch_type out;
82 for (size_t i = 0; i < nb_input; i += size)
83 {
84 detail::load_batch(in, lhs_np, i);
85 out = arg(in);
86 detail::store_batch(out, real_res, i);
87 }
88 size_t diff = detail::get_nb_diff(real_res, real_expected);
89 EXPECT_EQ(diff, 0) << print_function_name("arg");
90 }
91
test_pow()92 void test_pow()
93 {
94 test_conditional_pow<real_value_type>();
95 }
96
test_sqrt_nn()97 void test_sqrt_nn()
98 {
99 std::transform(lhs_nn.cbegin(), lhs_nn.cend(), expected.begin(),
100 [](const value_type& v) { using std::sqrt; return sqrt(v); });
101 batch_type in, out;
102 for (size_t i = 0; i < nb_input; i += size)
103 {
104 detail::load_batch(in, lhs_nn, i);
105 out = sqrt(in);
106 detail::store_batch(out, res, i);
107 }
108 size_t diff = detail::get_nb_diff(res, expected);
109 EXPECT_EQ(diff, 0) << print_function_name("sqrt_nn");
110 }
111
test_sqrt_pn()112 void test_sqrt_pn()
113 {
114 std::transform(lhs_pn.cbegin(), lhs_pn.cend(), expected.begin(),
115 [](const value_type& v) { using std::sqrt; return sqrt(v); });
116 batch_type in, out;
117 for (size_t i = 0; i < nb_input; i += size)
118 {
119 detail::load_batch(in, lhs_pn, i);
120 out = sqrt(in);
121 detail::store_batch(out, res, i);
122 }
123 size_t diff = detail::get_nb_diff(res, expected);
124 EXPECT_EQ(diff, 0) << print_function_name("sqrt_pn");
125 }
126
test_sqrt_np()127 void test_sqrt_np()
128 {
129 std::transform(lhs_np.cbegin(), lhs_np.cend(), expected.begin(),
130 [](const value_type& v) { using std::sqrt; return sqrt(v); });
131 batch_type in, out;
132 for (size_t i = 0; i < nb_input; i += size)
133 {
134 detail::load_batch(in, lhs_np, i);
135 out = sqrt(in);
136 detail::store_batch(out, res, i);
137 }
138 size_t diff = detail::get_nb_diff(res, expected);
139 EXPECT_EQ(diff, 0) << print_function_name("sqrt_nn");
140 }
141
test_sqrt_pp()142 void test_sqrt_pp()
143 {
144 std::transform(lhs_pp.cbegin(), lhs_pp.cend(), expected.begin(),
145 [](const value_type& v) { using std::sqrt; return sqrt(v); });
146 batch_type in, out;
147 for (size_t i = 0; i < nb_input; i += size)
148 {
149 detail::load_batch(in, lhs_pp, i);
150 out = sqrt(in);
151 detail::store_batch(out, res, i);
152 }
153 size_t diff = detail::get_nb_diff(res, expected);
154 EXPECT_EQ(diff, 0) << print_function_name("sqrt_pp");
155 }
156
157 private:
158
test_pow_impl()159 void test_pow_impl()
160 {
161 std::transform(lhs_np.cbegin(), lhs_np.cend(), rhs.cbegin(), expected.begin(),
162 [](const value_type& l, const value_type& r) { using std::pow; return pow(l, r); });
163 batch_type lhs_in, rhs_in, out;
164 for (size_t i = 0; i < nb_input; i += size)
165 {
166 detail::load_batch(lhs_in, lhs_np, i);
167 detail::load_batch(rhs_in, rhs, i);
168 out = pow(lhs_in, rhs_in);
169 detail::store_batch(out, res, i);
170 }
171 size_t diff = detail::get_nb_diff(res, expected);
172 EXPECT_EQ(diff, 0) << print_function_name("pow");
173 }
174
175 template <class T, typename std::enable_if<!std::is_same<T, float>::value, int>::type = 0>
test_conditional_pow()176 void test_conditional_pow()
177 {
178 test_pow_impl();
179 }
180
181 template <class T, typename std::enable_if<std::is_same<T, float>::value, int>::type = 0>
test_conditional_pow()182 void test_conditional_pow()
183 {
184
185 #if (XSIMD_X86_INSTR_SET >= XSIMD_X86_AVX512_VERSION) || (XSIMD_ARM_INSTR_SET >= XSIMD_ARM7_NEON_VERSION)
186 #if DEBUG_ACCURACY
187 test_pow_impl();
188 #endif
189 #else
190 test_pow_impl();
191 #endif
192 }
193 };
194
195 TYPED_TEST_SUITE(complex_power_test, batch_complex_types, simd_test_names);
196
TYPED_TEST(complex_power_test,abs)197 TYPED_TEST(complex_power_test, abs)
198 {
199 this->test_abs();
200 }
201
TYPED_TEST(complex_power_test,arg)202 TYPED_TEST(complex_power_test, arg)
203 {
204 this->test_arg();
205 }
206
TYPED_TEST(complex_power_test,pow)207 TYPED_TEST(complex_power_test, pow)
208 {
209 this->test_pow();
210 }
211
TYPED_TEST(complex_power_test,sqrt_nn)212 TYPED_TEST(complex_power_test, sqrt_nn)
213 {
214 this->test_sqrt_nn();
215 }
216
217
TYPED_TEST(complex_power_test,sqrt_pn)218 TYPED_TEST(complex_power_test, sqrt_pn)
219 {
220 this->test_sqrt_pn();
221 }
222
TYPED_TEST(complex_power_test,sqrt_np)223 TYPED_TEST(complex_power_test, sqrt_np)
224 {
225 this->test_sqrt_np();
226 }
227
228
TYPED_TEST(complex_power_test,sqrt_pp)229 TYPED_TEST(complex_power_test, sqrt_pp)
230 {
231 this->test_sqrt_pp();
232 }
233