1 // Copyright 2020, 2021 Francesco Biscani (bluescarni@gmail.com), Dario Izzo (dario.izzo@gmail.com)
2 //
3 // This file is part of the heyoka library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #include <heyoka/config.hpp>
10 
11 #include <algorithm>
12 #include <cassert>
13 #include <chrono> // NOTE: needed for the spdlog stopwatch.
14 #include <cmath>
15 #include <cstddef>
16 #include <cstdint>
17 #include <deque>
18 #include <initializer_list>
19 #include <ios>
20 #include <iterator>
21 #include <limits>
22 #include <locale>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <sstream>
27 #include <stdexcept>
28 #include <string>
29 #include <type_traits>
30 #include <unordered_map>
31 #include <unordered_set>
32 #include <utility>
33 #include <variant>
34 #include <vector>
35 
36 #include <boost/graph/adjacency_list.hpp>
37 #include <boost/numeric/conversion/cast.hpp>
38 
39 #include <fmt/format.h>
40 #include <fmt/ostream.h>
41 
42 #include <spdlog/stopwatch.h>
43 
44 #include <llvm/IR/Attributes.h>
45 #include <llvm/IR/BasicBlock.h>
46 #include <llvm/IR/Constants.h>
47 #include <llvm/IR/DerivedTypes.h>
48 #include <llvm/IR/Function.h>
49 #include <llvm/IR/IRBuilder.h>
50 #include <llvm/IR/LLVMContext.h>
51 #include <llvm/IR/Module.h>
52 #include <llvm/IR/Type.h>
53 #include <llvm/IR/Value.h>
54 
55 #if defined(HEYOKA_HAVE_REAL128)
56 
57 #include <mp++/real128.hpp>
58 
59 #endif
60 
61 #include <heyoka/detail/llvm_helpers.hpp>
62 #include <heyoka/detail/logging_impl.hpp>
63 #include <heyoka/detail/string_conv.hpp>
64 #include <heyoka/detail/type_traits.hpp>
65 #include <heyoka/detail/visibility.hpp>
66 #include <heyoka/expression.hpp>
67 #include <heyoka/llvm_state.hpp>
68 #include <heyoka/number.hpp>
69 #include <heyoka/param.hpp>
70 #include <heyoka/s11n.hpp>
71 #include <heyoka/taylor.hpp>
72 #include <heyoka/variable.hpp>
73 
74 #if defined(_MSC_VER) && !defined(__clang__)
75 
76 // NOTE: MSVC has issues with the other "using"
77 // statement form.
78 using namespace fmt::literals;
79 
80 #else
81 
82 using fmt::literals::operator""_format;
83 
84 #endif
85 
86 namespace heyoka
87 {
88 
89 namespace detail
90 {
91 
92 namespace
93 {
94 
taylor_c_diff_mangle(const variable &)95 std::string taylor_c_diff_mangle(const variable &)
96 {
97     return "var";
98 }
99 
taylor_c_diff_mangle(const number &)100 std::string taylor_c_diff_mangle(const number &)
101 {
102     return "num";
103 }
104 
taylor_c_diff_mangle(const param &)105 std::string taylor_c_diff_mangle(const param &)
106 {
107     return "par";
108 }
109 
110 } // namespace
111 
112 // NOTE: precondition on name: must be conforming to LLVM requirements for
113 // function names, and must not contain "." (as we use it as a separator in
114 // the mangling scheme).
115 std::pair<std::string, std::vector<llvm::Type *>>
taylor_c_diff_func_name_args_impl(llvm::LLVMContext & context,const std::string & name,llvm::Type * val_t,std::uint32_t n_uvars,const std::vector<std::variant<variable,number,param>> & args,std::uint32_t n_hidden_deps)116 taylor_c_diff_func_name_args_impl(llvm::LLVMContext &context, const std::string &name, llvm::Type *val_t,
117                                   std::uint32_t n_uvars, const std::vector<std::variant<variable, number, param>> &args,
118                                   std::uint32_t n_hidden_deps)
119 {
120     assert(val_t != nullptr);
121     assert(n_uvars > 0u);
122 
123     // Init the name.
124     auto fname = "heyoka.taylor_c_diff.{}."_format(name);
125 
126     // Init the vector of arguments:
127     // - diff order,
128     // - idx of the u variable whose diff is being computed,
129     // - diff array (pointer to val_t),
130     // - par ptr (pointer to scalar),
131     // - time ptr (pointer to scalar).
132     std::vector<llvm::Type *> fargs{
133         llvm::Type::getInt32Ty(context), llvm::Type::getInt32Ty(context), llvm::PointerType::getUnqual(val_t),
134         llvm::PointerType::getUnqual(val_t->getScalarType()), llvm::PointerType::getUnqual(val_t->getScalarType())};
135 
136     // Add the mangling and LLVM arg types for the argument types. Also, detect if
137     // we have variables in the arguments.
138     bool with_var = false;
139     for (decltype(args.size()) i = 0; i < args.size(); ++i) {
140         // Detect variable.
141         if (std::holds_alternative<variable>(args[i])) {
142             with_var = true;
143         }
144 
145         // Name mangling.
146         fname += std::visit([](const auto &v) { return taylor_c_diff_mangle(v); }, args[i]);
147 
148         // Add the arguments separator, if we are not at the
149         // last argument.
150         if (i != args.size() - 1u) {
151             fname += '_';
152         }
153 
154         // Add the LLVM function argument type.
155         fargs.push_back(std::visit(
156             [&](const auto &v) -> llvm::Type * {
157                 using type = detail::uncvref_t<decltype(v)>;
158 
159                 if constexpr (std::is_same_v<type, number>) {
160                     // For numbers, the argument is passed as a scalar
161                     // floating-point value.
162                     return val_t->getScalarType();
163                 } else {
164                     // For vars and params, the argument is an index
165                     // in an array.
166                     return llvm::Type::getInt32Ty(context);
167                 }
168             },
169             args[i]));
170     }
171 
172     // Close the argument list with a ".".
173     // NOTE: this will result in a ".." in the name
174     // if the function has zero arguments.
175     fname += '.';
176 
177     // If we have variables in the arguments, add mangling
178     // for n_uvars.
179     if (with_var) {
180         fname += "n_uvars_{}."_format(n_uvars);
181     }
182 
183     // Finally, add the mangling for the floating-point type.
184     fname += llvm_mangle_type(val_t);
185 
186     // Fill in the hidden dependency arguments. These are all indices.
187     fargs.insert(fargs.end(), boost::numeric_cast<decltype(fargs.size())>(n_hidden_deps),
188                  llvm::Type::getInt32Ty(context));
189 
190     return std::make_pair(std::move(fname), std::move(fargs));
191 }
192 
193 namespace
194 {
195 
196 template <typename T>
taylor_codegen_numparam_num(llvm_state & s,const number & num,std::uint32_t batch_size)197 llvm::Value *taylor_codegen_numparam_num(llvm_state &s, const number &num, std::uint32_t batch_size)
198 {
199     return vector_splat(s.builder(), codegen<T>(s, num), batch_size);
200 }
201 
taylor_codegen_numparam_par(llvm_state & s,const param & p,llvm::Value * par_ptr,std::uint32_t batch_size)202 llvm::Value *taylor_codegen_numparam_par(llvm_state &s, const param &p, llvm::Value *par_ptr, std::uint32_t batch_size)
203 {
204     assert(batch_size > 0u);
205 
206     auto &builder = s.builder();
207 
208     // Determine the index into the parameter array.
209     // LCOV_EXCL_START
210     if (p.idx() > std::numeric_limits<std::uint32_t>::max() / batch_size) {
211         throw std::overflow_error("Overflow detected in the computation of the index into a parameter array");
212     }
213     // LCOV_EXCL_STOP
214     const auto arr_idx = static_cast<std::uint32_t>(p.idx() * batch_size);
215 
216     // Compute the pointer to load from.
217     auto *ptr = builder.CreateInBoundsGEP(par_ptr, {builder.getInt32(arr_idx)});
218 
219     // Load.
220     return load_vector_from_memory(builder, ptr, batch_size);
221 }
222 
223 } // namespace
224 
taylor_codegen_numparam_dbl(llvm_state & s,const number & num,llvm::Value *,std::uint32_t batch_size)225 llvm::Value *taylor_codegen_numparam_dbl(llvm_state &s, const number &num, llvm::Value *, std::uint32_t batch_size)
226 {
227     return taylor_codegen_numparam_num<double>(s, num, batch_size);
228 }
229 
taylor_codegen_numparam_ldbl(llvm_state & s,const number & num,llvm::Value *,std::uint32_t batch_size)230 llvm::Value *taylor_codegen_numparam_ldbl(llvm_state &s, const number &num, llvm::Value *, std::uint32_t batch_size)
231 {
232     return taylor_codegen_numparam_num<long double>(s, num, batch_size);
233 }
234 
235 #if defined(HEYOKA_HAVE_REAL128)
236 
taylor_codegen_numparam_f128(llvm_state & s,const number & num,llvm::Value *,std::uint32_t batch_size)237 llvm::Value *taylor_codegen_numparam_f128(llvm_state &s, const number &num, llvm::Value *, std::uint32_t batch_size)
238 {
239     return taylor_codegen_numparam_num<mppp::real128>(s, num, batch_size);
240 }
241 
242 #endif
243 
taylor_codegen_numparam_dbl(llvm_state & s,const param & p,llvm::Value * par_ptr,std::uint32_t batch_size)244 llvm::Value *taylor_codegen_numparam_dbl(llvm_state &s, const param &p, llvm::Value *par_ptr, std::uint32_t batch_size)
245 {
246     return taylor_codegen_numparam_par(s, p, par_ptr, batch_size);
247 }
248 
taylor_codegen_numparam_ldbl(llvm_state & s,const param & p,llvm::Value * par_ptr,std::uint32_t batch_size)249 llvm::Value *taylor_codegen_numparam_ldbl(llvm_state &s, const param &p, llvm::Value *par_ptr, std::uint32_t batch_size)
250 {
251     return taylor_codegen_numparam_par(s, p, par_ptr, batch_size);
252 }
253 
254 #if defined(HEYOKA_HAVE_REAL128)
255 
taylor_codegen_numparam_f128(llvm_state & s,const param & p,llvm::Value * par_ptr,std::uint32_t batch_size)256 llvm::Value *taylor_codegen_numparam_f128(llvm_state &s, const param &p, llvm::Value *par_ptr, std::uint32_t batch_size)
257 {
258     return taylor_codegen_numparam_par(s, p, par_ptr, batch_size);
259 }
260 
261 #endif
262 
263 // Codegen helpers for number/param for use in the generic c_diff implementations.
taylor_c_diff_numparam_codegen(llvm_state & s,const number &,llvm::Value * n,llvm::Value *,std::uint32_t batch_size)264 llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &s, const number &, llvm::Value *n, llvm::Value *,
265                                             std::uint32_t batch_size)
266 {
267     return vector_splat(s.builder(), n, batch_size);
268 }
269 
taylor_c_diff_numparam_codegen(llvm_state & s,const param &,llvm::Value * p,llvm::Value * par_ptr,std::uint32_t batch_size)270 llvm::Value *taylor_c_diff_numparam_codegen(llvm_state &s, const param &, llvm::Value *p, llvm::Value *par_ptr,
271                                             std::uint32_t batch_size)
272 {
273     auto &builder = s.builder();
274 
275     // Fetch the pointer into par_ptr.
276     // NOTE: the overflow check is done in taylor_compute_jet().
277     auto *ptr = builder.CreateInBoundsGEP(par_ptr, {builder.CreateMul(p, builder.getInt32(batch_size))});
278 
279     return load_vector_from_memory(builder, ptr, batch_size);
280 }
281 
282 // Helper to fetch the derivative of order 'order' of the u variable at index u_idx from the
283 // derivative array 'arr'. The total number of u variables is n_uvars.
taylor_fetch_diff(const std::vector<llvm::Value * > & arr,std::uint32_t u_idx,std::uint32_t order,std::uint32_t n_uvars)284 llvm::Value *taylor_fetch_diff(const std::vector<llvm::Value *> &arr, std::uint32_t u_idx, std::uint32_t order,
285                                std::uint32_t n_uvars)
286 {
287     // Sanity check.
288     assert(u_idx < n_uvars);
289 
290     // Compute the index.
291     const auto idx = static_cast<decltype(arr.size())>(order) * n_uvars + u_idx;
292     assert(idx < arr.size());
293 
294     return arr[idx];
295 }
296 
297 // Load the derivative of order 'order' of the u variable u_idx from the array of Taylor derivatives diff_arr.
298 // n_uvars is the total number of u variables.
taylor_c_load_diff(llvm_state & s,llvm::Value * diff_arr,std::uint32_t n_uvars,llvm::Value * order,llvm::Value * u_idx)299 llvm::Value *taylor_c_load_diff(llvm_state &s, llvm::Value *diff_arr, std::uint32_t n_uvars, llvm::Value *order,
300                                 llvm::Value *u_idx)
301 {
302     auto &builder = s.builder();
303 
304     // NOTE: overflow check has already been done to ensure that the
305     // total size of diff_arr fits in a 32-bit unsigned integer.
306     auto *ptr = builder.CreateInBoundsGEP(
307         diff_arr, {builder.CreateAdd(builder.CreateMul(order, builder.getInt32(n_uvars)), u_idx)});
308 
309     return builder.CreateLoad(ptr);
310 }
311 
312 namespace
313 {
314 
315 // Simplify a Taylor decomposition by removing
316 // common subexpressions.
317 // NOTE: the hidden deps are not considered for CSE
318 // purposes, only the actual subexpressions.
taylor_decompose_cse(taylor_dc_t & v_ex,std::vector<std::uint32_t> & sv_funcs_dc,taylor_dc_t::size_type n_eq)319 taylor_dc_t taylor_decompose_cse(taylor_dc_t &v_ex, std::vector<std::uint32_t> &sv_funcs_dc,
320                                  taylor_dc_t::size_type n_eq)
321 {
322     using idx_t = taylor_dc_t::size_type;
323 
324     // Log runtime in trace mode.
325     spdlog::stopwatch sw;
326 
327     // Cache the original size for logging later.
328     const auto orig_size = v_ex.size();
329 
330     // A Taylor decomposition is supposed
331     // to have n_eq variables at the beginning,
332     // n_eq variables at the end and possibly
333     // extra variables in the middle.
334     assert(v_ex.size() >= n_eq * 2u);
335 
336     // Init the return value.
337     taylor_dc_t retval;
338 
339     // expression -> idx map. This will end up containing
340     // all the unique expressions from v_ex, and it will
341     // map them to their indices in retval (which will
342     // in general differ from their indices in v_ex).
343     std::unordered_map<expression, idx_t> ex_map;
344 
345     // Map for the renaming of u variables
346     // in the expressions.
347     std::unordered_map<std::string, std::string> uvars_rename;
348 
349     // The first n_eq definitions are just renaming
350     // of the state variables into u variables.
351     for (idx_t i = 0; i < n_eq; ++i) {
352         assert(std::holds_alternative<variable>(v_ex[i].first.value()));
353         // NOTE: no hidden deps allowed here.
354         assert(v_ex[i].second.empty());
355         retval.push_back(std::move(v_ex[i]));
356 
357         // NOTE: the u vars that correspond to state
358         // variables are never simplified,
359         // thus map them onto themselves.
360         [[maybe_unused]] const auto res = uvars_rename.emplace("u_{}"_format(i), "u_{}"_format(i));
361         assert(res.second);
362     }
363 
364     // Handle the u variables which do not correspond to state variables.
365     for (auto i = n_eq; i < v_ex.size() - n_eq; ++i) {
366         auto &[ex, deps] = v_ex[i];
367 
368         // Rename the u variables in ex.
369         rename_variables(ex, uvars_rename);
370 
371         if (auto it = ex_map.find(ex); it == ex_map.end()) {
372             // This is the first occurrence of ex in the
373             // decomposition. Add it to retval.
374             retval.emplace_back(ex, std::move(deps));
375 
376             // Add ex to ex_map, mapping it to
377             // the index it corresponds to in retval
378             // (let's call it j).
379             ex_map.emplace(std::move(ex), retval.size() - 1u);
380 
381             // Update uvars_rename. This will ensure that
382             // occurrences of the variable 'u_i' in the next
383             // elements of v_ex will be renamed to 'u_j'.
384             [[maybe_unused]] const auto res = uvars_rename.emplace("u_{}"_format(i), "u_{}"_format(retval.size() - 1u));
385             assert(res.second);
386         } else {
387             // ex is redundant. This means
388             // that it already appears in retval at index
389             // it->second. Don't add anything to retval,
390             // and remap the variable name 'u_i' to
391             // 'u_{it->second}'.
392             [[maybe_unused]] const auto res = uvars_rename.emplace("u_{}"_format(i), "u_{}"_format(it->second));
393             assert(res.second);
394         }
395     }
396 
397     // Handle the derivatives of the state variables at the
398     // end of the decomposition. We just need to ensure that
399     // the u variables in their definitions are renamed with
400     // the new indices.
401     for (auto i = v_ex.size() - n_eq; i < v_ex.size(); ++i) {
402         auto &[ex, deps] = v_ex[i];
403 
404         // NOTE: here we expect only vars, numbers or params,
405         // and no hidden dependencies.
406         assert(std::holds_alternative<variable>(ex.value()) || std::holds_alternative<number>(ex.value())
407                || std::holds_alternative<param>(ex.value()));
408         assert(deps.empty());
409 
410         rename_variables(ex, uvars_rename);
411 
412         retval.emplace_back(std::move(ex), std::move(deps));
413     }
414 
415     // Re-adjust all indices in the hidden dependencies in order to account
416     // for the renaming of the uvars.
417     for (auto &[_, deps] : retval) {
418         for (auto &idx : deps) {
419             auto it = uvars_rename.find("u_{}"_format(idx));
420             assert(it != uvars_rename.end());
421             idx = uname_to_index(it->second);
422         }
423     }
424 
425     // Same for the indices in sv_funcs_dc.
426     for (auto &idx : sv_funcs_dc) {
427         auto it = uvars_rename.find("u_{}"_format(idx));
428         assert(it != uvars_rename.end());
429         idx = uname_to_index(it->second);
430     }
431 
432     get_logger()->debug("Taylor CSE reduced decomposition size from {} to {}", orig_size, retval.size());
433     get_logger()->trace("Taylor CSE runtime: {}", sw);
434 
435     return retval;
436 }
437 
438 // Perform a topological sort on a graph representation
439 // of the Taylor decomposition. This can improve performance
440 // by grouping together operations that can be performed in parallel,
441 // and it also makes compact mode much more effective by creating
442 // clusters of subexpressions whose derivatives can be computed in
443 // parallel.
444 // NOTE: the original decomposition dc is already topologically sorted,
445 // in the sense that the definitions of the u variables are already
446 // ordered according to dependency. However, because the original decomposition
447 // comes from a depth-first search, it has the tendency to group together
448 // expressions which are dependent on each other. By doing another topological
449 // sort, this time based on breadth-first search, we determine another valid
450 // sorting in which independent operations tend to be clustered together.
taylor_sort_dc(taylor_dc_t & dc,std::vector<std::uint32_t> & sv_funcs_dc,taylor_dc_t::size_type n_eq)451 auto taylor_sort_dc(taylor_dc_t &dc, std::vector<std::uint32_t> &sv_funcs_dc, taylor_dc_t::size_type n_eq)
452 {
453     // A Taylor decomposition is supposed
454     // to have n_eq variables at the beginning,
455     // n_eq variables at the end and possibly
456     // extra variables in the middle
457     assert(dc.size() >= n_eq * 2u);
458 
459     // Log runtime in trace mode.
460     spdlog::stopwatch sw;
461 
462     // The graph type that we will use for the topological sorting.
463     using graph_t = boost::adjacency_list<boost::vecS,           // std::vector for list of adjacent vertices
464                                           boost::vecS,           // std::vector for the list of vertices
465                                           boost::bidirectionalS, // directed graph with efficient access
466                                                                  // to in-edges
467                                           boost::no_property,    // no vertex properties
468                                           boost::no_property,    // no edge properties
469                                           boost::no_property,    // no graph properties
470                                           boost::listS           // std::list for of the graph's edge list
471                                           >;
472 
473     graph_t g;
474 
475     // Add the root node.
476     const auto root_v = boost::add_vertex(g);
477 
478     // Add the nodes corresponding to the state variables.
479     for (decltype(n_eq) i = 0; i < n_eq; ++i) {
480         auto v = boost::add_vertex(g);
481 
482         // Add a dependency on the root node.
483         boost::add_edge(root_v, v, g);
484     }
485 
486     // Add the rest of the u variables.
487     for (decltype(n_eq) i = n_eq; i < dc.size() - n_eq; ++i) {
488         auto v = boost::add_vertex(g);
489 
490         // Fetch the list of variables in the current expression.
491         const auto vars = get_variables(dc[i].first);
492 
493         if (vars.empty()) {
494             // The current expression does not contain
495             // any variable: make it depend on the root
496             // node. This means that in the topological
497             // sort below, the current u var will appear
498             // immediately after the state variables.
499             boost::add_edge(root_v, v, g);
500         } else {
501             // Mark the current u variable as depending on all the
502             // variables in the current expression.
503             for (const auto &var : vars) {
504                 // Extract the index.
505                 const auto idx = uname_to_index(var);
506 
507                 // Add the dependency.
508                 // NOTE: add +1 because the i-th vertex
509                 // corresponds to the (i-1)-th u variable
510                 // due to the presence of the root node.
511                 boost::add_edge(boost::vertex(idx + 1u, g), v, g);
512             }
513         }
514     }
515 
516     assert(boost::num_vertices(g) - 1u == dc.size() - n_eq);
517 
518     // Run the BF topological sort on the graph. This is Kahn's algorithm:
519     // https://en.wikipedia.org/wiki/Topological_sorting
520 
521     // The result of the sort.
522     std::vector<decltype(dc.size())> v_idx;
523 
524     // Temp variable used to sort a list of edges in the loop below.
525     std::vector<boost::graph_traits<graph_t>::edge_descriptor> tmp_edges;
526 
527     // The set of all nodes with no incoming edge.
528     std::deque<decltype(dc.size())> tmp;
529     // The root node has no incoming edge.
530     tmp.push_back(0);
531 
532     // Main loop.
533     while (!tmp.empty()) {
534         // Pop the first element from tmp
535         // and append it to the result.
536         const auto v = tmp.front();
537         tmp.pop_front();
538         v_idx.push_back(v);
539 
540         // Fetch all the out edges of v and sort them according
541         // to the target vertex.
542         // NOTE: the sorting is important to ensure that all the state
543         // variables are insered into v_idx in the correct order.
544         const auto e_range = boost::out_edges(v, g);
545         tmp_edges.assign(e_range.first, e_range.second);
546         std::sort(tmp_edges.begin(), tmp_edges.end(),
547                   [&g](const auto &e1, const auto &e2) { return boost::target(e1, g) < boost::target(e2, g); });
548 
549         // For each out edge of v:
550         // - eliminate it;
551         // - check if the target vertex of the edge
552         //   has other incoming edges;
553         // - if it does not, insert it into tmp.
554         for (auto &e : tmp_edges) {
555             // Fetch the target of the edge.
556             const auto t = boost::target(e, g);
557 
558             // Remove the edge.
559             boost::remove_edge(e, g);
560 
561             // Get the range of vertices connecting to t.
562             const auto iav = boost::inv_adjacent_vertices(t, g);
563 
564             if (iav.first == iav.second) {
565                 // t does not have any incoming edges, add it to tmp.
566                 tmp.push_back(t);
567             }
568         }
569     }
570 
571     assert(v_idx.size() == boost::num_vertices(g));
572     assert(boost::num_edges(g) == 0u);
573 
574     // Adjust v_idx: remove the index of the root node,
575     // decrease by one all other indices, insert the final
576     // n_eq indices.
577     for (decltype(v_idx.size()) i = 0; i < v_idx.size() - 1u; ++i) {
578         v_idx[i] = v_idx[i + 1u] - 1u;
579     }
580     v_idx.resize(boost::numeric_cast<decltype(v_idx.size())>(dc.size()));
581     std::iota(v_idx.data() + dc.size() - n_eq, v_idx.data() + dc.size(), dc.size() - n_eq);
582 
583     // Create the remapping dictionary.
584     std::unordered_map<std::string, std::string> remap;
585     // NOTE: the u vars that correspond to state
586     // variables were inserted into v_idx in the original
587     // order, thus they are not re-sorted and they do not
588     // need renaming.
589     for (decltype(v_idx.size()) i = 0; i < n_eq; ++i) {
590         [[maybe_unused]] const auto res = remap.emplace("u_{}"_format(i), "u_{}"_format(i));
591         assert(res.second);
592     }
593     // Establish the remapping for the u variables that are not
594     // state variables.
595     for (decltype(v_idx.size()) i = n_eq; i < v_idx.size() - n_eq; ++i) {
596         [[maybe_unused]] const auto res = remap.emplace("u_{}"_format(v_idx[i]), "u_{}"_format(i));
597         assert(res.second);
598     }
599 
600     // Do the remap for the definitions of the u variables, the
601     // derivatives and the hidden deps.
602     for (auto *it = dc.data() + n_eq; it != dc.data() + dc.size(); ++it) {
603         // Remap the expression.
604         rename_variables(it->first, remap);
605 
606         // Remap the hidden dependencies.
607         for (auto &idx : it->second) {
608             auto it_remap = remap.find("u_{}"_format(idx));
609             assert(it_remap != remap.end());
610             idx = uname_to_index(it_remap->second);
611         }
612     }
613 
614     // Do the remap for sv_funcs.
615     for (auto &idx : sv_funcs_dc) {
616         auto it_remap = remap.find("u_{}"_format(idx));
617         assert(it_remap != remap.end());
618         idx = uname_to_index(it_remap->second);
619     }
620 
621     // Reorder the decomposition.
622     taylor_dc_t retval;
623     retval.reserve(v_idx.size());
624     for (auto idx : v_idx) {
625         retval.push_back(std::move(dc[idx]));
626     }
627 
628     get_logger()->trace("Taylor topological sort runtime: {}", sw);
629 
630     return retval;
631 }
632 
633 // LCOV_EXCL_START
634 
635 #if !defined(NDEBUG)
636 
637 // Helper to verify a Taylor decomposition.
verify_taylor_dec(const std::vector<expression> & orig,const taylor_dc_t & dc)638 void verify_taylor_dec(const std::vector<expression> &orig, const taylor_dc_t &dc)
639 {
640     using idx_t = taylor_dc_t::size_type;
641 
642     const auto n_eq = orig.size();
643 
644     assert(dc.size() >= n_eq * 2u);
645 
646     // The first n_eq expressions of u variables
647     // must be just variables. No hidden dependencies
648     // are allowed.
649     for (idx_t i = 0; i < n_eq; ++i) {
650         assert(std::holds_alternative<variable>(dc[i].first.value()));
651         assert(dc[i].second.empty());
652     }
653 
654     // From n_eq to dc.size() - n_eq, the expressions
655     // must be functions whose arguments
656     // are either variables in the u_n form,
657     // where n < i, or numbers/params.
658     // The hidden dependencies must contain indices
659     // only in the [n_eq, dc.size() - n_eq) range.
660     for (auto i = n_eq; i < dc.size() - n_eq; ++i) {
661         std::visit(
662             [i](const auto &v) {
663                 using type = detail::uncvref_t<decltype(v)>;
664 
665                 if constexpr (std::is_same_v<type, func>) {
666                     auto check_arg = [i](const auto &arg) {
667                         if (auto p_var = std::get_if<variable>(&arg.value())) {
668                             assert(p_var->name().rfind("u_", 0) == 0);
669                             assert(uname_to_index(p_var->name()) < i);
670                         } else if (std::get_if<number>(&arg.value()) == nullptr
671                                    && std::get_if<param>(&arg.value()) == nullptr) {
672                             assert(false);
673                         }
674                     };
675 
676                     for (const auto &arg : v.args()) {
677                         check_arg(arg);
678                     }
679                 } else {
680                     assert(false);
681                 }
682             },
683             dc[i].first.value());
684 
685         for (auto idx : dc[i].second) {
686             assert(idx >= n_eq);
687             assert(idx < dc.size() - n_eq);
688 
689             // Hidden dep onto itself does not make any sense.
690             assert(idx != i);
691         }
692     }
693 
694     // From dc.size() - n_eq to dc.size(), the expressions
695     // must be either variables in the u_n form, where n < i,
696     // or numbers/params.
697     for (auto i = dc.size() - n_eq; i < dc.size(); ++i) {
698         std::visit(
699             [i](const auto &v) {
700                 using type = detail::uncvref_t<decltype(v)>;
701 
702                 if constexpr (std::is_same_v<type, variable>) {
703                     assert(v.name().rfind("u_", 0) == 0);
704                     assert(uname_to_index(v.name()) < i);
705                 } else if constexpr (!std::is_same_v<type, number> && !std::is_same_v<type, param>) {
706                     assert(false);
707                 }
708             },
709             dc[i].first.value());
710 
711         // No hidden dependencies.
712         assert(dc[i].second.empty());
713     }
714 
715     std::unordered_map<std::string, expression> subs_map;
716 
717     // For each u variable, expand its definition
718     // in terms of state variables or other u variables,
719     // and store it in subs_map.
720     for (idx_t i = 0; i < dc.size() - n_eq; ++i) {
721         subs_map.emplace("u_{}"_format(i), subs(dc[i].first, subs_map));
722     }
723 
724     // Reconstruct the right-hand sides of the system
725     // and compare them to the original ones.
726     for (auto i = dc.size() - n_eq; i < dc.size(); ++i) {
727         assert(subs(dc[i].first, subs_map) == orig[i - (dc.size() - n_eq)]);
728     }
729 }
730 
731 // Helper to verify the decomposition of the sv funcs.
verify_taylor_dec_sv_funcs(const std::vector<std::uint32_t> & sv_funcs_dc,const std::vector<expression> & sv_funcs,const taylor_dc_t & dc,std::vector<expression>::size_type n_eq)732 void verify_taylor_dec_sv_funcs(const std::vector<std::uint32_t> &sv_funcs_dc, const std::vector<expression> &sv_funcs,
733                                 const taylor_dc_t &dc, std::vector<expression>::size_type n_eq)
734 {
735     assert(sv_funcs.size() == sv_funcs_dc.size());
736 
737     std::unordered_map<std::string, expression> subs_map;
738 
739     // For each u variable, expand its definition
740     // in terms of state variables or other u variables,
741     // and store it in subs_map.
742     for (decltype(dc.size()) i = 0; i < dc.size() - n_eq; ++i) {
743         subs_map.emplace("u_{}"_format(i), subs(dc[i].first, subs_map));
744     }
745 
746     // Reconstruct the sv functions and compare them to the
747     // original ones.
748     for (decltype(sv_funcs.size()) i = 0; i < sv_funcs.size(); ++i) {
749         assert(sv_funcs_dc[i] < dc.size());
750 
751         auto sv_func = subs(dc[sv_funcs_dc[i]].first, subs_map);
752         assert(sv_func == sv_funcs[i]);
753     }
754 }
755 
756 #endif
757 
758 // LCOV_EXCL_STOP
759 
760 // A couple of helpers for deep-copying containers of expressions.
copy(const std::vector<expression> & v_ex)761 auto copy(const std::vector<expression> &v_ex)
762 {
763     std::vector<expression> ret;
764     ret.reserve(v_ex.size());
765 
766     std::transform(v_ex.begin(), v_ex.end(), std::back_inserter(ret), [](const expression &e) { return copy(e); });
767 
768     return ret;
769 }
770 
copy(const std::vector<std::pair<expression,expression>> & v)771 auto copy(const std::vector<std::pair<expression, expression>> &v)
772 {
773     std::vector<std::pair<expression, expression>> ret;
774     ret.reserve(v.size());
775 
776     std::transform(v.begin(), v.end(), std::back_inserter(ret), [](const auto &p) {
777         return std::pair{copy(p.first), copy(p.second)};
778     });
779 
780     return ret;
781 }
782 
783 } // namespace
784 
785 } // namespace detail
786 
787 // Taylor decomposition with automatic deduction
788 // of variables.
789 // NOTE: when dealing with functions with hidden deps,
790 // we should consider avoiding adding hidden deps if the
791 // function argument(s) is a number/param: the hidden deps
792 // won't be used for the computation of the derivatives
793 // and thus they can be optimised out. Note that a straightforward
794 // implementation of this idea this will only work when the argument
795 // is a number/param, not when, e.g., the argument is par[0] + par[1] - in
796 // order to simplify this out, it should be recognized that the definition
797 // of a u variable depends only on numbers/params.
taylor_decompose(const std::vector<expression> & v_ex_,const std::vector<expression> & sv_funcs_)798 std::pair<taylor_dc_t, std::vector<std::uint32_t>> taylor_decompose(const std::vector<expression> &v_ex_,
799                                                                     const std::vector<expression> &sv_funcs_)
800 {
801     // Need to operate on copies due to in-place mutation
802     // via rename_variables().
803     auto v_ex = detail::copy(v_ex_);
804     auto sv_funcs = detail::copy(sv_funcs_);
805 
806     if (v_ex.empty()) {
807         throw std::invalid_argument("Cannot decompose a system of zero equations");
808     }
809 
810     // Determine the variables in the system of equations.
811     std::set<std::string> vars;
812     for (const auto &ex : v_ex) {
813         for (const auto &var : get_variables(ex)) {
814             vars.emplace(var);
815         }
816     }
817     if (vars.size() != v_ex.size()) {
818         throw std::invalid_argument(
819             "The number of deduced variables for a Taylor decomposition ({}) differs from the number of equations ({})"_format(
820                 vars.size(), v_ex.size()));
821     }
822 
823     // Check that the expressions in sv_funcs contain only
824     // state variables.
825     for (const auto &ex : sv_funcs) {
826         for (const auto &var : get_variables(ex)) {
827             if (vars.find(var) == vars.end()) {
828                 throw std::invalid_argument("The extra functions in a Taylor decomposition contain the variable '{}', "
829                                             "which is not a state variable"_format(var));
830             }
831         }
832     }
833 
834     // Cache the number of equations/variables
835     // for later use.
836     const auto n_eq = v_ex.size();
837 
838     // Create the map for renaming the variables to u_i.
839     // The renaming will be done in alphabetical order.
840     std::unordered_map<std::string, std::string> repl_map;
841     {
842         decltype(vars.size()) var_idx = 0;
843         for (const auto &var : vars) {
844             [[maybe_unused]] const auto eres = repl_map.emplace(var, "u_{}"_format(var_idx++));
845             assert(eres.second);
846         }
847     }
848 
849 #if !defined(NDEBUG)
850 
851     // Store a copy of the original system and
852     // sv_funcs for checking later.
853     auto orig_v_ex = detail::copy(v_ex);
854     auto orig_sv_funcs = detail::copy(sv_funcs);
855 
856 #endif
857 
858     // Rename the variables in the original equations.
859     for (auto &ex : v_ex) {
860         rename_variables(ex, repl_map);
861     }
862 
863     // Rename the variables in sv_funcs.
864     for (auto &ex : sv_funcs) {
865         rename_variables(ex, repl_map);
866     }
867 
868     // Init the decomposition. It begins with a list
869     // of the original variables of the system.
870     taylor_dc_t u_vars_defs;
871     u_vars_defs.reserve(vars.size());
872     for (const auto &var : vars) {
873         u_vars_defs.emplace_back(variable{var}, std::vector<std::uint32_t>{});
874     }
875 
876     // Log the construction runtime in trace mode.
877     spdlog::stopwatch sw;
878 
879     // Run the decomposition on each equation.
880     for (auto &ex : v_ex) {
881         // Decompose the current equation.
882         if (const auto dres = taylor_decompose(ex, u_vars_defs)) {
883             // NOTE: if the equation was decomposed
884             // (that is, it is not constant or a single variable),
885             // we have to update the original definition
886             // of the equation in v_ex
887             // so that it points to the u variable
888             // that now represents it.
889             // NOTE: all functions are forced to return dres != 0
890             // in the func API, so the only entities that
891             // can return dres == 0 are const/params or
892             // variables.
893             ex = expression{"u_{}"_format(dres)};
894         }
895     }
896 
897     // Decompose sv_funcs, and write into sv_funcs_dc the index
898     // of the u variable which each sv_func corresponds to.
899     std::vector<std::uint32_t> sv_funcs_dc;
900     for (auto &sv_ex : sv_funcs) {
901         if (const auto *var_ptr = std::get_if<variable>(&sv_ex.value())) {
902             // The current sv_func is a variable, add its index to sv_funcs_dc.
903             sv_funcs_dc.push_back(detail::uname_to_index(var_ptr->name()));
904         } else if (const auto dres = taylor_decompose(sv_ex, u_vars_defs)) {
905             // The sv_func was decomposed, add to sv_funcs_dc
906             // the index of the u variable which represents
907             // the result of the decomposition.
908             sv_funcs_dc.push_back(boost::numeric_cast<std::uint32_t>(dres));
909         } else {
910             // The sv_func was not decomposed, meaning it's a const/param.
911             throw std::invalid_argument(
912                 "The extra functions in a Taylor decomposition cannot be constants or parameters");
913         }
914     }
915     assert(sv_funcs_dc.size() == sv_funcs.size());
916 
917     // Append the (possibly updated) definitions of the diff equations
918     // in terms of u variables.
919     for (auto &ex : v_ex) {
920         u_vars_defs.emplace_back(std::move(ex), std::vector<std::uint32_t>{});
921     }
922 
923     detail::get_logger()->trace("Taylor decomposition construction runtime: {}", sw);
924 
925 #if !defined(NDEBUG)
926     // Verify the decomposition.
927     detail::verify_taylor_dec(orig_v_ex, u_vars_defs);
928     detail::verify_taylor_dec_sv_funcs(sv_funcs_dc, orig_sv_funcs, u_vars_defs, n_eq);
929 #endif
930 
931     // Simplify the decomposition.
932     u_vars_defs = detail::taylor_decompose_cse(u_vars_defs, sv_funcs_dc, n_eq);
933 
934 #if !defined(NDEBUG)
935     // Verify the simplified decomposition.
936     detail::verify_taylor_dec(orig_v_ex, u_vars_defs);
937     detail::verify_taylor_dec_sv_funcs(sv_funcs_dc, orig_sv_funcs, u_vars_defs, n_eq);
938 #endif
939 
940     // Run the breadth-first topological sort on the decomposition.
941     u_vars_defs = detail::taylor_sort_dc(u_vars_defs, sv_funcs_dc, n_eq);
942 
943 #if !defined(NDEBUG)
944     // Verify the reordered decomposition.
945     detail::verify_taylor_dec(orig_v_ex, u_vars_defs);
946     detail::verify_taylor_dec_sv_funcs(sv_funcs_dc, orig_sv_funcs, u_vars_defs, n_eq);
947 #endif
948 
949     return std::make_pair(std::move(u_vars_defs), std::move(sv_funcs_dc));
950 }
951 
952 // Taylor decomposition from lhs and rhs
953 // of a system of equations.
954 std::pair<taylor_dc_t, std::vector<std::uint32_t>>
taylor_decompose(const std::vector<std::pair<expression,expression>> & sys_,const std::vector<expression> & sv_funcs_)955 taylor_decompose(const std::vector<std::pair<expression, expression>> &sys_, const std::vector<expression> &sv_funcs_)
956 {
957     // Need to operate on copies due to in-place mutation
958     // via rename_variables().
959     auto sys = detail::copy(sys_);
960     auto sv_funcs = detail::copy(sv_funcs_);
961 
962     if (sys.empty()) {
963         throw std::invalid_argument("Cannot decompose a system of zero equations");
964     }
965 
966     // Determine the variables in the system of equations
967     // from the lhs of the equations. We need to ensure that:
968     // - all the lhs expressions are variables
969     //   and there are no duplicates,
970     // - all the variables in the rhs expressions
971     //   appear in the lhs expressions.
972     // Note that not all variables in the lhs
973     // need to appear in the rhs.
974 
975     // This will eventually contain the list
976     // of all variables in the system.
977     std::vector<std::string> lhs_vars;
978     // Maintain a set as well to check for duplicates.
979     std::unordered_set<std::string> lhs_vars_set;
980     // The set of variables in the rhs.
981     std::unordered_set<std::string> rhs_vars_set;
982 
983     for (const auto &p : sys) {
984         const auto &lhs = p.first;
985         const auto &rhs = p.second;
986 
987         // Infer the variable from the current lhs.
988         std::visit(
989             [&lhs, &lhs_vars, &lhs_vars_set](const auto &v) {
990                 if constexpr (std::is_same_v<detail::uncvref_t<decltype(v)>, variable>) {
991                     // Check if this is a duplicate variable.
992                     if (auto res = lhs_vars_set.emplace(v.name()); res.second) {
993                         // Not a duplicate, add it to lhs_vars.
994                         lhs_vars.emplace_back(v.name());
995                     } else {
996                         // Duplicate, error out.
997                         throw std::invalid_argument(
998                             "Error in the Taylor decomposition of a system of equations: the variable '{}' "
999                             "appears in the left-hand side twice"_format(v.name()));
1000                     }
1001                 } else {
1002                     throw std::invalid_argument(
1003                         "Error in the Taylor decomposition of a system of equations: the "
1004                         "left-hand side contains the expression '{}', which is not a variable"_format(lhs));
1005                 }
1006             },
1007             lhs.value());
1008 
1009         // Update the global list of variables
1010         // for the rhs.
1011         for (auto &var : get_variables(rhs)) {
1012             rhs_vars_set.emplace(std::move(var));
1013         }
1014     }
1015 
1016     // Check that all variables in the rhs appear in the lhs.
1017     for (const auto &var : rhs_vars_set) {
1018         if (lhs_vars_set.find(var) == lhs_vars_set.end()) {
1019             throw std::invalid_argument("Error in the Taylor decomposition of a system of equations: the variable '{}' "
1020                                         "appears in the right-hand side but not in the left-hand side"_format(var));
1021         }
1022     }
1023 
1024     // Check that the expressions in sv_funcs contain only
1025     // state variables.
1026     for (const auto &ex : sv_funcs) {
1027         for (const auto &var : get_variables(ex)) {
1028             if (lhs_vars_set.find(var) == lhs_vars_set.end()) {
1029                 throw std::invalid_argument("The extra functions in a Taylor decomposition contain the variable '{}', "
1030                                             "which is not a state variable"_format(var));
1031             }
1032         }
1033     }
1034 
1035     // Cache the number of equations/variables.
1036     const auto n_eq = sys.size();
1037     assert(n_eq == lhs_vars.size());
1038 
1039     // Create the map for renaming the variables to u_i.
1040     // The renaming will be done following the order of the lhs
1041     // variables.
1042     std::unordered_map<std::string, std::string> repl_map;
1043     for (decltype(lhs_vars.size()) i = 0; i < lhs_vars.size(); ++i) {
1044         [[maybe_unused]] const auto eres = repl_map.emplace(lhs_vars[i], "u_{}"_format(i));
1045         assert(eres.second);
1046     }
1047 
1048 #if !defined(NDEBUG)
1049 
1050     // Store a copy of the original rhs and sv_funcs
1051     // for checking later.
1052     std::vector<expression> orig_rhs;
1053     orig_rhs.reserve(sys.size());
1054     for (const auto &[_, rhs_ex] : sys) {
1055         orig_rhs.push_back(copy(rhs_ex));
1056     }
1057 
1058     auto orig_sv_funcs = detail::copy(sv_funcs);
1059 
1060 #endif
1061 
1062     // Rename the variables in the original equations.
1063     for (auto &[_, rhs_ex] : sys) {
1064         rename_variables(rhs_ex, repl_map);
1065     }
1066 
1067     // Rename the variables in sv_funcs.
1068     for (auto &ex : sv_funcs) {
1069         rename_variables(ex, repl_map);
1070     }
1071 
1072     // Log the construction runtime in trace mode.
1073     spdlog::stopwatch sw;
1074 
1075     // Init the decomposition. It begins with a list
1076     // of the original lhs variables of the system.
1077     taylor_dc_t u_vars_defs;
1078     u_vars_defs.reserve(lhs_vars.size());
1079     for (const auto &var : lhs_vars) {
1080         u_vars_defs.emplace_back(variable{var}, std::vector<std::uint32_t>{});
1081     }
1082 
1083     // Run the decomposition on each equation.
1084     for (auto &[_, ex] : sys) {
1085         // Decompose the current equation.
1086         if (const auto dres = taylor_decompose(ex, u_vars_defs)) {
1087             // NOTE: if the equation was decomposed
1088             // (that is, it is not constant or a single variable),
1089             // we have to update the original definition
1090             // of the equation in sys
1091             // so that it points to the u variable
1092             // that now represents it.
1093             // NOTE: all functions are forced to return dres != 0
1094             // in the func API, so the only entities that
1095             // can return dres == 0 are const/params or
1096             // variables.
1097             ex = expression{"u_{}"_format(dres)};
1098         }
1099     }
1100 
1101     // Decompose sv_funcs, and write into sv_funcs_dc the index
1102     // of the u variable which each sv_func corresponds to.
1103     std::vector<std::uint32_t> sv_funcs_dc;
1104     for (auto &sv_ex : sv_funcs) {
1105         if (auto *const var_ptr = std::get_if<variable>(&sv_ex.value())) {
1106             // The current sv_func is a variable, add its index to sv_funcs_dc.
1107             sv_funcs_dc.push_back(detail::uname_to_index(var_ptr->name()));
1108         } else if (const auto dres = taylor_decompose(sv_ex, u_vars_defs)) {
1109             // The sv_func was decomposed, add to sv_funcs_dc
1110             // the index of the u variable which represents
1111             // the result of the decomposition.
1112             sv_funcs_dc.push_back(boost::numeric_cast<std::uint32_t>(dres));
1113         } else {
1114             // The sv_func was not decomposed, meaning it's a const/param.
1115             throw std::invalid_argument(
1116                 "The extra functions in a Taylor decomposition cannot be constants or parameters");
1117         }
1118     }
1119     assert(sv_funcs_dc.size() == sv_funcs.size());
1120 
1121     // Append the (possibly updated) definitions of the diff equations
1122     // in terms of u variables.
1123     for (auto &[_, rhs] : sys) {
1124         u_vars_defs.emplace_back(std::move(rhs), std::vector<std::uint32_t>{});
1125     }
1126 
1127     detail::get_logger()->trace("Taylor decomposition construction runtime: {}", sw);
1128 
1129 #if !defined(NDEBUG)
1130     // Verify the decomposition.
1131     detail::verify_taylor_dec(orig_rhs, u_vars_defs);
1132     detail::verify_taylor_dec_sv_funcs(sv_funcs_dc, orig_sv_funcs, u_vars_defs, n_eq);
1133 #endif
1134 
1135     // Simplify the decomposition.
1136     u_vars_defs = detail::taylor_decompose_cse(u_vars_defs, sv_funcs_dc, n_eq);
1137 
1138 #if !defined(NDEBUG)
1139     // Verify the simplified decomposition.
1140     detail::verify_taylor_dec(orig_rhs, u_vars_defs);
1141     detail::verify_taylor_dec_sv_funcs(sv_funcs_dc, orig_sv_funcs, u_vars_defs, n_eq);
1142 #endif
1143 
1144     // Run the breadth-first topological sort on the decomposition.
1145     u_vars_defs = detail::taylor_sort_dc(u_vars_defs, sv_funcs_dc, n_eq);
1146 
1147 #if !defined(NDEBUG)
1148     // Verify the reordered decomposition.
1149     detail::verify_taylor_dec(orig_rhs, u_vars_defs);
1150     detail::verify_taylor_dec_sv_funcs(sv_funcs_dc, orig_sv_funcs, u_vars_defs, n_eq);
1151 #endif
1152 
1153     return std::make_pair(std::move(u_vars_defs), std::move(sv_funcs_dc));
1154 }
1155 
1156 namespace detail
1157 {
1158 
1159 namespace
1160 {
1161 
1162 // Implementation of the streaming operator for the scalar integrators.
1163 template <typename T>
taylor_adaptive_stream_impl(std::ostream & os,const taylor_adaptive_impl<T> & ta)1164 std::ostream &taylor_adaptive_stream_impl(std::ostream &os, const taylor_adaptive_impl<T> &ta)
1165 {
1166     std::ostringstream oss;
1167     oss.exceptions(std::ios_base::failbit | std::ios_base::badbit);
1168     oss.imbue(std::locale::classic());
1169     oss << std::showpoint;
1170     oss.precision(std::numeric_limits<T>::max_digits10);
1171     oss << std::boolalpha;
1172 
1173     oss << "Tolerance               : " << ta.get_tol() << '\n';
1174     oss << "High accuracy           : " << ta.get_high_accuracy() << '\n';
1175     oss << "Compact mode            : " << ta.get_compact_mode() << '\n';
1176     oss << "Taylor order            : " << ta.get_order() << '\n';
1177     oss << "Dimension               : " << ta.get_dim() << '\n';
1178     oss << "Time                    : " << ta.get_time() << '\n';
1179     oss << "State                   : [";
1180     for (decltype(ta.get_state().size()) i = 0; i < ta.get_state().size(); ++i) {
1181         oss << ta.get_state()[i];
1182         if (i != ta.get_state().size() - 1u) {
1183             oss << ", ";
1184         }
1185     }
1186     oss << "]\n";
1187 
1188     if (!ta.get_pars().empty()) {
1189         oss << "Parameters              : [";
1190         for (decltype(ta.get_pars().size()) i = 0; i < ta.get_pars().size(); ++i) {
1191             oss << ta.get_pars()[i];
1192             if (i != ta.get_pars().size() - 1u) {
1193                 oss << ", ";
1194             }
1195         }
1196         oss << "]\n";
1197     }
1198 
1199     if (ta.with_events()) {
1200         if (!ta.get_t_events().empty()) {
1201             oss << "N of terminal events    : " << ta.get_t_events().size() << '\n';
1202         }
1203 
1204         if (!ta.get_nt_events().empty()) {
1205             oss << "N of non-terminal events: " << ta.get_nt_events().size() << '\n';
1206         }
1207     }
1208 
1209     return os << oss.str();
1210 }
1211 
1212 // Implementation of the streaming operator for the batch integrators.
1213 template <typename T>
taylor_adaptive_batch_stream_impl(std::ostream & os,const taylor_adaptive_batch_impl<T> & ta)1214 std::ostream &taylor_adaptive_batch_stream_impl(std::ostream &os, const taylor_adaptive_batch_impl<T> &ta)
1215 {
1216     std::ostringstream oss;
1217     oss.exceptions(std::ios_base::failbit | std::ios_base::badbit);
1218     oss.imbue(std::locale::classic());
1219     oss << std::showpoint;
1220     oss.precision(std::numeric_limits<T>::max_digits10);
1221     oss << std::boolalpha;
1222 
1223     oss << "Tolerance               : " << ta.get_tol() << '\n';
1224     oss << "High accuracy           : " << ta.get_high_accuracy() << '\n';
1225     oss << "Compact mode            : " << ta.get_compact_mode() << '\n';
1226     oss << "Taylor order            : " << ta.get_order() << '\n';
1227     oss << "Dimension               : " << ta.get_dim() << '\n';
1228     oss << "Batch size              : " << ta.get_batch_size() << '\n';
1229     oss << "Time                    : [";
1230     for (decltype(ta.get_time().size()) i = 0; i < ta.get_time().size(); ++i) {
1231         oss << ta.get_time()[i];
1232         if (i != ta.get_time().size() - 1u) {
1233             oss << ", ";
1234         }
1235     }
1236     oss << "]\n";
1237     oss << "State                   : [";
1238     for (decltype(ta.get_state().size()) i = 0; i < ta.get_state().size(); ++i) {
1239         oss << ta.get_state()[i];
1240         if (i != ta.get_state().size() - 1u) {
1241             oss << ", ";
1242         }
1243     }
1244     oss << "]\n";
1245 
1246     if (!ta.get_pars().empty()) {
1247         oss << "Parameters              : [";
1248         for (decltype(ta.get_pars().size()) i = 0; i < ta.get_pars().size(); ++i) {
1249             oss << ta.get_pars()[i];
1250             if (i != ta.get_pars().size() - 1u) {
1251                 oss << ", ";
1252             }
1253         }
1254         oss << "]\n";
1255     }
1256 
1257     if (ta.with_events()) {
1258         if (!ta.get_t_events().empty()) {
1259             oss << "N of terminal events    : " << ta.get_t_events().size() << '\n';
1260         }
1261 
1262         if (!ta.get_nt_events().empty()) {
1263             oss << "N of non-terminal events: " << ta.get_nt_events().size() << '\n';
1264         }
1265     }
1266 
1267     return os << oss.str();
1268 }
1269 
1270 } // namespace
1271 
1272 template <>
operator <<(std::ostream & os,const taylor_adaptive_impl<double> & ta)1273 std::ostream &operator<<(std::ostream &os, const taylor_adaptive_impl<double> &ta)
1274 {
1275     return taylor_adaptive_stream_impl(os, ta);
1276 }
1277 
1278 template <>
operator <<(std::ostream & os,const taylor_adaptive_impl<long double> & ta)1279 std::ostream &operator<<(std::ostream &os, const taylor_adaptive_impl<long double> &ta)
1280 {
1281     return taylor_adaptive_stream_impl(os, ta);
1282 }
1283 
1284 #if defined(HEYOKA_HAVE_REAL128)
1285 
1286 template <>
operator <<(std::ostream & os,const taylor_adaptive_impl<mppp::real128> & ta)1287 std::ostream &operator<<(std::ostream &os, const taylor_adaptive_impl<mppp::real128> &ta)
1288 {
1289     return taylor_adaptive_stream_impl(os, ta);
1290 }
1291 
1292 #endif
1293 
1294 template <>
operator <<(std::ostream & os,const taylor_adaptive_batch_impl<double> & ta)1295 std::ostream &operator<<(std::ostream &os, const taylor_adaptive_batch_impl<double> &ta)
1296 {
1297     return taylor_adaptive_batch_stream_impl(os, ta);
1298 }
1299 
1300 template <>
operator <<(std::ostream & os,const taylor_adaptive_batch_impl<long double> & ta)1301 std::ostream &operator<<(std::ostream &os, const taylor_adaptive_batch_impl<long double> &ta)
1302 {
1303     return taylor_adaptive_batch_stream_impl(os, ta);
1304 }
1305 
1306 #if defined(HEYOKA_HAVE_REAL128)
1307 
1308 template <>
operator <<(std::ostream & os,const taylor_adaptive_batch_impl<mppp::real128> & ta)1309 std::ostream &operator<<(std::ostream &os, const taylor_adaptive_batch_impl<mppp::real128> &ta)
1310 {
1311     return taylor_adaptive_batch_stream_impl(os, ta);
1312 }
1313 
1314 #endif
1315 
1316 } // namespace detail
1317 
1318 #define HEYOKA_TAYLOR_ENUM_STREAM_CASE(val)                                                                            \
1319     case val:                                                                                                          \
1320         os << #val;                                                                                                    \
1321         break
1322 
operator <<(std::ostream & os,taylor_outcome oc)1323 std::ostream &operator<<(std::ostream &os, taylor_outcome oc)
1324 {
1325     switch (oc) {
1326         HEYOKA_TAYLOR_ENUM_STREAM_CASE(taylor_outcome::success);
1327         HEYOKA_TAYLOR_ENUM_STREAM_CASE(taylor_outcome::step_limit);
1328         HEYOKA_TAYLOR_ENUM_STREAM_CASE(taylor_outcome::time_limit);
1329         HEYOKA_TAYLOR_ENUM_STREAM_CASE(taylor_outcome::err_nf_state);
1330         HEYOKA_TAYLOR_ENUM_STREAM_CASE(taylor_outcome::cb_stop);
1331         default:
1332             if (oc >= taylor_outcome{0}) {
1333                 // Continuing terminal event.
1334                 os << "taylor_outcome::terminal_event_{} (continuing)"_format(static_cast<std::int64_t>(oc));
1335             } else if (oc > taylor_outcome::success) {
1336                 // Stopping terminal event.
1337                 os << "taylor_outcome::terminal_event_{} (stopping)"_format(-static_cast<std::int64_t>(oc) - 1);
1338             } else {
1339                 // Unknown value.
1340                 os << "taylor_outcome::??";
1341             }
1342     }
1343 
1344     return os;
1345 }
1346 
operator <<(std::ostream & os,event_direction dir)1347 std::ostream &operator<<(std::ostream &os, event_direction dir)
1348 {
1349     switch (dir) {
1350         HEYOKA_TAYLOR_ENUM_STREAM_CASE(event_direction::any);
1351         HEYOKA_TAYLOR_ENUM_STREAM_CASE(event_direction::positive);
1352         HEYOKA_TAYLOR_ENUM_STREAM_CASE(event_direction::negative);
1353         default:
1354             // Unknown value.
1355             os << "event_direction::??";
1356     }
1357 
1358     return os;
1359 }
1360 
1361 #undef HEYOKA_TAYLOR_OUTCOME_STREAM_CASE
1362 
1363 namespace detail
1364 {
1365 
1366 namespace
1367 {
1368 
1369 // Helper to create the callback used in the default
1370 // constructor of a non-terminal event.
1371 template <typename T, bool B>
nt_event_def_cb()1372 auto nt_event_def_cb()
1373 {
1374     if constexpr (B) {
1375         return [](taylor_adaptive_batch_impl<T> &, T, int, std::uint32_t) {};
1376     } else {
1377         return [](taylor_adaptive_impl<T> &, T, int) {};
1378     }
1379 }
1380 
1381 } // namespace
1382 
1383 template <typename T, bool B>
nt_event_impl()1384 nt_event_impl<T, B>::nt_event_impl() : nt_event_impl(expression{}, nt_event_def_cb<T, B>())
1385 {
1386 }
1387 
1388 template <typename T, bool B>
finalise_ctor(event_direction d)1389 void nt_event_impl<T, B>::finalise_ctor(event_direction d)
1390 {
1391     if (!callback) {
1392         throw std::invalid_argument("Cannot construct a non-terminal event with an empty callback");
1393     }
1394 
1395     if (d < event_direction::negative || d > event_direction::positive) {
1396         throw std::invalid_argument("Invalid value selected for the direction of a non-terminal event");
1397     }
1398     dir = d;
1399 }
1400 
1401 template <typename T, bool B>
nt_event_impl(const nt_event_impl & o)1402 nt_event_impl<T, B>::nt_event_impl(const nt_event_impl &o) : eq(copy(o.eq)), callback(o.callback), dir(o.dir)
1403 {
1404 }
1405 
1406 template <typename T, bool B>
1407 nt_event_impl<T, B>::nt_event_impl(nt_event_impl &&) noexcept = default;
1408 
1409 template <typename T, bool B>
operator =(const nt_event_impl & o)1410 nt_event_impl<T, B> &nt_event_impl<T, B>::operator=(const nt_event_impl &o)
1411 {
1412     if (this != &o) {
1413         *this = nt_event_impl(o);
1414     }
1415 
1416     return *this;
1417 }
1418 
1419 template <typename T, bool B>
1420 nt_event_impl<T, B> &nt_event_impl<T, B>::operator=(nt_event_impl &&) noexcept = default;
1421 
1422 template <typename T, bool B>
1423 nt_event_impl<T, B>::~nt_event_impl() = default;
1424 
1425 template <typename T, bool B>
get_expression() const1426 const expression &nt_event_impl<T, B>::get_expression() const
1427 {
1428     return eq;
1429 }
1430 
1431 template <typename T, bool B>
get_callback() const1432 const typename nt_event_impl<T, B>::callback_t &nt_event_impl<T, B>::get_callback() const
1433 {
1434     return callback;
1435 }
1436 
1437 template <typename T, bool B>
get_direction() const1438 event_direction nt_event_impl<T, B>::get_direction() const
1439 {
1440     return dir;
1441 }
1442 
1443 template <typename T, bool B>
t_event_impl()1444 t_event_impl<T, B>::t_event_impl() : t_event_impl(expression{})
1445 {
1446 }
1447 
1448 template <typename T, bool B>
finalise_ctor(callback_t cb,T cd,event_direction d)1449 void t_event_impl<T, B>::finalise_ctor(callback_t cb, T cd, event_direction d)
1450 {
1451     using std::isfinite;
1452 
1453     callback = std::move(cb);
1454 
1455     if (!isfinite(cd)) {
1456         throw std::invalid_argument("Cannot set a non-finite cooldown value for a terminal event");
1457     }
1458     cooldown = cd;
1459 
1460     if (d < event_direction::negative || d > event_direction::positive) {
1461         throw std::invalid_argument("Invalid value selected for the direction of a terminal event");
1462     }
1463     dir = d;
1464 }
1465 
1466 template <typename T, bool B>
t_event_impl(const t_event_impl & o)1467 t_event_impl<T, B>::t_event_impl(const t_event_impl &o)
1468     : eq(copy(o.eq)), callback(o.callback), cooldown(o.cooldown), dir(o.dir)
1469 {
1470 }
1471 
1472 template <typename T, bool B>
1473 t_event_impl<T, B>::t_event_impl(t_event_impl &&) noexcept = default;
1474 
1475 template <typename T, bool B>
operator =(const t_event_impl & o)1476 t_event_impl<T, B> &t_event_impl<T, B>::operator=(const t_event_impl &o)
1477 {
1478     if (this != &o) {
1479         *this = t_event_impl(o);
1480     }
1481 
1482     return *this;
1483 }
1484 
1485 template <typename T, bool B>
1486 t_event_impl<T, B> &t_event_impl<T, B>::operator=(t_event_impl &&) noexcept = default;
1487 
1488 template <typename T, bool B>
1489 t_event_impl<T, B>::~t_event_impl() = default;
1490 
1491 template <typename T, bool B>
get_expression() const1492 const expression &t_event_impl<T, B>::get_expression() const
1493 {
1494     return eq;
1495 }
1496 
1497 template <typename T, bool B>
get_callback() const1498 const typename t_event_impl<T, B>::callback_t &t_event_impl<T, B>::get_callback() const
1499 {
1500     return callback;
1501 }
1502 
1503 template <typename T, bool B>
get_direction() const1504 event_direction t_event_impl<T, B>::get_direction() const
1505 {
1506     return dir;
1507 }
1508 
1509 template <typename T, bool B>
1510 T t_event_impl<T, B>::get_cooldown() const
1511 {
1512     return cooldown;
1513 }
1514 
1515 namespace
1516 {
1517 
1518 // Implementation of stream insertion for the non-terminal event class.
nt_event_impl_stream_impl(std::ostream & os,const expression & eq,event_direction dir)1519 std::ostream &nt_event_impl_stream_impl(std::ostream &os, const expression &eq, event_direction dir)
1520 {
1521     os << "Event type     : non-terminal\n";
1522     os << "Event equation : " << eq << '\n';
1523     os << "Event direction: " << dir << '\n';
1524 
1525     return os;
1526 }
1527 
1528 // Implementation of stream insertion for the terminal event class.
1529 template <typename C, typename T>
t_event_impl_stream_impl(std::ostream & os,const expression & eq,event_direction dir,const C & callback,const T & cooldown)1530 std::ostream &t_event_impl_stream_impl(std::ostream &os, const expression &eq, event_direction dir, const C &callback,
1531                                        const T &cooldown)
1532 {
1533     os << "Event type     : terminal\n";
1534     os << "Event equation : " << eq << '\n';
1535     os << "Event direction: " << dir << '\n';
1536     os << "With callback  : " << (callback ? "yes" : "no") << '\n';
1537     os << "Cooldown       : " << (cooldown < 0 ? "auto" : "{}"_format(cooldown)) << '\n';
1538 
1539     return os;
1540 }
1541 
1542 } // namespace
1543 
1544 template <>
operator <<(std::ostream & os,const nt_event_impl<double,false> & e)1545 std::ostream &operator<<(std::ostream &os, const nt_event_impl<double, false> &e)
1546 {
1547     return nt_event_impl_stream_impl(os, e.get_expression(), e.get_direction());
1548 }
1549 
1550 template <>
operator <<(std::ostream & os,const nt_event_impl<double,true> & e)1551 std::ostream &operator<<(std::ostream &os, const nt_event_impl<double, true> &e)
1552 {
1553     return nt_event_impl_stream_impl(os, e.get_expression(), e.get_direction());
1554 }
1555 
1556 template <>
operator <<(std::ostream & os,const nt_event_impl<long double,false> & e)1557 std::ostream &operator<<(std::ostream &os, const nt_event_impl<long double, false> &e)
1558 {
1559     return nt_event_impl_stream_impl(os, e.get_expression(), e.get_direction());
1560 }
1561 
1562 template <>
operator <<(std::ostream & os,const nt_event_impl<long double,true> & e)1563 std::ostream &operator<<(std::ostream &os, const nt_event_impl<long double, true> &e)
1564 {
1565     return nt_event_impl_stream_impl(os, e.get_expression(), e.get_direction());
1566 }
1567 
1568 #if defined(HEYOKA_HAVE_REAL128)
1569 
1570 template <>
operator <<(std::ostream & os,const nt_event_impl<mppp::real128,false> & e)1571 std::ostream &operator<<(std::ostream &os, const nt_event_impl<mppp::real128, false> &e)
1572 {
1573     return nt_event_impl_stream_impl(os, e.get_expression(), e.get_direction());
1574 }
1575 
1576 template <>
operator <<(std::ostream & os,const nt_event_impl<mppp::real128,true> & e)1577 std::ostream &operator<<(std::ostream &os, const nt_event_impl<mppp::real128, true> &e)
1578 {
1579     return nt_event_impl_stream_impl(os, e.get_expression(), e.get_direction());
1580 }
1581 
1582 #endif
1583 
1584 template <>
operator <<(std::ostream & os,const t_event_impl<double,false> & e)1585 std::ostream &operator<<(std::ostream &os, const t_event_impl<double, false> &e)
1586 {
1587     return t_event_impl_stream_impl(os, e.get_expression(), e.get_direction(), e.get_callback(), e.get_cooldown());
1588 }
1589 
1590 template <>
operator <<(std::ostream & os,const t_event_impl<double,true> & e)1591 std::ostream &operator<<(std::ostream &os, const t_event_impl<double, true> &e)
1592 {
1593     return t_event_impl_stream_impl(os, e.get_expression(), e.get_direction(), e.get_callback(), e.get_cooldown());
1594 }
1595 
1596 template <>
operator <<(std::ostream & os,const t_event_impl<long double,false> & e)1597 std::ostream &operator<<(std::ostream &os, const t_event_impl<long double, false> &e)
1598 {
1599     return t_event_impl_stream_impl(os, e.get_expression(), e.get_direction(), e.get_callback(), e.get_cooldown());
1600 }
1601 
1602 template <>
operator <<(std::ostream & os,const t_event_impl<long double,true> & e)1603 std::ostream &operator<<(std::ostream &os, const t_event_impl<long double, true> &e)
1604 {
1605     return t_event_impl_stream_impl(os, e.get_expression(), e.get_direction(), e.get_callback(), e.get_cooldown());
1606 }
1607 
1608 #if defined(HEYOKA_HAVE_REAL128)
1609 
1610 template <>
operator <<(std::ostream & os,const t_event_impl<mppp::real128,false> & e)1611 std::ostream &operator<<(std::ostream &os, const t_event_impl<mppp::real128, false> &e)
1612 {
1613     return t_event_impl_stream_impl(os, e.get_expression(), e.get_direction(), e.get_callback(), e.get_cooldown());
1614 }
1615 
1616 template <>
operator <<(std::ostream & os,const t_event_impl<mppp::real128,true> & e)1617 std::ostream &operator<<(std::ostream &os, const t_event_impl<mppp::real128, true> &e)
1618 {
1619     return t_event_impl_stream_impl(os, e.get_expression(), e.get_direction(), e.get_callback(), e.get_cooldown());
1620 }
1621 
1622 #endif
1623 
1624 // Explicit instantiation of the implementation classes/functions.
1625 template class nt_event_impl<double, false>;
1626 template class t_event_impl<double, false>;
1627 
1628 template class nt_event_impl<double, true>;
1629 template class t_event_impl<double, true>;
1630 
1631 template class nt_event_impl<long double, false>;
1632 template class t_event_impl<long double, false>;
1633 
1634 template class nt_event_impl<long double, true>;
1635 template class t_event_impl<long double, true>;
1636 
1637 #if defined(HEYOKA_HAVE_REAL128)
1638 
1639 template class nt_event_impl<mppp::real128, false>;
1640 template class t_event_impl<mppp::real128, false>;
1641 
1642 template class nt_event_impl<mppp::real128, true>;
1643 template class t_event_impl<mppp::real128, true>;
1644 
1645 #endif
1646 
1647 // Add a function for computing the dense output
1648 // via polynomial evaluation.
1649 template <typename T>
taylor_add_d_out_function(llvm_state & s,std::uint32_t n_eq,std::uint32_t order,std::uint32_t batch_size,bool high_accuracy,bool external_linkage,bool optimise)1650 void taylor_add_d_out_function(llvm_state &s, std::uint32_t n_eq, std::uint32_t order, std::uint32_t batch_size,
1651                                bool high_accuracy, bool external_linkage, bool optimise)
1652 {
1653     assert(n_eq > 0u);
1654     assert(order > 0u);
1655     assert(batch_size > 0u);
1656 
1657     auto &builder = s.builder();
1658     auto &context = s.context();
1659 
1660     // The function arguments:
1661     // - the output pointer (read/write, used also for accumulation),
1662     // - the pointer to the Taylor coefficients (read-only),
1663     // - the pointer to the h values (read-only).
1664     // No overlap is allowed.
1665     std::vector<llvm::Type *> fargs(3, llvm::PointerType::getUnqual(to_llvm_type<T>(context)));
1666     // The function does not return anything.
1667     auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
1668     assert(ft != nullptr);
1669     // Now create the function.
1670     auto *f = llvm::Function::Create(
1671         ft, external_linkage ? llvm::Function::ExternalLinkage : llvm::Function::InternalLinkage, "d_out_f",
1672         &s.module());
1673     // LCOV_EXCL_START
1674     if (f == nullptr) {
1675         throw std::invalid_argument(
1676             "Unable to create a function for the dense output in an adaptive Taylor integrator");
1677     }
1678     // LCOV_EXCL_STOP
1679 
1680     // Set the names/attributes of the function arguments.
1681     auto *out_ptr = f->args().begin();
1682     out_ptr->setName("out_ptr");
1683     out_ptr->addAttr(llvm::Attribute::NoCapture);
1684     out_ptr->addAttr(llvm::Attribute::NoAlias);
1685 
1686     auto *tc_ptr = f->args().begin() + 1;
1687     tc_ptr->setName("tc_ptr");
1688     tc_ptr->addAttr(llvm::Attribute::NoCapture);
1689     tc_ptr->addAttr(llvm::Attribute::NoAlias);
1690     tc_ptr->addAttr(llvm::Attribute::ReadOnly);
1691 
1692     auto *h_ptr = f->args().begin() + 2;
1693     h_ptr->setName("h_ptr");
1694     h_ptr->addAttr(llvm::Attribute::NoCapture);
1695     h_ptr->addAttr(llvm::Attribute::NoAlias);
1696     h_ptr->addAttr(llvm::Attribute::ReadOnly);
1697 
1698     // Create a new basic block to start insertion into.
1699     auto *bb = llvm::BasicBlock::Create(context, "entry", f);
1700     assert(bb != nullptr);
1701     builder.SetInsertPoint(bb);
1702 
1703     // Load the value of h.
1704     auto *h = load_vector_from_memory(builder, h_ptr, batch_size);
1705 
1706     if (high_accuracy) {
1707         // Create the array for storing the running compensations.
1708         auto array_type = llvm::ArrayType::get(make_vector_type(to_llvm_type<T>(context), batch_size), n_eq);
1709         auto comp_arr
1710             = builder.CreateInBoundsGEP(builder.CreateAlloca(array_type), {builder.getInt32(0), builder.getInt32(0)});
1711 
1712         // Start by writing into out_ptr the zero-order coefficients
1713         // and by filling with zeroes the running compensations.
1714         llvm_loop_u32(s, builder.getInt32(0), builder.getInt32(n_eq), [&](llvm::Value *cur_var_idx) {
1715             // Load the coefficient from tc_ptr. The index is:
1716             // batch_size * (order + 1u) * cur_var_idx.
1717             auto *tc_idx = builder.CreateMul(builder.getInt32(batch_size * (order + 1u)), cur_var_idx);
1718             auto *tc = load_vector_from_memory(builder, builder.CreateInBoundsGEP(tc_ptr, {tc_idx}), batch_size);
1719 
1720             // Store it in out_ptr. The index is:
1721             // batch_size * cur_var_idx.
1722             auto *out_idx = builder.CreateMul(builder.getInt32(batch_size), cur_var_idx);
1723             store_vector_to_memory(builder, builder.CreateInBoundsGEP(out_ptr, {out_idx}), tc);
1724 
1725             // Zero-init the element in comp_arr.
1726             builder.CreateStore(vector_splat(builder, codegen<T>(s, number{0.}), batch_size),
1727                                 builder.CreateInBoundsGEP(comp_arr, {cur_var_idx}));
1728         });
1729 
1730         // Init the running updater for the powers of h.
1731         auto *cur_h = builder.CreateAlloca(h->getType());
1732         builder.CreateStore(h, cur_h);
1733 
1734         // Run the evaluation.
1735         llvm_loop_u32(s, builder.getInt32(1), builder.getInt32(order + 1u), [&](llvm::Value *cur_order) {
1736             // Load the current power of h.
1737             auto *cur_h_val = builder.CreateLoad(cur_h);
1738 
1739             llvm_loop_u32(s, builder.getInt32(0), builder.getInt32(n_eq), [&](llvm::Value *cur_var_idx) {
1740                 // Load the coefficient from tc_ptr. The index is:
1741                 // batch_size * (order + 1u) * cur_var_idx + batch_size * cur_order.
1742                 auto *tc_idx
1743                     = builder.CreateAdd(builder.CreateMul(builder.getInt32(batch_size * (order + 1u)), cur_var_idx),
1744                                         builder.CreateMul(builder.getInt32(batch_size), cur_order));
1745                 auto *cf = load_vector_from_memory(builder, builder.CreateInBoundsGEP(tc_ptr, {tc_idx}), batch_size);
1746                 auto *tmp = builder.CreateFMul(cf, cur_h_val);
1747 
1748                 // Compute the quantities for the compensation.
1749                 auto *comp_ptr = builder.CreateInBoundsGEP(comp_arr, {cur_var_idx});
1750                 auto *out_idx = builder.CreateMul(builder.getInt32(batch_size), cur_var_idx);
1751                 auto *res_ptr = builder.CreateInBoundsGEP(out_ptr, {out_idx});
1752                 auto *y = builder.CreateFSub(tmp, builder.CreateLoad(comp_ptr));
1753                 auto *cur_res = load_vector_from_memory(builder, res_ptr, batch_size);
1754                 auto *t = builder.CreateFAdd(cur_res, y);
1755 
1756                 // Update the compensation and the return value.
1757                 builder.CreateStore(builder.CreateFSub(builder.CreateFSub(t, cur_res), y), comp_ptr);
1758                 store_vector_to_memory(builder, res_ptr, t);
1759             });
1760 
1761             // Update the value of h.
1762             builder.CreateStore(builder.CreateFMul(cur_h_val, h), cur_h);
1763         });
1764     } else {
1765         // Start by writing into out_ptr the coefficients of the highest-degree
1766         // monomial in each polynomial.
1767         llvm_loop_u32(s, builder.getInt32(0), builder.getInt32(n_eq), [&](llvm::Value *cur_var_idx) {
1768             // Load the coefficient from tc_ptr. The index is:
1769             // batch_size * (order + 1u) * cur_var_idx + batch_size * order.
1770             auto *tc_idx
1771                 = builder.CreateAdd(builder.CreateMul(builder.getInt32(batch_size * (order + 1u)), cur_var_idx),
1772                                     builder.getInt32(batch_size * order));
1773             auto *tc = load_vector_from_memory(builder, builder.CreateInBoundsGEP(tc_ptr, {tc_idx}), batch_size);
1774 
1775             // Store it in out_ptr. The index is:
1776             // batch_size * cur_var_idx.
1777             auto *out_idx = builder.CreateMul(builder.getInt32(batch_size), cur_var_idx);
1778             store_vector_to_memory(builder, builder.CreateInBoundsGEP(out_ptr, {out_idx}), tc);
1779         });
1780 
1781         // Now let's run the Horner scheme.
1782         llvm_loop_u32(
1783             s, builder.getInt32(1), builder.CreateAdd(builder.getInt32(order), builder.getInt32(1)),
1784             [&](llvm::Value *cur_order) {
1785                 llvm_loop_u32(s, builder.getInt32(0), builder.getInt32(n_eq), [&](llvm::Value *cur_var_idx) {
1786                     // Load the current Taylor coefficient from tc_ptr.
1787                     // NOTE: we are loading the coefficients backwards wrt the order, hence
1788                     // we specify order - cur_order.
1789                     // NOTE: the index is:
1790                     // batch_size * (order + 1u) * cur_var_idx + batch_size * (order - cur_order).
1791                     auto *tc_idx
1792                         = builder.CreateAdd(builder.CreateMul(builder.getInt32(batch_size * (order + 1u)), cur_var_idx),
1793                                             builder.CreateMul(builder.getInt32(batch_size),
1794                                                               builder.CreateSub(builder.getInt32(order), cur_order)));
1795                     auto *tc
1796                         = load_vector_from_memory(builder, builder.CreateInBoundsGEP(tc_ptr, {tc_idx}), batch_size);
1797 
1798                     // Accumulate in out_ptr. The index is:
1799                     // batch_size * cur_var_idx.
1800                     auto *out_idx = builder.CreateMul(builder.getInt32(batch_size), cur_var_idx);
1801                     auto *out_p = builder.CreateInBoundsGEP(out_ptr, {out_idx});
1802                     auto *cur_out = load_vector_from_memory(builder, out_p, batch_size);
1803                     store_vector_to_memory(builder, out_p, builder.CreateFAdd(tc, builder.CreateFMul(cur_out, h)));
1804                 });
1805             });
1806     }
1807 
1808     // Create the return value.
1809     builder.CreateRetVoid();
1810 
1811     // Verify the function.
1812     s.verify_function(f);
1813 
1814     // Run the optimisation pass, if requested.
1815     if (optimise) {
1816         s.optimise();
1817     }
1818 }
1819 
1820 template void taylor_add_d_out_function<double>(llvm_state &, std::uint32_t, std::uint32_t, std::uint32_t, bool, bool,
1821                                                 bool);
1822 template void taylor_add_d_out_function<long double>(llvm_state &, std::uint32_t, std::uint32_t, std::uint32_t, bool,
1823                                                      bool, bool);
1824 
1825 #if defined(HEYOKA_HAVE_REAL128)
1826 
1827 template void taylor_add_d_out_function<mppp::real128>(llvm_state &, std::uint32_t, std::uint32_t, std::uint32_t, bool,
1828                                                        bool, bool);
1829 
1830 #endif
1831 
1832 } // namespace detail
1833 
1834 template <typename T>
add_c_out_function(std::uint32_t order,std::uint32_t dim,bool high_accuracy)1835 void continuous_output<T>::add_c_out_function(std::uint32_t order, std::uint32_t dim, bool high_accuracy)
1836 {
1837     // Overflow check: we want to be able to index into the arrays of
1838     // times and Taylor coefficients using 32-bit ints.
1839     // LCOV_EXCL_START
1840     if (m_tcs.size() > std::numeric_limits<std::uint32_t>::max()
1841         || m_times_hi.size() > std::numeric_limits<std::uint32_t>::max()) {
1842         throw std::overflow_error("Overflow detected while adding continuous output to a Taylor integrator");
1843     }
1844     // LCOV_EXCL_STOP
1845 
1846     auto &md = m_llvm_state.module();
1847     auto &builder = m_llvm_state.builder();
1848     auto &context = m_llvm_state.context();
1849 
1850     // Fetch the current insertion block.
1851     auto orig_bb = builder.GetInsertBlock();
1852 
1853     // Add the function for the computation of the dense output.
1854     detail::taylor_add_d_out_function<T>(m_llvm_state, dim, order, 1, high_accuracy, false, false);
1855 
1856     // Fetch it.
1857     auto d_out_f = md.getFunction("d_out_f");
1858     assert(d_out_f != nullptr); // LCOV_EXCL_LINE
1859 
1860     // Restore the original insertion block.
1861     builder.SetInsertPoint(orig_bb);
1862 
1863     // Establish the time direction.
1864     const detail::dfloat<T> df_t_start(m_times_hi[0], m_times_lo[0]), df_t_end(m_times_hi.back(), m_times_lo.back());
1865     const auto dir = df_t_start < df_t_end;
1866 
1867     // The function arguments:
1868     // - the output pointer (read/write, used also for accumulation),
1869     // - the time value,
1870     // - the pointer to the Taylor coefficients (read-only),
1871     // - the pointer to the hi times (read-only),
1872     // - the pointer to the lo times (read-only).
1873     // No overlap is allowed.
1874     auto fp_t = detail::to_llvm_type<T>(context);
1875     auto ptr_t = llvm::PointerType::getUnqual(fp_t);
1876     std::vector<llvm::Type *> fargs{ptr_t, fp_t, ptr_t, ptr_t, ptr_t};
1877     // The function does not return anything.
1878     auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
1879     assert(ft != nullptr); // LCOV_EXCL_LINE
1880     // Now create the function.
1881     auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "c_out", &md);
1882     // LCOV_EXCL_START
1883     if (f == nullptr) {
1884         throw std::invalid_argument("Unable to create a function for continuous output in a Taylor integrator");
1885     }
1886     // LCOV_EXCL_STOP
1887 
1888     // Set the names/attributes of the function arguments.
1889     auto out_ptr = f->args().begin();
1890     out_ptr->setName("out_ptr");
1891     out_ptr->addAttr(llvm::Attribute::NoCapture);
1892     out_ptr->addAttr(llvm::Attribute::NoAlias);
1893 
1894     auto tm = f->args().begin() + 1;
1895     tm->setName("tm");
1896 
1897     auto tc_ptr = f->args().begin() + 2;
1898     tc_ptr->setName("tc_ptr");
1899     tc_ptr->addAttr(llvm::Attribute::NoCapture);
1900     tc_ptr->addAttr(llvm::Attribute::NoAlias);
1901     tc_ptr->addAttr(llvm::Attribute::ReadOnly);
1902 
1903     auto times_ptr_hi = f->args().begin() + 3;
1904     times_ptr_hi->setName("times_ptr_hi");
1905     times_ptr_hi->addAttr(llvm::Attribute::NoCapture);
1906     times_ptr_hi->addAttr(llvm::Attribute::NoAlias);
1907     times_ptr_hi->addAttr(llvm::Attribute::ReadOnly);
1908 
1909     auto times_ptr_lo = f->args().begin() + 4;
1910     times_ptr_lo->setName("times_ptr_lo");
1911     times_ptr_lo->addAttr(llvm::Attribute::NoCapture);
1912     times_ptr_lo->addAttr(llvm::Attribute::NoAlias);
1913     times_ptr_lo->addAttr(llvm::Attribute::ReadOnly);
1914 
1915     // Create a new basic block to start insertion into.
1916     auto *bb = llvm::BasicBlock::Create(context, "entry", f);
1917     assert(bb != nullptr); // LCOV_EXCL_LINE
1918     builder.SetInsertPoint(bb);
1919 
1920     // Create the variable in which we will store the timestep size.
1921     // This is necessary because the d_out_f function requires a pointer
1922     // to the timestep.
1923     auto h_ptr = builder.CreateAlloca(fp_t);
1924 
1925     // Look for the index in the times vector corresponding to
1926     // a time greater than tm (less than tm in backwards integration).
1927     // This is essentially an implementation of std::upper_bound:
1928     // https://en.cppreference.com/w/cpp/algorithm/upper_bound
1929     auto tidx = builder.CreateAlloca(builder.getInt32Ty());
1930     auto count = builder.CreateAlloca(builder.getInt32Ty());
1931     auto step = builder.CreateAlloca(builder.getInt32Ty());
1932     auto first = builder.CreateAlloca(builder.getInt32Ty());
1933 
1934     // count is inited with the size of the range.
1935     builder.CreateStore(builder.getInt32(static_cast<std::uint32_t>(m_times_hi.size())), count);
1936     // first is inited to zero.
1937     builder.CreateStore(builder.getInt32(0), first);
1938 
1939     detail::llvm_while_loop(
1940         m_llvm_state, [&]() { return builder.CreateICmpNE(builder.CreateLoad(count), builder.getInt32(0)); },
1941         [&]() {
1942             // tidx = first.
1943             builder.CreateStore(builder.CreateLoad(first), tidx);
1944             // step = count / 2.
1945             builder.CreateStore(builder.CreateUDiv(builder.CreateLoad(count), builder.getInt32(2)), step);
1946             // tidx = tidx + step.
1947             builder.CreateStore(builder.CreateAdd(builder.CreateLoad(tidx), builder.CreateLoad(step)), tidx);
1948 
1949             // Logical condition:
1950             // - !(tm < *tidx), if integrating forward,
1951             // - !(tm > *tidx), if integrating backward.
1952             auto tidx_val_hi = builder.CreateLoad(builder.CreateInBoundsGEP(times_ptr_hi, {builder.CreateLoad(tidx)}));
1953             auto tidx_val_lo = builder.CreateLoad(builder.CreateInBoundsGEP(times_ptr_lo, {builder.CreateLoad(tidx)}));
1954             auto zero_val = llvm::ConstantFP::get(fp_t, 0.);
1955             auto cond = dir ? detail::llvm_dl_lt(m_llvm_state, tm, zero_val, tidx_val_hi, tidx_val_lo)
1956                             : detail::llvm_dl_gt(m_llvm_state, tm, zero_val, tidx_val_hi, tidx_val_lo);
1957             cond = builder.CreateNot(cond);
1958 
1959             detail::llvm_if_then_else(
1960                 m_llvm_state, cond,
1961                 [&]() {
1962                     // ++tidx.
1963                     builder.CreateStore(builder.CreateAdd(builder.CreateLoad(tidx), builder.getInt32(1)), tidx);
1964                     // first = tidx.
1965                     builder.CreateStore(builder.CreateLoad(tidx), first);
1966                     // count = count - (step + 1).
1967                     builder.CreateStore(
1968                         builder.CreateSub(builder.CreateLoad(count),
1969                                           builder.CreateAdd(builder.CreateLoad(step), builder.getInt32(1))),
1970                         count);
1971                 },
1972                 [&]() {
1973                     // count = step.
1974                     builder.CreateStore(builder.CreateLoad(step), count);
1975                 });
1976         });
1977 
1978     // NOTE: the output of the std::upper_bound algorithm
1979     // is in the 'first' variable.
1980     llvm::Value *tc_idx = builder.CreateLoad(first);
1981 
1982     // Normally, the TC index should be first - 1. The exceptions are:
1983     // - first == 0, in which case TC index is also 0,
1984     // - first == range size, in which case TC index is first - 2.
1985     // These two exceptions arise when tm is outside the range of validity
1986     // for the continuous output. In such cases, we will use either the first
1987     // or the last possible set of TCs.
1988     detail::llvm_if_then_else(
1989         m_llvm_state, builder.CreateICmpEQ(tc_idx, builder.getInt32(0)),
1990         [&]() {
1991             // first == 0, do nothing.
1992         },
1993         [&]() {
1994             detail::llvm_if_then_else(
1995                 m_llvm_state,
1996                 builder.CreateICmpEQ(tc_idx, builder.getInt32(static_cast<std::uint32_t>(m_times_hi.size()))),
1997                 [&]() {
1998                     // first == range size.
1999                     builder.CreateStore(builder.CreateSub(tc_idx, builder.getInt32(2)), first);
2000                 },
2001                 [&]() {
2002                     // The normal path.
2003                     builder.CreateStore(builder.CreateSub(tc_idx, builder.getInt32(1)), first);
2004                 });
2005         });
2006 
2007     // Reload tc_idx.
2008     tc_idx = builder.CreateLoad(first);
2009 
2010     // Load the time corresponding to tc_idx.
2011     auto start_tm_hi = builder.CreateLoad(builder.CreateInBoundsGEP(times_ptr_hi, {tc_idx}));
2012     auto start_tm_lo = builder.CreateLoad(builder.CreateInBoundsGEP(times_ptr_lo, {tc_idx}));
2013 
2014     // Compute and store the value of h = tm - start_tm.
2015     auto [h_hi, h_lo] = detail::llvm_dl_add(m_llvm_state, tm, llvm::ConstantFP::get(fp_t, 0.),
2016                                             builder.CreateFNeg(start_tm_hi), builder.CreateFNeg(start_tm_lo));
2017     builder.CreateStore(h_hi, h_ptr);
2018 
2019     // Compute the index into the Taylor coefficients array.
2020     tc_idx = builder.CreateMul(tc_idx, builder.getInt32(dim * (order + 1u)));
2021 
2022     // Invoke the d_out function.
2023     builder.CreateCall(d_out_f, {out_ptr, builder.CreateInBoundsGEP(tc_ptr, {tc_idx}), h_ptr});
2024 
2025     // Create the return value.
2026     builder.CreateRetVoid();
2027 
2028     // Verify the function.
2029     m_llvm_state.verify_function(f);
2030 
2031     // Run the optimisation pass.
2032     m_llvm_state.optimise();
2033 
2034     // Compile.
2035     m_llvm_state.compile();
2036 
2037     // Fetch the function pointer and assign it.
2038     m_f_ptr = reinterpret_cast<fptr_t>(m_llvm_state.jit_lookup("c_out"));
2039 }
2040 
2041 template <typename T>
2042 continuous_output<T>::continuous_output() = default;
2043 
2044 template <typename T>
continuous_output(llvm_state && s)2045 continuous_output<T>::continuous_output(llvm_state &&s) : m_llvm_state(std::move(s))
2046 {
2047 }
2048 
2049 template <typename T>
continuous_output(const continuous_output & o)2050 continuous_output<T>::continuous_output(const continuous_output &o)
2051     : m_llvm_state(o.m_llvm_state), m_tcs(o.m_tcs), m_times_hi(o.m_times_hi), m_times_lo(o.m_times_lo),
2052       m_output(o.m_output)
2053 {
2054     // If o is valid, fetch the function pointer from the copied state.
2055     // Otherwise, m_f_ptr will remain null.
2056     if (o.m_f_ptr != nullptr) {
2057         m_f_ptr = reinterpret_cast<fptr_t>(m_llvm_state.jit_lookup("c_out"));
2058     }
2059 }
2060 
2061 template <typename T>
2062 continuous_output<T>::continuous_output(continuous_output &&) noexcept = default;
2063 
2064 template <typename T>
2065 continuous_output<T>::~continuous_output() = default;
2066 
2067 template <typename T>
operator =(const continuous_output & o)2068 continuous_output<T> &continuous_output<T>::operator=(const continuous_output &o)
2069 {
2070     if (this != &o) {
2071         *this = continuous_output(o);
2072     }
2073 
2074     return *this;
2075 }
2076 
2077 template <typename T>
2078 continuous_output<T> &continuous_output<T>::operator=(continuous_output &&) noexcept = default;
2079 
2080 template <typename T>
call_impl(T t)2081 void continuous_output<T>::call_impl(T t)
2082 {
2083     using std::isfinite;
2084 
2085     if (m_f_ptr == nullptr) {
2086         throw std::invalid_argument("Cannot use a default-constructed continuous_output object");
2087     }
2088 
2089     // NOTE: run the assertions only after ensuring this
2090     // is a valid object.
2091 
2092     // LCOV_EXCL_START
2093 #if !defined(NDEBUG)
2094     // m_output must not be empty.
2095     assert(!m_output.empty());
2096     // Need at least 2 time points.
2097     assert(m_times_hi.size() >= 2u);
2098     // hi/lo parts of times must have the same sizes.
2099     assert(m_times_hi.size() == m_times_lo.size());
2100 #endif
2101     // LCOV_EXCL_STOP
2102 
2103     if (!isfinite(t)) {
2104         throw std::invalid_argument("Cannot compute the continuous output at the non-finite time {}"_format(t));
2105     }
2106 
2107     m_f_ptr(m_output.data(), t, m_tcs.data(), m_times_hi.data(), m_times_lo.data());
2108 }
2109 
2110 template <typename T>
get_llvm_state() const2111 const llvm_state &continuous_output<T>::get_llvm_state() const
2112 {
2113     return m_llvm_state;
2114 }
2115 
2116 template <typename T>
save(boost::archive::binary_oarchive & ar,unsigned) const2117 void continuous_output<T>::save(boost::archive::binary_oarchive &ar, unsigned) const
2118 {
2119     ar << m_llvm_state;
2120     ar << m_tcs;
2121     ar << m_times_hi;
2122     ar << m_times_lo;
2123     ar << m_output;
2124 }
2125 
2126 template <typename T>
load(boost::archive::binary_iarchive & ar,unsigned)2127 void continuous_output<T>::load(boost::archive::binary_iarchive &ar, unsigned)
2128 {
2129     ar >> m_llvm_state;
2130     ar >> m_tcs;
2131     ar >> m_times_hi;
2132     ar >> m_times_lo;
2133     ar >> m_output;
2134 
2135     // NOTE: if m_output is not empty, it means the archived
2136     // object had been initialised.
2137     if (m_output.empty()) {
2138         m_f_ptr = nullptr;
2139     } else {
2140         m_f_ptr = reinterpret_cast<fptr_t>(m_llvm_state.jit_lookup("c_out"));
2141     }
2142 }
2143 
2144 template <typename T>
get_bounds() const2145 std::pair<T, T> continuous_output<T>::get_bounds() const
2146 {
2147     if (m_f_ptr == nullptr) {
2148         throw std::invalid_argument("Cannot use a default-constructed continuous_output object");
2149     }
2150 
2151     return {m_times_hi[0], m_times_hi.back()};
2152 }
2153 
2154 template <typename T>
get_n_steps() const2155 std::size_t continuous_output<T>::get_n_steps() const
2156 {
2157     if (m_f_ptr == nullptr) {
2158         throw std::invalid_argument("Cannot use a default-constructed continuous_output object");
2159     }
2160 
2161     return boost::numeric_cast<std::size_t>(m_times_hi.size() - 1u);
2162 }
2163 
2164 template class continuous_output<double>;
2165 template class continuous_output<long double>;
2166 
2167 #if defined(HEYOKA_HAVE_REAL128)
2168 
2169 template class continuous_output<mppp::real128>;
2170 
2171 #endif
2172 
2173 namespace detail
2174 {
2175 
2176 template <typename T>
c_out_stream_impl(std::ostream & os,const continuous_output<T> & co)2177 std::ostream &c_out_stream_impl(std::ostream &os, const continuous_output<T> &co)
2178 {
2179     std::ostringstream oss;
2180     oss.exceptions(std::ios_base::failbit | std::ios_base::badbit);
2181     oss.imbue(std::locale::classic());
2182     oss << std::showpoint;
2183     oss.precision(std::numeric_limits<T>::max_digits10);
2184 
2185     if (co.get_output().empty()) {
2186         oss << "Default-constructed continuous_output";
2187     } else {
2188         const detail::dfloat<T> df_t_start(co.m_times_hi[0], co.m_times_lo[0]),
2189             df_t_end(co.m_times_hi.back(), co.m_times_lo.back());
2190         const auto dir = df_t_start < df_t_end;
2191 
2192         oss << "Direction : " << (dir ? "forward" : "backward") << '\n';
2193         oss << "Time range: "
2194             << (dir ? "[{}, {})"_format(co.m_times_hi[0], co.m_times_hi.back())
2195                     : "({}, {}]"_format(co.m_times_hi.back(), co.m_times_hi[0]))
2196             << '\n';
2197         oss << "N of steps: " << (co.m_times_hi.size() - 1u) << '\n';
2198     }
2199 
2200     return os << oss.str();
2201 }
2202 
2203 } // namespace detail
2204 
2205 template <>
operator <<(std::ostream & os,const continuous_output<double> & co)2206 std::ostream &operator<<(std::ostream &os, const continuous_output<double> &co)
2207 {
2208     return detail::c_out_stream_impl(os, co);
2209 }
2210 
2211 template <>
operator <<(std::ostream & os,const continuous_output<long double> & co)2212 std::ostream &operator<<(std::ostream &os, const continuous_output<long double> &co)
2213 {
2214     return detail::c_out_stream_impl(os, co);
2215 }
2216 
2217 #if defined(HEYOKA_HAVE_REAL128)
2218 
2219 template <>
operator <<(std::ostream & os,const continuous_output<mppp::real128> & co)2220 std::ostream &operator<<(std::ostream &os, const continuous_output<mppp::real128> &co)
2221 {
2222     return detail::c_out_stream_impl(os, co);
2223 }
2224 
2225 #endif
2226 
2227 #if !defined(NDEBUG)
2228 
2229 extern "C" {
2230 
2231 // Function to check, in debug mode, the indexing of the Taylor coefficients
2232 // in the batch mode continuous output implementation.
heyoka_continuous_output_batch_tc_idx_debug(const std::uint32_t * tc_idx,std::uint32_t times_size,std::uint32_t batch_size)2233 HEYOKA_DLL_PUBLIC void heyoka_continuous_output_batch_tc_idx_debug(const std::uint32_t *tc_idx,
2234                                                                    std::uint32_t times_size,
2235                                                                    std::uint32_t batch_size) noexcept
2236 {
2237     // LCOV_EXCL_START
2238     assert(batch_size != 0u);
2239     assert(times_size % batch_size == 0u);
2240     assert(times_size / batch_size >= 3u);
2241     // LCOV_EXCL_STOP
2242 
2243     for (std::uint32_t i = 0; i < batch_size; ++i) {
2244         assert(tc_idx[i] < times_size / batch_size - 2u); // LCOV_EXCL_LINE
2245     }
2246 }
2247 }
2248 
2249 #endif
2250 
2251 // Continuous output for the batch integrator.
2252 template <typename T>
add_c_out_function(std::uint32_t order,std::uint32_t dim,bool high_accuracy)2253 void continuous_output_batch<T>::add_c_out_function(std::uint32_t order, std::uint32_t dim, bool high_accuracy)
2254 {
2255     // Overflow check: we want to be able to index into the arrays of
2256     // times and Taylor coefficients using 32-bit ints.
2257     // LCOV_EXCL_START
2258     if (m_tcs.size() > std::numeric_limits<std::uint32_t>::max()
2259         || m_times_hi.size() > std::numeric_limits<std::uint32_t>::max()) {
2260         throw std::overflow_error(
2261             "Overflow detected while adding continuous output to a Taylor integrator in batch mode");
2262     }
2263     // LCOV_EXCL_STOP
2264 
2265     auto &md = m_llvm_state.module();
2266     auto &builder = m_llvm_state.builder();
2267     auto &context = m_llvm_state.context();
2268 
2269     // The function arguments:
2270     // - the output pointer (read/write, used also for accumulation),
2271     // - the pointer to the target time values (read-only),
2272     // - the pointer to the Taylor coefficients (read-only),
2273     // - the pointer to the hi times (read-only),
2274     // - the pointer to the lo times (read-only).
2275     // No overlap is allowed.
2276     auto fp_t = detail::to_llvm_type<T>(context);
2277     auto fp_vec_t = detail::make_vector_type(fp_t, m_batch_size);
2278     auto ptr_t = llvm::PointerType::getUnqual(fp_t);
2279     std::vector<llvm::Type *> fargs{ptr_t, ptr_t, ptr_t, ptr_t, ptr_t};
2280     // The function does not return anything.
2281     auto *ft = llvm::FunctionType::get(builder.getVoidTy(), fargs, false);
2282     assert(ft != nullptr); // LCOV_EXCL_LINE
2283     // Now create the function.
2284     auto *f = llvm::Function::Create(ft, llvm::Function::ExternalLinkage, "c_out", &md);
2285     // LCOV_EXCL_START
2286     if (f == nullptr) {
2287         throw std::invalid_argument("Unable to create a function for continuous output in a Taylor integrator");
2288     }
2289     // LCOV_EXCL_STOP
2290 
2291     // Set the names/attributes of the function arguments.
2292     auto out_ptr = f->args().begin();
2293     out_ptr->setName("out_ptr");
2294     out_ptr->addAttr(llvm::Attribute::NoCapture);
2295     out_ptr->addAttr(llvm::Attribute::NoAlias);
2296 
2297     auto tm_ptr = f->args().begin() + 1;
2298     tm_ptr->setName("tm_ptr");
2299     tm_ptr->addAttr(llvm::Attribute::NoCapture);
2300     tm_ptr->addAttr(llvm::Attribute::NoAlias);
2301     tm_ptr->addAttr(llvm::Attribute::ReadOnly);
2302 
2303     auto tc_ptr = f->args().begin() + 2;
2304     tc_ptr->setName("tc_ptr");
2305     tc_ptr->addAttr(llvm::Attribute::NoCapture);
2306     tc_ptr->addAttr(llvm::Attribute::NoAlias);
2307     tc_ptr->addAttr(llvm::Attribute::ReadOnly);
2308 
2309     auto times_ptr_hi = f->args().begin() + 3;
2310     times_ptr_hi->setName("times_ptr_hi");
2311     times_ptr_hi->addAttr(llvm::Attribute::NoCapture);
2312     times_ptr_hi->addAttr(llvm::Attribute::NoAlias);
2313     times_ptr_hi->addAttr(llvm::Attribute::ReadOnly);
2314 
2315     auto times_ptr_lo = f->args().begin() + 4;
2316     times_ptr_lo->setName("times_ptr_lo");
2317     times_ptr_lo->addAttr(llvm::Attribute::NoCapture);
2318     times_ptr_lo->addAttr(llvm::Attribute::NoAlias);
2319     times_ptr_lo->addAttr(llvm::Attribute::ReadOnly);
2320 
2321     // Create a new basic block to start insertion into.
2322     auto *bb = llvm::BasicBlock::Create(context, "entry", f);
2323     assert(bb != nullptr); // LCOV_EXCL_LINE
2324     builder.SetInsertPoint(bb);
2325 
2326     // Establish the time directions.
2327     auto bool_vector_t = detail::make_vector_type(builder.getInt1Ty(), m_batch_size);
2328     assert(bool_vector_t != nullptr); // LCOV_EXCL_LINE
2329     llvm::Value *dir_vec{};
2330     if (m_batch_size == 1u) {
2331         // In scalar mode, the direction is a single value.
2332         const detail::dfloat<T> df_t_start(m_times_hi[0], m_times_lo[0]),
2333             // NOTE: we load from the padding values here.
2334             df_t_end(m_times_hi.back(), m_times_lo.back());
2335         const auto dir = df_t_start < df_t_end;
2336 
2337         dir_vec = builder.getInt1(dir);
2338     } else {
2339         dir_vec = llvm::UndefValue::get(bool_vector_t);
2340         for (std::uint32_t i = 0; i < m_batch_size; ++i) {
2341             const detail::dfloat<T> df_t_start(m_times_hi[i], m_times_lo[i]),
2342                 // NOTE: we load from the padding values here.
2343                 df_t_end(m_times_hi[m_times_hi.size() - m_batch_size + i],
2344                          m_times_lo[m_times_lo.size() - m_batch_size + i]);
2345             const auto dir = df_t_start < df_t_end;
2346 
2347             dir_vec = builder.CreateInsertElement(dir_vec, builder.getInt1(dir), i);
2348         }
2349     }
2350 
2351     // Look for the index in the times vector corresponding to
2352     // a time greater than tm (less than tm in backwards integration).
2353     // This is essentially an implementation of std::upper_bound:
2354     // https://en.cppreference.com/w/cpp/algorithm/upper_bound
2355     auto int32_vec_t = detail::make_vector_type(builder.getInt32Ty(), m_batch_size);
2356     auto tidx = builder.CreateAlloca(int32_vec_t);
2357     auto count = builder.CreateAlloca(int32_vec_t);
2358     auto step = builder.CreateAlloca(int32_vec_t);
2359     auto first = builder.CreateAlloca(int32_vec_t);
2360 
2361     // count is inited with the size of the range.
2362     // NOTE: count includes the padding.
2363     builder.CreateStore(
2364         detail::vector_splat(builder, builder.getInt32(static_cast<std::uint32_t>(m_times_hi.size()) / m_batch_size),
2365                              m_batch_size),
2366         count);
2367 
2368     // first is inited to zero.
2369     auto zero_vec_i32 = detail::vector_splat(builder, builder.getInt32(0), m_batch_size);
2370     builder.CreateStore(zero_vec_i32, first);
2371 
2372     // Load the time value from tm_ptr.
2373     auto tm = detail::load_vector_from_memory(builder, tm_ptr, m_batch_size);
2374 
2375     // This is the vector [0, 1, 2, ..., (batch_size - 1)].
2376     llvm::Value *batch_offset{};
2377     if (m_batch_size == 1u) {
2378         // In scalar mode, use a single value.
2379         batch_offset = builder.getInt32(0);
2380     } else {
2381         batch_offset = llvm::UndefValue::get(int32_vec_t);
2382         for (std::uint32_t i = 0; i < m_batch_size; ++i) {
2383             batch_offset = builder.CreateInsertElement(batch_offset, builder.getInt32(i), i);
2384         }
2385     }
2386 
2387     // Splatted version of the batch size.
2388     auto batch_splat = detail::vector_splat(builder, builder.getInt32(m_batch_size), m_batch_size);
2389 
2390     // Splatted versions of the base pointers for the time data.
2391     auto times_ptr_hi_vec = detail::vector_splat(builder, times_ptr_hi, m_batch_size);
2392     auto times_ptr_lo_vec = detail::vector_splat(builder, times_ptr_lo, m_batch_size);
2393 
2394     // fp vector of zeroes.
2395     auto zero_vec_fp = detail::vector_splat(builder, llvm::ConstantFP::get(fp_t, 0.), m_batch_size);
2396 
2397     // Vector of i32 ones.
2398     auto one_vec_i32 = detail::vector_splat(builder, builder.getInt32(1), m_batch_size);
2399 
2400     detail::llvm_while_loop(
2401         m_llvm_state,
2402         [&]() -> llvm::Value * {
2403             // NOTE: the condition here is that any value in count is not zero.
2404             auto cmp = builder.CreateICmpNE(builder.CreateLoad(count), zero_vec_i32);
2405 
2406             // NOTE: in scalar mode, no reduction is needed.
2407             return (m_batch_size == 1u) ? cmp : builder.CreateOrReduce(cmp);
2408         },
2409         [&]() {
2410             // tidx = first.
2411             builder.CreateStore(builder.CreateLoad(first), tidx);
2412             // step = count / 2.
2413             auto two_vec_i32 = detail::vector_splat(builder, builder.getInt32(2), m_batch_size);
2414             builder.CreateStore(builder.CreateUDiv(builder.CreateLoad(count), two_vec_i32), step);
2415             // tidx = tidx + step.
2416             builder.CreateStore(builder.CreateAdd(builder.CreateLoad(tidx), builder.CreateLoad(step)), tidx);
2417 
2418             // Compute the indices for loading the times from the pointers.
2419             auto tl_idx = builder.CreateAdd(builder.CreateMul(builder.CreateLoad(tidx), batch_splat), batch_offset);
2420 
2421             // Compute the pointers for loading the time data.
2422             auto tptr_hi = builder.CreateInBoundsGEP(times_ptr_hi_vec, {tl_idx});
2423             auto tptr_lo = builder.CreateInBoundsGEP(times_ptr_lo_vec, {tl_idx});
2424 
2425             // Gather the hi/lo values.
2426             auto tidx_val_hi = detail::gather_vector_from_memory(builder, fp_vec_t, tptr_hi, alignof(T));
2427             auto tidx_val_lo = detail::gather_vector_from_memory(builder, fp_vec_t, tptr_lo, alignof(T));
2428 
2429             // Compute the two conditions !(tm < *tidx) and !(tm > *tidx).
2430             auto cmp_lt
2431                 = builder.CreateNot(detail::llvm_dl_lt(m_llvm_state, tm, zero_vec_fp, tidx_val_hi, tidx_val_lo));
2432             auto cmp_gt
2433                 = builder.CreateNot(detail::llvm_dl_gt(m_llvm_state, tm, zero_vec_fp, tidx_val_hi, tidx_val_lo));
2434 
2435             // Select cmp_lt if integrating forward, cmp_gt when integrating backward.
2436             auto cond = builder.CreateSelect(dir_vec, cmp_lt, cmp_gt);
2437 
2438             // tidx += (1 or 0).
2439             builder.CreateStore(
2440                 builder.CreateAdd(builder.CreateLoad(tidx), builder.CreateSelect(cond, one_vec_i32, zero_vec_i32)),
2441                 tidx);
2442 
2443             // first = (tidx or first).
2444             builder.CreateStore(builder.CreateSelect(cond, builder.CreateLoad(tidx), builder.CreateLoad(first)), first);
2445 
2446             // count = count - (step or count).
2447             auto old_count = builder.CreateLoad(count);
2448             auto new_count
2449                 = builder.CreateSub(old_count, builder.CreateSelect(cond, builder.CreateLoad(step), old_count));
2450 
2451             // count = count + (-1 or step).
2452             new_count = builder.CreateAdd(
2453                 new_count, builder.CreateSelect(cond, builder.CreateNeg(one_vec_i32), builder.CreateLoad(step)));
2454             builder.CreateStore(new_count, count);
2455         });
2456 
2457     // NOTE: the output of the std::upper_bound algorithm
2458     // is in the 'first' variable.
2459     llvm::Value *tc_idx = builder.CreateLoad(first);
2460 
2461     // Normally, the TC index should be first - 1. The exceptions are:
2462     // - first == 0, in which case TC index is also 0,
2463     // - first == (range size - 1), in which case TC index is first - 2.
2464     // These two exceptions arise when tm is outside the range of validity
2465     // for the continuous output. In such cases, we will use either the first
2466     // or the last possible set of TCs.
2467     // NOTE: the second check is range size - 1 (rather than just range size
2468     // like in the scalar case) due to padding.
2469     // In order to vectorise the check, we compute:
2470     // tc_idx = tc_idx - (tc_idx != 0) - (tc_idx == range size - 1).
2471     auto tc_idx_cmp1 = builder.CreateZExt(builder.CreateICmpNE(tc_idx, zero_vec_i32), int32_vec_t);
2472     auto tc_idx_cmp2 = builder.CreateZExt(
2473         builder.CreateICmpEQ(
2474             tc_idx, detail::vector_splat(
2475                         builder, builder.getInt32(static_cast<std::uint32_t>(m_times_hi.size() / m_batch_size - 1u)),
2476                         m_batch_size)),
2477         int32_vec_t);
2478     tc_idx = builder.CreateSub(tc_idx, tc_idx_cmp1);
2479     tc_idx = builder.CreateSub(tc_idx, tc_idx_cmp2);
2480 
2481 #if !defined(NDEBUG)
2482 
2483     // In debug mode, invoke the index checking function.
2484     auto tc_idx_debug_ptr
2485         = builder.CreateInBoundsGEP(builder.CreateAlloca(llvm::ArrayType::get(builder.getInt32Ty(), m_batch_size)),
2486                                     {builder.getInt32(0), builder.getInt32(0)});
2487     detail::store_vector_to_memory(builder, tc_idx_debug_ptr, tc_idx);
2488     detail::llvm_invoke_external(m_llvm_state, "heyoka_continuous_output_batch_tc_idx_debug", builder.getVoidTy(),
2489                                  {tc_idx_debug_ptr, builder.getInt32(static_cast<std::uint32_t>(m_times_hi.size())),
2490                                   builder.getInt32(m_batch_size)});
2491 
2492 #endif
2493 
2494     // Convert tc_idx into an index for loading from the time vectors.
2495     auto tc_l_idx = builder.CreateAdd(builder.CreateMul(tc_idx, batch_splat), batch_offset);
2496 
2497     // Load the times corresponding to tc_idx.
2498     auto start_tm_hi = detail::gather_vector_from_memory(
2499         builder, fp_vec_t, builder.CreateInBoundsGEP(times_ptr_hi_vec, {tc_l_idx}), alignof(T));
2500     auto start_tm_lo = detail::gather_vector_from_memory(
2501         builder, fp_vec_t, builder.CreateInBoundsGEP(times_ptr_lo_vec, {tc_l_idx}), alignof(T));
2502 
2503     // Compute the value of h = tm - start_tm.
2504     auto h = detail::llvm_dl_add(m_llvm_state, tm, zero_vec_fp, builder.CreateFNeg(start_tm_hi),
2505                                  builder.CreateFNeg(start_tm_lo))
2506                  .first;
2507 
2508     // Compute the base pointers in the array of TC for the computation
2509     // of Horner's scheme.
2510     tc_idx = builder.CreateAdd(
2511         builder.CreateMul(
2512             tc_idx, detail::vector_splat(builder, builder.getInt32(dim * (order + 1u) * m_batch_size), m_batch_size)),
2513         batch_offset);
2514     // NOTE: each pointer in tc_ptrs points to the Taylor coefficient of
2515     // order 0 for the first state variable in the timestep data block selected
2516     // for that batch index.
2517     auto tc_ptrs = builder.CreateInBoundsGEP(tc_ptr, {tc_idx});
2518 
2519     // Run the Horner scheme.
2520     if (high_accuracy) {
2521         // Create the array for storing the running compensations.
2522         auto array_type = llvm::ArrayType::get(fp_vec_t, dim);
2523         auto comp_arr
2524             = builder.CreateInBoundsGEP(builder.CreateAlloca(array_type), {builder.getInt32(0), builder.getInt32(0)});
2525 
2526         // Start by writing into out_ptr the zero-order coefficients
2527         // and by filling with zeroes the running compensations.
2528         detail::llvm_loop_u32(m_llvm_state, builder.getInt32(0), builder.getInt32(dim), [&](llvm::Value *cur_var_idx) {
2529             // Load the coefficient from tc_ptrs. The index is:
2530             // m_batch_size * (order + 1u) * cur_var_idx.
2531             auto *load_idx = builder.CreateMul(builder.getInt32(m_batch_size * (order + 1u)), cur_var_idx);
2532             auto *tcs = detail::gather_vector_from_memory(builder, fp_vec_t,
2533                                                           builder.CreateInBoundsGEP(tc_ptrs, {load_idx}), alignof(T));
2534 
2535             // Store it in out_ptr. The index is:
2536             // m_batch_size * cur_var_idx.
2537             auto *out_idx = builder.CreateMul(builder.getInt32(m_batch_size), cur_var_idx);
2538             detail::store_vector_to_memory(builder, builder.CreateInBoundsGEP(out_ptr, {out_idx}), tcs);
2539 
2540             // Zero-init the element in comp_arr.
2541             builder.CreateStore(zero_vec_fp, builder.CreateInBoundsGEP(comp_arr, {cur_var_idx}));
2542         });
2543 
2544         // Init the running updater for the powers of h.
2545         auto *cur_h = builder.CreateAlloca(h->getType());
2546         builder.CreateStore(h, cur_h);
2547 
2548         // Run the evaluation.
2549         detail::llvm_loop_u32(
2550             m_llvm_state, builder.getInt32(1), builder.getInt32(order + 1u), [&](llvm::Value *cur_order) {
2551                 // Load the current power of h.
2552                 auto *cur_h_val = builder.CreateLoad(cur_h);
2553 
2554                 detail::llvm_loop_u32(
2555                     m_llvm_state, builder.getInt32(0), builder.getInt32(dim), [&](llvm::Value *cur_var_idx) {
2556                         // Load the coefficient from tc_ptrs. The index is:
2557                         // m_batch_size * (order + 1u) * cur_var_idx + m_batch_size * cur_order.
2558                         auto *load_idx = builder.CreateAdd(
2559                             builder.CreateMul(builder.getInt32(m_batch_size * (order + 1u)), cur_var_idx),
2560                             builder.CreateMul(builder.getInt32(m_batch_size), cur_order));
2561                         auto *cf = detail::gather_vector_from_memory(
2562                             builder, fp_vec_t, builder.CreateInBoundsGEP(tc_ptrs, {load_idx}), alignof(T));
2563                         auto *tmp = builder.CreateFMul(cf, cur_h_val);
2564 
2565                         // Compute the quantities for the compensation.
2566                         auto *comp_ptr = builder.CreateInBoundsGEP(comp_arr, {cur_var_idx});
2567                         auto *out_idx = builder.CreateMul(builder.getInt32(m_batch_size), cur_var_idx);
2568                         auto *res_ptr = builder.CreateInBoundsGEP(out_ptr, {out_idx});
2569                         auto *y = builder.CreateFSub(tmp, builder.CreateLoad(comp_ptr));
2570                         auto *cur_res = detail::load_vector_from_memory(builder, res_ptr, m_batch_size);
2571                         auto *t = builder.CreateFAdd(cur_res, y);
2572 
2573                         // Update the compensation and the return value.
2574                         builder.CreateStore(builder.CreateFSub(builder.CreateFSub(t, cur_res), y), comp_ptr);
2575                         detail::store_vector_to_memory(builder, res_ptr, t);
2576                     });
2577 
2578                 // Update the value of h.
2579                 builder.CreateStore(builder.CreateFMul(cur_h_val, h), cur_h);
2580             });
2581     } else {
2582         // Start by writing into out_ptr the coefficients of the highest-degree
2583         // monomial in each polynomial.
2584         detail::llvm_loop_u32(m_llvm_state, builder.getInt32(0), builder.getInt32(dim), [&](llvm::Value *cur_var_idx) {
2585             // Load the coefficient from tc_ptrs. The index is:
2586             // m_batch_size * (order + 1u) * cur_var_idx + m_batch_size * order.
2587             auto *load_idx
2588                 = builder.CreateAdd(builder.CreateMul(builder.getInt32(m_batch_size * (order + 1u)), cur_var_idx),
2589                                     builder.getInt32(m_batch_size * order));
2590             auto *tcs = detail::gather_vector_from_memory(builder, fp_vec_t,
2591                                                           builder.CreateInBoundsGEP(tc_ptrs, {load_idx}), alignof(T));
2592 
2593             // Store it in out_ptr. The index is:
2594             // m_batch_size * cur_var_idx.
2595             auto *out_idx = builder.CreateMul(builder.getInt32(m_batch_size), cur_var_idx);
2596             detail::store_vector_to_memory(builder, builder.CreateInBoundsGEP(out_ptr, {out_idx}), tcs);
2597         });
2598 
2599         // Now let's run the Horner scheme.
2600         detail::llvm_loop_u32(
2601             m_llvm_state, builder.getInt32(1), builder.CreateAdd(builder.getInt32(order), builder.getInt32(1)),
2602             [&](llvm::Value *cur_order) {
2603                 detail::llvm_loop_u32(
2604                     m_llvm_state, builder.getInt32(0), builder.getInt32(dim), [&](llvm::Value *cur_var_idx) {
2605                         // Load the current Taylor coefficients from tc_ptrs.
2606                         // NOTE: we are loading the coefficients backwards wrt the order, hence
2607                         // we specify order - cur_order.
2608                         // NOTE: the index is:
2609                         // m_batch_size * (order + 1u) * cur_var_idx + m_batch_size * (order - cur_order).
2610                         auto *load_idx = builder.CreateAdd(
2611                             builder.CreateMul(builder.getInt32(m_batch_size * (order + 1u)), cur_var_idx),
2612                             builder.CreateMul(builder.getInt32(m_batch_size),
2613                                               builder.CreateSub(builder.getInt32(order), cur_order)));
2614                         auto *tcs = detail::gather_vector_from_memory(
2615                             builder, fp_vec_t, builder.CreateInBoundsGEP(tc_ptrs, {load_idx}), alignof(T));
2616 
2617                         // Accumulate in out_ptr. The index is:
2618                         // m_batch_size * cur_var_idx.
2619                         auto *out_idx = builder.CreateMul(builder.getInt32(m_batch_size), cur_var_idx);
2620                         auto *out_p = builder.CreateInBoundsGEP(out_ptr, {out_idx});
2621                         auto *cur_out = detail::load_vector_from_memory(builder, out_p, m_batch_size);
2622                         detail::store_vector_to_memory(builder, out_p,
2623                                                        builder.CreateFAdd(tcs, builder.CreateFMul(cur_out, h)));
2624                     });
2625             });
2626     }
2627 
2628     // Create the return value.
2629     builder.CreateRetVoid();
2630 
2631     // Verify the function.
2632     m_llvm_state.verify_function(f);
2633 
2634     // Run the optimisation pass.
2635     m_llvm_state.optimise();
2636 
2637     // Compile.
2638     m_llvm_state.compile();
2639 
2640     // Fetch the function pointer and assign it.
2641     m_f_ptr = reinterpret_cast<fptr_t>(m_llvm_state.jit_lookup("c_out"));
2642 }
2643 
2644 template <typename T>
2645 continuous_output_batch<T>::continuous_output_batch() = default;
2646 
2647 template <typename T>
continuous_output_batch(llvm_state && s)2648 continuous_output_batch<T>::continuous_output_batch(llvm_state &&s) : m_llvm_state(std::move(s))
2649 {
2650 }
2651 
2652 template <typename T>
continuous_output_batch(const continuous_output_batch & o)2653 continuous_output_batch<T>::continuous_output_batch(const continuous_output_batch &o)
2654     : m_batch_size(o.m_batch_size), m_llvm_state(o.m_llvm_state), m_tcs(o.m_tcs), m_times_hi(o.m_times_hi),
2655       m_times_lo(o.m_times_lo), m_output(o.m_output), m_tmp_tm(o.m_tmp_tm)
2656 {
2657     // If o is valid, fetch the function pointer from the copied state.
2658     // Otherwise, m_f_ptr will remain null.
2659     if (o.m_f_ptr != nullptr) {
2660         m_f_ptr = reinterpret_cast<fptr_t>(m_llvm_state.jit_lookup("c_out"));
2661     }
2662 }
2663 
2664 template <typename T>
2665 continuous_output_batch<T>::continuous_output_batch(continuous_output_batch &&) noexcept = default;
2666 
2667 template <typename T>
2668 continuous_output_batch<T>::~continuous_output_batch() = default;
2669 
2670 template <typename T>
operator =(const continuous_output_batch & o)2671 continuous_output_batch<T> &continuous_output_batch<T>::operator=(const continuous_output_batch &o)
2672 {
2673     if (this != &o) {
2674         *this = continuous_output_batch(o);
2675     }
2676 
2677     return *this;
2678 }
2679 
2680 template <typename T>
2681 continuous_output_batch<T> &continuous_output_batch<T>::operator=(continuous_output_batch &&) noexcept = default;
2682 
2683 template <typename T>
call_impl(const T * t)2684 void continuous_output_batch<T>::call_impl(const T *t)
2685 {
2686     using std::isfinite;
2687 
2688     if (m_f_ptr == nullptr) {
2689         throw std::invalid_argument("Cannot use a default-constructed continuous_output_batch object");
2690     }
2691 
2692     // NOTE: run the assertions only after ensuring this
2693     // is a valid object.
2694 
2695     // LCOV_EXCL_START
2696 #if !defined(NDEBUG)
2697     // The batch size must not be zero.
2698     assert(m_batch_size != 0u);
2699     // m_batch_size must divide m_output exactly.
2700     assert(m_output.size() % m_batch_size == 0u);
2701     // m_tmp_tm must be of size m_batch_size.
2702     assert(m_tmp_tm.size() == m_batch_size);
2703     // m_batch_size must divide the time and tcs vectors exactly.
2704     assert(m_times_hi.size() % m_batch_size == 0u);
2705     assert(m_tcs.size() % m_batch_size == 0u);
2706     // Need at least 3 time points (2 + 1 for padding).
2707     assert(m_times_hi.size() / m_batch_size >= 3u);
2708     // hi/lo parts of times must have the same sizes.
2709     assert(m_times_hi.size() == m_times_lo.size());
2710 #endif
2711     // LCOV_EXCL_STOP
2712 
2713     // Copy over the times to the temp buffer and check that they are finite.
2714     // NOTE: this copy ensures we avoid aliasing issues with the
2715     // other data members.
2716     for (std::uint32_t i = 0; i < m_batch_size; ++i) {
2717         if (!isfinite(t[i])) {
2718             throw std::invalid_argument("Cannot compute the continuous output in batch mode "
2719                                         "for the batch index {} at the non-finite time {}"_format(i, t[i]));
2720         }
2721 
2722         m_tmp_tm[i] = t[i];
2723     }
2724 
2725     m_f_ptr(m_output.data(), m_tmp_tm.data(), m_tcs.data(), m_times_hi.data(), m_times_lo.data());
2726 }
2727 
2728 template <typename T>
operator ()(const std::vector<T> & tm)2729 const std::vector<T> &continuous_output_batch<T>::operator()(const std::vector<T> &tm)
2730 {
2731     if (m_f_ptr == nullptr) {
2732         throw std::invalid_argument("Cannot use a default-constructed continuous_output_batch object");
2733     }
2734 
2735     if (tm.size() != m_batch_size) {
2736         throw std::invalid_argument(
2737             "An invalid time vector was passed to the call operator of continuous_output_batch: the "
2738             "vector size is {}, but a size of {} was expected instead"_format(tm.size(), m_batch_size));
2739     }
2740 
2741     return (*this)(tm.data());
2742 }
2743 
2744 template <typename T>
get_llvm_state() const2745 const llvm_state &continuous_output_batch<T>::get_llvm_state() const
2746 {
2747     return m_llvm_state;
2748 }
2749 
2750 template <typename T>
get_batch_size() const2751 std::uint32_t continuous_output_batch<T>::get_batch_size() const
2752 {
2753     return m_batch_size;
2754 }
2755 
2756 template <typename T>
save(boost::archive::binary_oarchive & ar,unsigned) const2757 void continuous_output_batch<T>::save(boost::archive::binary_oarchive &ar, unsigned) const
2758 {
2759     ar << m_batch_size;
2760     ar << m_llvm_state;
2761     ar << m_tcs;
2762     ar << m_times_hi;
2763     ar << m_times_lo;
2764     ar << m_output;
2765     ar << m_tmp_tm;
2766 }
2767 
2768 template <typename T>
load(boost::archive::binary_iarchive & ar,unsigned)2769 void continuous_output_batch<T>::load(boost::archive::binary_iarchive &ar, unsigned)
2770 {
2771     ar >> m_batch_size;
2772     ar >> m_llvm_state;
2773     ar >> m_tcs;
2774     ar >> m_times_hi;
2775     ar >> m_times_lo;
2776     ar >> m_output;
2777     ar >> m_tmp_tm;
2778 
2779     // NOTE: if m_output is not empty, it means the archived
2780     // object had been initialised.
2781     if (m_output.empty()) {
2782         m_f_ptr = nullptr;
2783     } else {
2784         m_f_ptr = reinterpret_cast<fptr_t>(m_llvm_state.jit_lookup("c_out"));
2785     }
2786 }
2787 
2788 template <typename T>
get_bounds() const2789 std::pair<std::vector<T>, std::vector<T>> continuous_output_batch<T>::get_bounds() const
2790 {
2791     if (m_f_ptr == nullptr) {
2792         throw std::invalid_argument("Cannot use a default-constructed continuous_output_batch object");
2793     }
2794 
2795     std::vector<T> lb, ub;
2796     lb.resize(boost::numeric_cast<decltype(lb.size())>(m_batch_size));
2797     ub.resize(boost::numeric_cast<decltype(ub.size())>(m_batch_size));
2798 
2799     for (std::uint32_t i = 0; i < m_batch_size; ++i) {
2800         lb[i] = m_times_hi[i];
2801         // NOTE: take into account the padding.
2802         ub[i] = m_times_hi[m_times_hi.size() - 2u * m_batch_size + i];
2803     }
2804 
2805     return std::make_pair(std::move(lb), std::move(ub));
2806 }
2807 
2808 template <typename T>
get_n_steps() const2809 std::size_t continuous_output_batch<T>::get_n_steps() const
2810 {
2811     if (m_f_ptr == nullptr) {
2812         throw std::invalid_argument("Cannot use a default-constructed continuous_output_batch object");
2813     }
2814 
2815     // NOTE: account for padding.
2816     return boost::numeric_cast<std::size_t>(m_times_hi.size() / m_batch_size - 2u);
2817 }
2818 
2819 template class continuous_output_batch<double>;
2820 template class continuous_output_batch<long double>;
2821 
2822 #if defined(HEYOKA_HAVE_REAL128)
2823 
2824 template class continuous_output_batch<mppp::real128>;
2825 
2826 #endif
2827 
2828 namespace detail
2829 {
2830 
2831 template <typename T>
c_out_batch_stream_impl(std::ostream & os,const continuous_output_batch<T> & co)2832 std::ostream &c_out_batch_stream_impl(std::ostream &os, const continuous_output_batch<T> &co)
2833 {
2834     std::ostringstream oss;
2835     oss.exceptions(std::ios_base::failbit | std::ios_base::badbit);
2836     oss.imbue(std::locale::classic());
2837     oss << std::showpoint;
2838     oss.precision(std::numeric_limits<T>::max_digits10);
2839 
2840     if (co.get_output().empty()) {
2841         oss << "Default-constructed continuous_output_batch";
2842     } else {
2843         const auto batch_size = co.m_batch_size;
2844 
2845         oss << "Directions : [";
2846         for (std::uint32_t i = 0; i < batch_size; ++i) {
2847             const detail::dfloat<T> df_t_start(co.m_times_hi[i], co.m_times_lo[i]),
2848                 df_t_end(co.m_times_hi[co.m_times_hi.size() - 2u * batch_size + i],
2849                          co.m_times_lo[co.m_times_lo.size() - 2u * batch_size + i]);
2850             const auto dir = df_t_start < df_t_end;
2851 
2852             oss << (dir ? "forward" : "backward");
2853 
2854             if (i != batch_size - 1u) {
2855                 oss << ", ";
2856             }
2857         }
2858         oss << "]\n";
2859 
2860         oss << "Time ranges: [";
2861         for (std::uint32_t i = 0; i < batch_size; ++i) {
2862             const detail::dfloat<T> df_t_start(co.m_times_hi[i], co.m_times_lo[i]),
2863                 df_t_end(co.m_times_hi[co.m_times_hi.size() - 2u * batch_size + i],
2864                          co.m_times_lo[co.m_times_lo.size() - 2u * batch_size + i]);
2865             const auto dir = df_t_start < df_t_end;
2866             oss << (dir ? "[{}, {})"_format(df_t_start.hi, df_t_end.hi)
2867                         : "({}, {}]"_format(df_t_end.hi, df_t_start.hi));
2868 
2869             if (i != batch_size - 1u) {
2870                 oss << ", ";
2871             }
2872         }
2873         oss << "]\n";
2874 
2875         oss << "N of steps : " << co.get_n_steps() << '\n';
2876     }
2877 
2878     return os << oss.str();
2879 }
2880 
2881 } // namespace detail
2882 
2883 template <>
operator <<(std::ostream & os,const continuous_output_batch<double> & co)2884 std::ostream &operator<<(std::ostream &os, const continuous_output_batch<double> &co)
2885 {
2886     return detail::c_out_batch_stream_impl(os, co);
2887 }
2888 
2889 template <>
operator <<(std::ostream & os,const continuous_output_batch<long double> & co)2890 std::ostream &operator<<(std::ostream &os, const continuous_output_batch<long double> &co)
2891 {
2892     return detail::c_out_batch_stream_impl(os, co);
2893 }
2894 
2895 #if defined(HEYOKA_HAVE_REAL128)
2896 
2897 template <>
operator <<(std::ostream & os,const continuous_output_batch<mppp::real128> & co)2898 std::ostream &operator<<(std::ostream &os, const continuous_output_batch<mppp::real128> &co)
2899 {
2900     return detail::c_out_batch_stream_impl(os, co);
2901 }
2902 
2903 #endif
2904 
2905 } // namespace heyoka
2906