1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <algorithm>
12 #include <cmath>
13 #include <cstdint>
14 #include <initializer_list>
15 #include <random>
16 #include <stdexcept>
17 #include <tuple>
18 #include <vector>
19 
20 #include <boost/math/constants/constants.hpp>
21 
22 #include <fmt/format.h>
23 
24 #include <llvm/IR/BasicBlock.h>
25 #include <llvm/IR/DerivedTypes.h>
26 #include <llvm/IR/Function.h>
27 #include <llvm/IR/IRBuilder.h>
28 #include <llvm/IR/LLVMContext.h>
29 #include <llvm/IR/Module.h>
30 #include <llvm/IR/Type.h>
31 
32 #if defined(HEYOKA_HAVE_REAL128)
33 
34 #include <mp++/real128.hpp>
35 
36 #endif
37 
38 #include <heyoka/detail/llvm_helpers.hpp>
39 #include <heyoka/llvm_state.hpp>
40 
41 #include "catch.hpp"
42 #include "test_utils.hpp"
43 
44 #if defined(_MSC_VER) && !defined(__clang__)
45 
46 // NOTE: MSVC has issues with the other "using"
47 // statement form.
48 using namespace fmt::literals;
49 
50 #else
51 
52 using fmt::literals::operator""_format;
53 
54 #endif
55 
56 using namespace heyoka;
57 using namespace heyoka_test;
58 
59 const auto fp_types = std::tuple<double
60 #if !defined(HEYOKA_ARCH_PPC)
61                                  ,
62                                  long double
63 #endif
64 #if defined(HEYOKA_HAVE_REAL128)
65                                  ,
66                                  mppp::real128
67 #endif
68                                  >{};
69 
70 std::mt19937 rng;
71 
72 constexpr auto ntrials = 100;
73 
74 TEST_CASE("sgn scalar")
75 {
76     using detail::llvm_sgn;
77     using detail::to_llvm_type;
78 
__anon5160b0c00102(auto fp_x) 79     auto tester = [](auto fp_x) {
80         using fp_t = decltype(fp_x);
81 
82         for (auto opt_level : {0u, 1u, 2u, 3u}) {
83             llvm_state s{kw::opt_level = opt_level};
84 
85             auto &md = s.module();
86             auto &builder = s.builder();
87             auto &context = s.context();
88 
89             auto val_t = to_llvm_type<fp_t>(context);
90 
91             auto *ft = llvm::FunctionType::get(builder.getInt32Ty(), {val_t}, false);
92             auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "sgn", &md);
93 
94             auto x = f->args().begin();
95 
96             builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
97 
98             // Create the return value.
99             builder.CreateRet(llvm_sgn(s, x));
100 
101             // Verify.
102             s.verify_function(f);
103 
104             // Run the optimisation pass.
105             s.optimise();
106 
107             // Compile.
108             s.compile();
109 
110             // Fetch the function pointer.
111             auto f_ptr = reinterpret_cast<std::int32_t (*)(fp_t)>(s.jit_lookup("sgn"));
112 
113             REQUIRE(f_ptr(0) == 0);
114             REQUIRE(f_ptr(-42) == -1);
115             REQUIRE(f_ptr(123) == 1);
116         }
117     };
118 
119     tuple_for_each(fp_types, tester);
120 }
121 
122 // Generic branchless sign function.
123 template <typename T>
sgn(T val)124 int sgn(T val)
125 {
126     return (T(0) < val) - (val < T(0));
127 }
128 
129 TEST_CASE("sgn batch")
130 {
131     using detail::llvm_sgn;
132     using detail::to_llvm_type;
133 
__anon5160b0c00202(auto fp_x) 134     auto tester = [](auto fp_x) {
135         using fp_t = decltype(fp_x);
136 
137         for (auto batch_size : {1u, 2u, 4u, 13u}) {
138             for (auto opt_level : {0u, 1u, 2u, 3u}) {
139                 llvm_state s{kw::opt_level = opt_level};
140 
141                 auto &md = s.module();
142                 auto &builder = s.builder();
143                 auto &context = s.context();
144 
145                 auto val_t = to_llvm_type<fp_t>(context);
146 
147                 auto *ft = llvm::FunctionType::get(
148                     builder.getVoidTy(),
149                     {llvm::PointerType::getUnqual(builder.getInt32Ty()), llvm::PointerType::getUnqual(val_t)}, false);
150                 auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "sgn", &md);
151 
152                 auto out = f->args().begin();
153                 auto x = f->args().begin() + 1;
154 
155                 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
156 
157                 // Load the vector from memory.
158                 auto v = detail::load_vector_from_memory(builder, x, batch_size);
159 
160                 // Create and store the return value.
161                 detail::store_vector_to_memory(builder, out, llvm_sgn(s, v));
162 
163                 builder.CreateRetVoid();
164 
165                 // Verify.
166                 s.verify_function(f);
167 
168                 // Run the optimisation pass.
169                 s.optimise();
170 
171                 // Compile.
172                 s.compile();
173 
174                 // Fetch the function pointer.
175                 auto f_ptr = reinterpret_cast<void (*)(std::int32_t *, const fp_t *)>(s.jit_lookup("sgn"));
176 
177                 std::uniform_real_distribution<double> rdist(-10., 10.);
178                 std::vector<fp_t> values(batch_size);
179                 std::generate(values.begin(), values.end(), [&rdist]() { return rdist(rng); });
180                 std::vector<std::int32_t> signs(batch_size);
181 
182                 f_ptr(signs.data(), values.data());
183 
184                 for (auto i = 0u; i < batch_size; ++i) {
185                     REQUIRE(signs[i] == sgn(values[i]));
186                 }
187 
188                 values[0] = 0;
189 
190                 f_ptr(signs.data(), values.data());
191                 REQUIRE(signs[0] == 0);
192             }
193         }
194     };
195 
196     tuple_for_each(fp_types, tester);
197 }
198 
199 TEST_CASE("sincos scalar")
200 {
201     using detail::llvm_sincos;
202     using detail::to_llvm_type;
203     using std::cos;
204     using std::sin;
205 
__anon5160b0c00402(auto fp_x) 206     auto tester = [](auto fp_x) {
207         using fp_t = decltype(fp_x);
208 
209         for (auto opt_level : {0u, 1u, 2u, 3u}) {
210             llvm_state s{kw::opt_level = opt_level};
211 
212             auto &md = s.module();
213             auto &builder = s.builder();
214             auto &context = s.context();
215 
216             auto val_t = to_llvm_type<fp_t>(context);
217 
218             std::vector<llvm::Type *> fargs{val_t, llvm::PointerType::getUnqual(val_t),
219                                             llvm::PointerType::getUnqual(val_t)};
220             auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
221             auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "sc", &md);
222 
223             auto x = f->args().begin();
224             auto sptr = f->args().begin() + 1;
225             auto cptr = f->args().begin() + 2;
226 
227             builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
228 
229             auto ret = llvm_sincos(s, x);
230             builder.CreateStore(ret.first, sptr);
231             builder.CreateStore(ret.second, cptr);
232 
233             // Create the return value.
234             builder.CreateRetVoid();
235 
236             // Verify.
237             s.verify_function(f);
238 
239             // Run the optimisation pass.
240             s.optimise();
241 
242             // Compile.
243             s.compile();
244 
245             // Fetch the function pointer.
246             auto f_ptr = reinterpret_cast<void (*)(fp_t, fp_t *, fp_t *)>(s.jit_lookup("sc"));
247 
248             fp_t sn, cs;
249             f_ptr(fp_t(2), &sn, &cs);
250             REQUIRE(sn == approximately(sin(fp_t(2))));
251             REQUIRE(cs == approximately(cos(fp_t(2))));
252 
253             f_ptr(fp_t(-123.45), &sn, &cs);
254             REQUIRE(sn == approximately(sin(fp_t(-123.45))));
255             REQUIRE(cs == approximately(cos(fp_t(-123.45))));
256         }
257     };
258 
259     tuple_for_each(fp_types, tester);
260 }
261 
262 TEST_CASE("sincos batch")
263 {
264     using detail::llvm_sincos;
265     using detail::to_llvm_type;
266     using std::cos;
267     using std::sin;
268 
__anon5160b0c00502(auto fp_x) 269     auto tester = [](auto fp_x) {
270         using fp_t = decltype(fp_x);
271 
272         for (auto batch_size : {1u, 2u, 4u, 13u}) {
273             for (auto opt_level : {0u, 1u, 2u, 3u}) {
274                 llvm_state s{kw::opt_level = opt_level};
275 
276                 auto &md = s.module();
277                 auto &builder = s.builder();
278                 auto &context = s.context();
279 
280                 auto val_t = to_llvm_type<fp_t>(context);
281 
282                 std::vector<llvm::Type *> fargs{llvm::PointerType::getUnqual(val_t),
283                                                 llvm::PointerType::getUnqual(val_t),
284                                                 llvm::PointerType::getUnqual(val_t)};
285                 auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
286                 auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "sc", &md);
287 
288                 auto xptr = f->args().begin();
289                 auto sptr = f->args().begin() + 1;
290                 auto cptr = f->args().begin() + 2;
291 
292                 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
293 
294                 auto x = detail::load_vector_from_memory(builder, xptr, batch_size);
295 
296                 auto ret = llvm_sincos(s, x);
297                 detail::store_vector_to_memory(builder, sptr, ret.first);
298                 detail::store_vector_to_memory(builder, cptr, ret.second);
299 
300                 // Create the return value.
301                 builder.CreateRetVoid();
302 
303                 // Verify.
304                 s.verify_function(f);
305 
306                 // Run the optimisation pass.
307                 s.optimise();
308 
309                 // Compile.
310                 s.compile();
311 
312                 // Fetch the function pointer.
313                 auto f_ptr = reinterpret_cast<void (*)(fp_t *, fp_t *, fp_t *)>(s.jit_lookup("sc"));
314 
315                 // Setup the argument and the output values.
316                 std::vector<fp_t> x_vec(batch_size), s_vec(x_vec), c_vec(x_vec);
317                 for (auto i = 0u; i < batch_size; ++i) {
318                     x_vec[i] = i + 1u;
319                 }
320 
321                 f_ptr(x_vec.data(), s_vec.data(), c_vec.data());
322 
323                 for (auto i = 0u; i < batch_size; ++i) {
324                     REQUIRE(s_vec[i] == approximately(sin(x_vec[i])));
325                     REQUIRE(c_vec[i] == approximately(cos(x_vec[i])));
326                 }
327             }
328         }
329     };
330 
331     tuple_for_each(fp_types, tester);
332 }
333 
334 TEST_CASE("modulus scalar")
335 {
336     using detail::llvm_modulus;
337     using detail::to_llvm_type;
338     using std::floor;
339 
__anon5160b0c00602(auto fp_x) 340     auto tester = [](auto fp_x) {
341         using fp_t = decltype(fp_x);
342 
343         for (auto opt_level : {0u, 1u, 2u, 3u}) {
344             llvm_state s{kw::opt_level = opt_level};
345 
346             auto &md = s.module();
347             auto &builder = s.builder();
348             auto &context = s.context();
349 
350             auto val_t = to_llvm_type<fp_t>(context);
351 
352             std::vector<llvm::Type *> fargs{val_t, val_t};
353             auto *ft = llvm::FunctionType::get(val_t, fargs, false);
354             auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_rem", &md);
355 
356             auto x = f->args().begin();
357             auto y = f->args().begin() + 1;
358 
359             builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
360 
361             builder.CreateRet(llvm_modulus(s, x, y));
362 
363             // Verify.
364             s.verify_function(f);
365 
366             // Run the optimisation pass.
367             s.optimise();
368 
369             // Compile.
370             s.compile();
371 
372             // Fetch the function pointer.
373             auto f_ptr = reinterpret_cast<fp_t (*)(fp_t, fp_t)>(s.jit_lookup("hey_rem"));
374 
375             auto a = fp_t(123);
376             auto b = fp_t(2) / fp_t(7);
377 
378             REQUIRE(f_ptr(a, b) == approximately(a - b * floor(a / b), fp_t(1000)));
379 
380             a = fp_t(-4);
381             b = fp_t(314) / fp_t(100);
382 
383             REQUIRE(f_ptr(a, b) == approximately(a - b * floor(a / b), fp_t(1000)));
384         }
385     };
386 
387     tuple_for_each(fp_types, tester);
388 }
389 
390 TEST_CASE("modulus batch")
391 {
392     using detail::llvm_modulus;
393     using detail::to_llvm_type;
394     using std::floor;
395 
__anon5160b0c00702(auto fp_x) 396     auto tester = [](auto fp_x) {
397         using fp_t = decltype(fp_x);
398 
399         for (auto batch_size : {1u, 2u, 4u, 13u}) {
400             for (auto opt_level : {0u, 1u, 2u, 3u}) {
401                 llvm_state s{kw::opt_level = opt_level};
402 
403                 auto &md = s.module();
404                 auto &builder = s.builder();
405                 auto &context = s.context();
406 
407                 auto val_t = to_llvm_type<fp_t>(context);
408 
409                 std::vector<llvm::Type *> fargs{llvm::PointerType::getUnqual(val_t),
410                                                 llvm::PointerType::getUnqual(val_t),
411                                                 llvm::PointerType::getUnqual(val_t)};
412                 auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
413                 auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_rem", &md);
414 
415                 auto ret_ptr = f->args().begin();
416                 auto x_ptr = f->args().begin() + 1;
417                 auto y_ptr = f->args().begin() + 2;
418 
419                 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
420 
421                 auto ret = llvm_modulus(s, detail::load_vector_from_memory(builder, x_ptr, batch_size),
422                                         detail::load_vector_from_memory(builder, y_ptr, batch_size));
423 
424                 detail::store_vector_to_memory(builder, ret_ptr, ret);
425 
426                 builder.CreateRetVoid();
427 
428                 // Verify.
429                 s.verify_function(f);
430 
431                 // Run the optimisation pass.
432                 s.optimise();
433 
434                 // Compile.
435                 s.compile();
436 
437                 // Fetch the function pointer.
438                 auto f_ptr = reinterpret_cast<void (*)(fp_t *, fp_t *, fp_t *)>(s.jit_lookup("hey_rem"));
439 
440                 // Setup the arguments and the output value.
441                 std::vector<fp_t> ret_vec(batch_size), a_vec(ret_vec), b_vec(ret_vec);
442                 for (auto i = 0u; i < batch_size; ++i) {
443                     a_vec[i] = i + 1u;
444                     b_vec[i] = a_vec[i] * 10 * (i + 1u);
445                 }
446 
447                 f_ptr(ret_vec.data(), a_vec.data(), b_vec.data());
448 
449                 for (auto i = 0u; i < batch_size; ++i) {
450                     auto a = a_vec[i];
451                     auto b = b_vec[i];
452 
453                     REQUIRE(ret_vec[i] == approximately(a - b * floor(a / b), fp_t(1000)));
454                 }
455             }
456         }
457     };
458 
459     tuple_for_each(fp_types, tester);
460 }
461 
462 TEST_CASE("inv_kep_E_scalar")
463 {
464     using detail::llvm_add_inv_kep_E;
465     using detail::to_llvm_type;
466     namespace bmt = boost::math::tools;
467     using std::cos;
468     using std::sin;
469 
__anon5160b0c00802(auto fp_x) 470     auto tester = [](auto fp_x) {
471         using fp_t = decltype(fp_x);
472 
473         for (auto opt_level : {0u, 1u, 2u, 3u}) {
474             llvm_state s{kw::opt_level = opt_level};
475 
476             auto fkep = llvm_add_inv_kep_E<fp_t>(s, 1);
477 
478             auto &md = s.module();
479             auto &builder = s.builder();
480             auto &context = s.context();
481 
482             {
483                 auto val_t = to_llvm_type<fp_t>(context);
484 
485                 std::vector<llvm::Type *> fargs{val_t, val_t};
486                 auto *ft = llvm::FunctionType::get(val_t, fargs, false);
487                 auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_kep", &md);
488 
489                 auto e = f->args().begin();
490                 auto M = f->args().begin() + 1;
491 
492                 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
493 
494                 builder.CreateRet(builder.CreateCall(fkep, {e, M}));
495 
496                 // Verify.
497                 s.verify_function(f);
498 
499                 // Run the optimisation pass.
500                 s.optimise();
501 
502                 // Compile.
503                 s.compile();
504             }
505 
506             // Fetch the function pointer.
507             auto f_ptr = reinterpret_cast<fp_t (*)(fp_t, fp_t)>(s.jit_lookup("hey_kep"));
508 
509             std::uniform_real_distribution<double> e_dist(0., 1.), M_dist(0., 2 * boost::math::constants::pi<double>());
510 
511             // First set of tests with zero eccentricity.
512             for (auto i = 0; i < ntrials; ++i) {
513                 const auto M = M_dist(rng);
514                 const auto E = f_ptr(0, M);
515                 REQUIRE(fp_t(M) == approximately(E));
516             }
517 
518             // Non-zero eccentricities.
519             for (auto i = 0; i < ntrials * 10; ++i) {
520                 const auto M = M_dist(rng);
521                 const auto e = e_dist(rng);
522                 const auto E = f_ptr(e, M);
523                 REQUIRE(fp_t(M) == approximately(E - e * sin(E), fp_t(10000)));
524             }
525         }
526     };
527 
528     tuple_for_each(fp_types, tester);
529 }
530 
531 TEST_CASE("inv_kep_E_batch")
532 {
533     using detail::llvm_add_inv_kep_E;
534     using detail::to_llvm_type;
535     namespace bmt = boost::math::tools;
536     using std::cos;
537     using std::sin;
538 
__anon5160b0c00902(auto fp_x) 539     auto tester = [](auto fp_x) {
540         using fp_t = decltype(fp_x);
541 
542         for (auto batch_size : {1u, 2u, 4u, 13u}) {
543             for (auto opt_level : {0u, 1u, 2u, 3u}) {
544                 llvm_state s{kw::opt_level = opt_level};
545 
546                 auto fkep = llvm_add_inv_kep_E<fp_t>(s, batch_size);
547 
548                 auto &md = s.module();
549                 auto &builder = s.builder();
550                 auto &context = s.context();
551 
552                 auto val_t = to_llvm_type<fp_t>(context);
553 
554                 std::vector<llvm::Type *> fargs{llvm::PointerType::getUnqual(val_t),
555                                                 llvm::PointerType::getUnqual(val_t),
556                                                 llvm::PointerType::getUnqual(val_t)};
557                 auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
558                 auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "hey_kep", &md);
559 
560                 auto ret_ptr = f->args().begin();
561                 auto e_ptr = f->args().begin() + 1;
562                 auto M_ptr = f->args().begin() + 2;
563 
564                 builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
565 
566                 auto ret = builder.CreateCall(fkep, {detail::load_vector_from_memory(builder, e_ptr, batch_size),
567                                                      detail::load_vector_from_memory(builder, M_ptr, batch_size)});
568 
569                 detail::store_vector_to_memory(builder, ret_ptr, ret);
570 
571                 builder.CreateRetVoid();
572 
573                 // Verify.
574                 s.verify_function(f);
575 
576                 // Run the optimisation pass.
577                 s.optimise();
578 
579                 // Compile.
580                 s.compile();
581 
582                 // Fetch the function pointer.
583                 auto f_ptr = reinterpret_cast<void (*)(fp_t *, fp_t *, fp_t *)>(s.jit_lookup("hey_kep"));
584 
585                 std::uniform_real_distribution<double> e_dist(0., 1.),
586                     M_dist(0., 2 * boost::math::constants::pi<double>());
587 
588                 std::vector<fp_t> ret_vec(batch_size), e_vec(ret_vec), M_vec(ret_vec);
589 
590                 // First set of tests with zero eccentricity.
591                 for (auto i = 0; i < ntrials; ++i) {
592                     for (auto j = 0u; j < batch_size; ++j) {
593                         M_vec[j] = M_dist(rng);
594                     }
595                     f_ptr(ret_vec.data(), e_vec.data(), M_vec.data());
596 
597                     for (auto j = 0u; j < batch_size; ++j) {
598                         REQUIRE(M_vec[j] == approximately(ret_vec[j]));
599                     }
600                 }
601 
602                 // Non-zero eccentricities.
603                 for (auto i = 0; i < ntrials * 10; ++i) {
604                     for (auto j = 0u; j < batch_size; ++j) {
605                         M_vec[j] = M_dist(rng);
606                         e_vec[j] = e_dist(rng);
607                     }
608                     f_ptr(ret_vec.data(), e_vec.data(), M_vec.data());
609 
610                     for (auto j = 0u; j < batch_size; ++j) {
611                         REQUIRE(M_vec[j] == approximately(ret_vec[j] - e_vec[j] * sin(ret_vec[j]), fp_t(10000)));
612                     }
613                 }
614             }
615         }
616     };
617 
618     tuple_for_each(fp_types, tester);
619 }
620 
621 TEST_CASE("while_loop")
622 {
623     using detail::llvm_while_loop;
624 
625     for (auto opt_level : {0u, 1u, 2u, 3u}) {
626         llvm_state s{kw::opt_level = opt_level};
627 
628         auto &md = s.module();
629         auto &builder = s.builder();
630         auto &context = s.context();
631 
632         auto val_t = builder.getInt32Ty();
633 
634         std::vector<llvm::Type *> fargs{val_t};
635         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
636         auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "count_n", &md);
637 
638         auto final_n = f->args().begin();
639 
640         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
641 
642         auto retval = builder.CreateAlloca(val_t);
643         builder.CreateStore(builder.getInt32(0), retval);
644 
645         llvm_while_loop(
__anon5160b0c00a02() 646             s, [&]() -> llvm::Value * { return builder.CreateICmpULT(builder.CreateLoad(retval), final_n); },
__anon5160b0c00b02() 647             [&]() { builder.CreateStore(builder.CreateAdd(builder.CreateLoad(retval), builder.getInt32(1)), retval); });
648 
649         // Return the result.
650         builder.CreateRet(builder.CreateLoad(retval));
651 
652         // Verify.
653         s.verify_function(f);
654 
655         // Run the optimisation pass.
656         s.optimise();
657 
658         // Compile.
659         s.compile();
660 
661         // Fetch the function pointer.
662         auto f_ptr = reinterpret_cast<std::uint32_t (*)(std::uint32_t)>(s.jit_lookup("count_n"));
663 
664         REQUIRE(f_ptr(0) == 0u);
665         REQUIRE(f_ptr(1) == 1u);
666         REQUIRE(f_ptr(2) == 2u);
667         REQUIRE(f_ptr(3) == 3u);
668         REQUIRE(f_ptr(4) == 4u);
669     }
670 
671     // Error handling.
672     {
673         llvm_state s;
674 
675         auto &md = s.module();
676         auto &builder = s.builder();
677         auto &context = s.context();
678 
679         auto val_t = builder.getInt32Ty();
680 
681         std::vector<llvm::Type *> fargs{val_t};
682         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
683         auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "count_n", &md);
684 
685         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
686 
687         auto retval = builder.CreateAlloca(val_t);
688         builder.CreateStore(builder.getInt32(0), retval);
689 
690         // NOTE: if we don't do the cleanup of f before re-throwing,
691         // on the OSX CI the test will hang on shutdown (i.e., all the tests
692         // run correctly but the test program hangs on exit). Not sure what is
693         // going on with that, perhaps another bad interaction between LLVM and
694         // exceptions?
__anon5160b0c00c02() 695         auto thrower = [&]() {
696             try {
697                 llvm_while_loop(
698                     s, [&]() -> llvm::Value * { throw std::runtime_error{"aa"}; }, [&]() {});
699             } catch (...) {
700                 f->eraseFromParent();
701 
702                 throw;
703             }
704         };
705 
706         REQUIRE_THROWS_AS(thrower(), std::runtime_error);
707     }
708 
709     {
710         llvm_state s;
711 
712         auto &md = s.module();
713         auto &builder = s.builder();
714         auto &context = s.context();
715 
716         auto val_t = builder.getInt32Ty();
717 
718         std::vector<llvm::Type *> fargs{val_t};
719         auto *ft = llvm::FunctionType::get(val_t, fargs, false);
720         auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "count_n", &md);
721 
722         auto final_n = f->args().begin();
723 
724         builder.SetInsertPoint(llvm::BasicBlock::Create(context, "entry", f));
725 
726         auto retval = builder.CreateAlloca(val_t);
727         builder.CreateStore(builder.getInt32(0), retval);
728 
__anon5160b0c00f02() 729         auto thrower = [&]() {
730             try {
731                 llvm_while_loop(
732                     s, [&]() -> llvm::Value * { return builder.CreateICmpULT(builder.CreateLoad(retval), final_n); },
733                     [&]() { throw std::runtime_error{"aa"}; });
734             } catch (...) {
735                 f->eraseFromParent();
736 
737                 throw;
738             }
739         };
740 
741         REQUIRE_THROWS_AS(thrower(), std::runtime_error);
742     }
743 }
744 
745 TEST_CASE("csc_scalar")
746 {
747     using detail::llvm_add_csc;
748     using detail::llvm_mangle_type;
749     using detail::to_llvm_type;
750 
__anon5160b0c01202(auto fp_x) 751     auto tester = [](auto fp_x) {
752         using fp_t = decltype(fp_x);
753 
754         for (auto opt_level : {0u, 1u, 2u, 3u}) {
755             llvm_state s{kw::opt_level = opt_level};
756 
757             const auto degree = 4u;
758 
759             llvm_add_csc<fp_t>(s, degree, 1);
760 
761             s.optimise();
762 
763             s.compile();
764 
765             auto f_ptr = reinterpret_cast<void (*)(std::uint32_t *, const fp_t *)>(s.jit_lookup(
766                 "heyoka_csc_degree_{}_{}"_format(degree, llvm_mangle_type(to_llvm_type<fp_t>(s.context())))));
767 
768             // Random testing.
769             std::uniform_real_distribution<double> rdist(-10., 10.);
770             std::uniform_int_distribution<int> idist(0, 9);
771             std::uint32_t out = 0;
772             std::vector<fp_t> cfs(degree + 1u), nz_values;
773 
774             for (auto i = 0; i < ntrials * 10; ++i) {
775                 nz_values.clear();
776 
777                 // Generate random coefficients, putting
778                 // in a zero every once in a while.
779                 std::generate(cfs.begin(), cfs.end(), [&idist, &rdist, &nz_values]() {
780                     auto ret = idist(rng) == 0 ? fp_t(0) : fp_t(rdist(rng));
781                     if (ret != 0) {
782                         nz_values.push_back(ret);
783                     }
784 
785                     return ret;
786                 });
787 
788                 // Determine the number of sign changes.
789                 auto n_sc = 0u;
790                 for (decltype(nz_values.size()) j = 1; j < nz_values.size(); ++j) {
791                     n_sc += sgn(nz_values[j]) != sgn(nz_values[j - 1u]);
792                 }
793 
794                 // Check it.
795                 f_ptr(&out, cfs.data());
796                 REQUIRE(out == n_sc);
797             }
798 
799             // A full zero test.
800             std::fill(cfs.begin(), cfs.end(), fp_t(0));
801             f_ptr(&out, cfs.data());
802             REQUIRE(out == 0u);
803 
804             // Full 1.
805             std::fill(cfs.begin(), cfs.end(), fp_t(1));
806             f_ptr(&out, cfs.data());
807             REQUIRE(out == 0u);
808 
809             // Full -1.
810             std::fill(cfs.begin(), cfs.end(), fp_t(-1));
811             f_ptr(&out, cfs.data());
812             REQUIRE(out == 0u);
813         }
814     };
815 
816     tuple_for_each(fp_types, tester);
817 }
818 
819 TEST_CASE("csc_batch")
820 {
821     using detail::llvm_add_csc;
822     using detail::llvm_mangle_type;
823     using detail::make_vector_type;
824     using detail::to_llvm_type;
825 
__anon5160b0c01402(auto fp_x) 826     auto tester = [](auto fp_x) {
827         using fp_t = decltype(fp_x);
828 
829         for (auto batch_size : {1u, 2u, 4u, 13u}) {
830             for (auto opt_level : {0u, 1u, 2u, 3u}) {
831                 llvm_state s{kw::opt_level = opt_level};
832 
833                 const auto degree = 4u;
834 
835                 llvm_add_csc<fp_t>(s, degree, batch_size);
836 
837                 s.optimise();
838 
839                 s.compile();
840 
841                 auto f_ptr = reinterpret_cast<void (*)(std::uint32_t *, const fp_t *)>(
842                     s.jit_lookup("heyoka_csc_degree_{}_{}"_format(
843                         degree, llvm_mangle_type(make_vector_type(to_llvm_type<fp_t>(s.context()), batch_size)))));
844 
845                 // Random testing.
846                 std::uniform_real_distribution<double> rdist(-10., 10.);
847                 std::uniform_int_distribution<int> idist(0, 9);
848                 std::vector<std::uint32_t> out(batch_size), n_sc(batch_size);
849                 std::vector<fp_t> cfs((degree + 1u) * batch_size), nz_values;
850 
851                 for (auto i = 0; i < ntrials * 10; ++i) {
852                     // Generate random coefficients, putting
853                     // in a zero every once in a while.
854                     std::generate(cfs.begin(), cfs.end(),
855                                   [&idist, &rdist]() { return idist(rng) == 0 ? fp_t(0) : fp_t(rdist(rng)); });
856 
857                     // Determine the number of sign changes for each batch element.
858                     for (auto batch_idx = 0u; batch_idx < batch_size; ++batch_idx) {
859                         nz_values.clear();
860 
861                         for (auto j = 0u; j <= degree; ++j) {
862                             if (cfs[batch_size * j + batch_idx] != 0) {
863                                 nz_values.push_back(cfs[batch_size * j + batch_idx]);
864                             }
865                         }
866 
867                         n_sc[batch_idx] = 0;
868                         for (decltype(nz_values.size()) j = 1; j < nz_values.size(); ++j) {
869                             n_sc[batch_idx] += sgn(nz_values[j]) != sgn(nz_values[j - 1u]);
870                         }
871                     }
872 
873                     // Check the result.
874                     f_ptr(out.data(), cfs.data());
875                     REQUIRE(std::equal(out.begin(), out.end(), n_sc.begin()));
876                 }
877 
878                 // A full zero test.
879                 std::fill(cfs.begin(), cfs.end(), fp_t(0));
880                 f_ptr(out.data(), cfs.data());
881                 REQUIRE(std::all_of(out.begin(), out.end(), [](auto x) { return x == 0; }));
882 
883                 // Full 1.
884                 std::fill(cfs.begin(), cfs.end(), fp_t(1));
885                 f_ptr(out.data(), cfs.data());
886                 REQUIRE(std::all_of(out.begin(), out.end(), [](auto x) { return x == 0; }));
887 
888                 // Full -1.
889                 std::fill(cfs.begin(), cfs.end(), fp_t(-1));
890                 f_ptr(out.data(), cfs.data());
891                 REQUIRE(std::all_of(out.begin(), out.end(), [](auto x) { return x == 0; }));
892             }
893         }
894     };
895 
896     tuple_for_each(fp_types, tester);
897 }
898